Commit 582b9295 by Nate Hardison

Validating OpenID request trust roots to only come from *.cs50.net

parent 635e8546
...@@ -33,7 +33,7 @@ from openid.consumer.consumer import (Consumer, SUCCESS, CANCEL, FAILURE) ...@@ -33,7 +33,7 @@ from openid.consumer.consumer import (Consumer, SUCCESS, CANCEL, FAILURE)
import django_openid_auth.views as openid_views import django_openid_auth.views as openid_views
from openid.server.server import Server, ProtocolError, CheckIDRequest, EncodingError from openid.server.server import Server, ProtocolError, CheckIDRequest, EncodingError
from openid.server.trustroot import verifyReturnTo from openid.server.trustroot import TrustRoot
from openid.store.filestore import FileOpenIDStore from openid.store.filestore import FileOpenIDStore
from openid.yadis.discover import DiscoveryFailure from openid.yadis.discover import DiscoveryFailure
from openid.consumer.discover import OPENID_IDP_2_0_TYPE from openid.consumer.discover import OPENID_IDP_2_0_TYPE
...@@ -255,7 +255,7 @@ def provider_respond(server, request, response, data): ...@@ -255,7 +255,7 @@ def provider_respond(server, request, response, data):
Respond to an OpenID request Respond to an OpenID request
""" """
# get simple registration request # get simple registration request
sreg_data = {} sreg_data = {}
sreg_request = sreg.SRegRequest.fromOpenIDRequest(request) sreg_request = sreg.SRegRequest.fromOpenIDRequest(request)
sreg_fields = sreg_request.allRequestedFields() sreg_fields = sreg_request.allRequestedFields()
...@@ -305,6 +305,37 @@ def provider_respond(server, request, response, data): ...@@ -305,6 +305,37 @@ def provider_respond(server, request, response, data):
return http_response return http_response
def validate_trust_root(openid_request):
"""
Only allow OpenID requests from valid trust roots
"""
# verify the trust root/return to
trust_root = openid_request.trust_root
return_to = openid_request.return_to
# don't allow empty trust roots
if openid_request.trust_root is None:
return false
# ensure trust root parses cleanly (one wildcard, of form *.foo.com, etc.)
trust_root = TrustRoot.parse(openid_request.trust_root)
if trust_root is None:
return false
# don't allow empty return tos
if openid_request.return_to is None:
return false
# ensure return to is within trust root
if not trust_root.validateURL(openid_request.return_to):
return false
# only allow *.cs50.net for now
return trust_root.host.endswith('cs50.net')
@csrf_exempt @csrf_exempt
def provider_login(request): def provider_login(request):
""" """
...@@ -323,6 +354,10 @@ def provider_login(request): ...@@ -323,6 +354,10 @@ def provider_login(request):
# decode request # decode request
openid_request = server.decodeRequest(query) openid_request = server.decodeRequest(query)
# don't allow invalid and non-*.cs50.net trust roots
if not validate_trust_root(openid_request):
return default_render_failure(request, "Invalid OpenID trust root")
# checkid_immediate not supported, require user interaction # checkid_immediate not supported, require user interaction
if openid_request.mode == 'checkid_immediate': if openid_request.mode == 'checkid_immediate':
return provider_respond(server, openid_request, openid_request.answer(false), {}) return provider_respond(server, openid_request, openid_request.answer(false), {})
...@@ -351,6 +386,10 @@ def provider_login(request): ...@@ -351,6 +386,10 @@ def provider_login(request):
openid_request = request.session['openid_request'] openid_request = request.session['openid_request']
del request.session['openid_request'] del request.session['openid_request']
# don't allow invalid and non-*.cs50.net trust roots
if not validate_trust_root(openid_request):
return default_render_failure(request, "Invalid OpenID trust root")
# check if user with given email exists # check if user with given email exists
email = request.POST['email'] email = request.POST['email']
password = request.POST['password'] password = request.POST['password']
...@@ -431,3 +470,4 @@ def provider_xrds(request): ...@@ -431,3 +470,4 @@ def provider_xrds(request):
# custom XRDS header necessary for discovery process # custom XRDS header necessary for discovery process
response['X-XRDS-Location'] = get_xrds_url('xrds', request) response['X-XRDS-Location'] = get_xrds_url('xrds', request)
return response return response
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment