Commit 0a81684a by Carlos Andrés Rocha

[34078525] Fix error saving open_id request in session

parent d702082d
......@@ -267,26 +267,21 @@ def ssl_login(request):
# -----------------------------------------------------------------------------
def get_dict_for_openid(data):
"""
Return a dictionary suitable for the OpenID library
"""
return dict((k, v) for k, v in data.iteritems())
def get_xrds_url(resource, request):
"""
Return the XRDS url for a resource
"""
host = request.META['HTTP_HOST']
if not host.endswith('edx.org'):
return None
location = host + '/openid/provider/' + resource + '/'
location = request.META['HTTP_HOST'] + '/openid/provider/' + resource + '/'
if request.is_secure():
url = 'https://' + location
return 'https://' + location
else:
url = 'http://' + location
return url
return 'http://' + location
def add_openid_simple_registration(request, response, data):
......@@ -402,17 +397,24 @@ def provider_login(request):
OpenID login endpoint
"""
# initialize store and server
# make and validate endpoint
endpoint = get_xrds_url('login', request)
if not endpoint:
return default_render_failure(request, "Invalid OpenID request")
# initialize store and server
store = FileOpenIDStore('/tmp/openid_provider')
server = Server(store, endpoint)
# handle OpenID request
query = get_dict_for_openid(request.REQUEST)
querydict = dict(request.REQUEST.items())
error = False
if 'openid.mode' in request.GET or 'openid.mode' in request.POST:
# decode request
openid_request = server.decodeRequest(query)
openid_request = server.decodeRequest(querydict)
if not openid_request:
return default_render_failure(request, "Invalid OpenID request")
# don't allow invalid and non-trusted trust roots
if not validate_trust_root(openid_request):
......@@ -427,7 +429,7 @@ def provider_login(request):
elif openid_request.mode == 'checkid_setup':
if openid_request.idSelect():
# remember request and original path
request.session['openid_request'] = {
request.session['openid_setup'] = {
'request': openid_request,
'url': request.get_full_path()
}
......@@ -443,12 +445,14 @@ def provider_login(request):
server.handleRequest(openid_request), {})
# handle login
if request.method == 'POST' and 'openid_request' in request.session:
if request.method == 'POST' and 'openid_setup' in request.session:
# get OpenID request from session
openid_request = request.session['openid_request']
del request.session['openid_request']
openid_setup = request.session['openid_setup']
openid_request = openid_setup['request']
openid_request_url = openid_setup['url']
del request.session['openid_setup']
# don't allow invalid and non-*.cs50.net trust roots
# don't allow invalid trust roots
if not validate_trust_root(openid_request):
return default_render_failure(request, "Invalid OpenID trust root")
......@@ -460,7 +464,7 @@ def provider_login(request):
request.session['openid_error'] = True
msg = "OpenID login failed - Unknown user email: {0}".format(email)
log.warning(msg)
return HttpResponseRedirect(openid_request['url'])
return HttpResponseRedirect(openid_request_url)
# attempt to authenticate user
username = user.username
......@@ -471,7 +475,7 @@ def provider_login(request):
msg = "OpenID login failed - password for {0} is invalid"
msg = msg.format(email)
log.warning(msg)
return HttpResponseRedirect(openid_request['url'])
return HttpResponseRedirect(openid_request_url)
# authentication succeeded, so log user in
if user is not None and user.is_active:
......@@ -486,10 +490,10 @@ def provider_login(request):
# redirect user to return_to location
url = endpoint + urlquote(user.username)
response = openid_request['request'].answer(True, None, url)
response = openid_request.answer(True, None, url)
return provider_respond(server,
openid_request['request'],
openid_request,
response,
{
'fullname': profile.name,
......@@ -499,7 +503,7 @@ def provider_login(request):
request.session['openid_error'] = True
msg = "Login failed - Account not active for user {0}".format(username)
log.warning(msg)
return HttpResponseRedirect(openid_request['url'])
return HttpResponseRedirect(openid_request_url)
# determine consumer domain if applicable
return_to = ''
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment