Commit fe78850c by Anthony Lenton

Merged in lp:~michael.nelson/django-openid-auth/810978-required-sreg

parents 1bbcbd7b 4c297d2f
...@@ -45,6 +45,9 @@ class IdentityAlreadyClaimed(Exception): ...@@ -45,6 +45,9 @@ class IdentityAlreadyClaimed(Exception):
class StrictUsernameViolation(Exception): class StrictUsernameViolation(Exception):
pass pass
class RequiredAttributeNotReturned(Exception):
pass
class OpenIDBackend: class OpenIDBackend:
"""A django.contrib.auth backend that authenticates the user based on """A django.contrib.auth backend that authenticates the user based on
an OpenID response.""" an OpenID response."""
...@@ -74,10 +77,7 @@ class OpenIDBackend: ...@@ -74,10 +77,7 @@ class OpenIDBackend:
claimed_id__exact=openid_response.identity_url) claimed_id__exact=openid_response.identity_url)
except UserOpenID.DoesNotExist: except UserOpenID.DoesNotExist:
if getattr(settings, 'OPENID_CREATE_USERS', False): if getattr(settings, 'OPENID_CREATE_USERS', False):
try:
user = self.create_user_from_openid(openid_response) user = self.create_user_from_openid(openid_response)
except StrictUsernameViolation:
return None
else: else:
user = user_openid.user user = user_openid.user
...@@ -144,12 +144,6 @@ class OpenIDBackend: ...@@ -144,12 +144,6 @@ class OpenIDBackend:
first_name=first_name, last_name=last_name) first_name=first_name, last_name=last_name)
def _get_available_username(self, nickname, identity_url): def _get_available_username(self, nickname, identity_url):
# If we're being strict about usernames, throw an error if we didn't
# get one back from the provider
if getattr(settings, 'OPENID_STRICT_USERNAMES', False):
if nickname is None or nickname == '':
raise StrictUsernameViolation("No username")
# If we don't have a nickname, and we're not being strict, use a default # If we don't have a nickname, and we're not being strict, use a default
nickname = nickname or 'openiduser' nickname = nickname or 'openiduser'
...@@ -184,7 +178,9 @@ class OpenIDBackend: ...@@ -184,7 +178,9 @@ class OpenIDBackend:
if getattr(settings, 'OPENID_STRICT_USERNAMES', False): if getattr(settings, 'OPENID_STRICT_USERNAMES', False):
if User.objects.filter(username__exact=nickname).count() > 0: if User.objects.filter(username__exact=nickname).count() > 0:
raise StrictUsernameViolation("Duplicate username: %s" % nickname) raise StrictUsernameViolation(
"The username (%s) with which you tried to log in is "
"already in use for a different account." % nickname)
# Pick a username for the user based on their nickname, # Pick a username for the user based on their nickname,
# checking for conflicts. # checking for conflicts.
...@@ -202,6 +198,16 @@ class OpenIDBackend: ...@@ -202,6 +198,16 @@ class OpenIDBackend:
def create_user_from_openid(self, openid_response): def create_user_from_openid(self, openid_response):
details = self._extract_user_details(openid_response) details = self._extract_user_details(openid_response)
required_attrs = getattr(settings, 'OPENID_SREG_REQUIRED_FIELDS', [])
if getattr(settings, 'OPENID_STRICT_USERNAMES', False):
required_attrs.append('nickname')
for required_attr in required_attrs:
if required_attr not in details or not details[required_attr]:
raise RequiredAttributeNotReturned(
"An attribute required for logging in was not "
"returned ({0}).".format(required_attr))
nickname = details['nickname'] or 'openiduser' nickname = details['nickname'] or 'openiduser'
email = details['email'] or '' email = details['email'] or ''
......
...@@ -137,6 +137,8 @@ class RelyingPartyTests(TestCase): ...@@ -137,6 +137,8 @@ class RelyingPartyTests(TestCase):
self.old_teams_map = getattr(settings, 'OPENID_LAUNCHPAD_TEAMS_MAPPING', {}) self.old_teams_map = getattr(settings, 'OPENID_LAUNCHPAD_TEAMS_MAPPING', {})
self.old_use_as_admin_login = getattr(settings, 'OPENID_USE_AS_ADMIN_LOGIN', False) self.old_use_as_admin_login = getattr(settings, 'OPENID_USE_AS_ADMIN_LOGIN', False)
self.old_follow_renames = getattr(settings, 'OPENID_FOLLOW_RENAMES', False) self.old_follow_renames = getattr(settings, 'OPENID_FOLLOW_RENAMES', False)
self.old_required_fields = getattr(
settings, 'OPENID_SREG_REQUIRED_FIELDS', [])
settings.OPENID_CREATE_USERS = False settings.OPENID_CREATE_USERS = False
settings.OPENID_STRICT_USERNAMES = False settings.OPENID_STRICT_USERNAMES = False
...@@ -145,6 +147,7 @@ class RelyingPartyTests(TestCase): ...@@ -145,6 +147,7 @@ class RelyingPartyTests(TestCase):
settings.OPENID_LAUNCHPAD_TEAMS_MAPPING = {} settings.OPENID_LAUNCHPAD_TEAMS_MAPPING = {}
settings.OPENID_USE_AS_ADMIN_LOGIN = False settings.OPENID_USE_AS_ADMIN_LOGIN = False
settings.OPENID_FOLLOW_RENAMES = False settings.OPENID_FOLLOW_RENAMES = False
settings.OPENID_SREG_REQUIRED_FIELDS = []
def tearDown(self): def tearDown(self):
settings.LOGIN_REDIRECT_URL = self.old_login_redirect_url settings.LOGIN_REDIRECT_URL = self.old_login_redirect_url
...@@ -155,6 +158,7 @@ class RelyingPartyTests(TestCase): ...@@ -155,6 +158,7 @@ class RelyingPartyTests(TestCase):
settings.OPENID_LAUNCHPAD_TEAMS_MAPPING = self.old_teams_map settings.OPENID_LAUNCHPAD_TEAMS_MAPPING = self.old_teams_map
settings.OPENID_USE_AS_ADMIN_LOGIN = self.old_use_as_admin_login settings.OPENID_USE_AS_ADMIN_LOGIN = self.old_use_as_admin_login
settings.OPENID_FOLLOW_RENAMES = self.old_follow_renames settings.OPENID_FOLLOW_RENAMES = self.old_follow_renames
settings.OPENID_SREG_REQUIRED_FIELDS = self.old_required_fields
setDefaultFetcher(None) setDefaultFetcher(None)
super(RelyingPartyTests, self).tearDown() super(RelyingPartyTests, self).tearDown()
...@@ -553,7 +557,39 @@ class RelyingPartyTests(TestCase): ...@@ -553,7 +557,39 @@ class RelyingPartyTests(TestCase):
response = self.complete(openid_response) response = self.complete(openid_response)
# Status code should be 403: Forbidden # Status code should be 403: Forbidden
self.assertEquals(403, response.status_code) self.assertContains(response,
"The username (someuser) with which you tried to log in is "
"already in use for a different account.",
status_code=403)
def test_login_requires_sreg_required_fields(self):
# If any required attributes are not included in the response,
# we fail with a forbidden.
settings.OPENID_CREATE_USERS = True
settings.OPENID_SREG_REQUIRED_FIELDS = ('email', 'language')
# Posting in an identity URL begins the authentication request:
response = self.client.post('/openid/login/',
{'openid_identifier': 'http://example.com/identity',
'next': '/getuser/'})
self.assertContains(response, 'OpenID transaction in progress')
# Complete the request, passing back some simple registration
# data. The user is redirected to the next URL.
openid_request = self.provider.parseFormPost(response.content)
sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request)
openid_response = openid_request.answer(True)
sreg_response = sreg.SRegResponse.extractResponse(
sreg_request, {'nickname': 'foo',
'fullname': 'Some User',
'email': 'foo@example.com'})
openid_response.addExtension(sreg_response)
response = self.complete(openid_response)
# Status code should be 403: Forbidden as we didn't include
# a required field - language.
self.assertContains(response,
"An attribute required for logging in was not returned "
"(language)", status_code=403)
def test_login_update_details(self): def test_login_update_details(self):
settings.OPENID_UPDATE_DETAILS_FROM_SREG = True settings.OPENID_UPDATE_DETAILS_FROM_SREG = True
...@@ -599,6 +635,27 @@ class RelyingPartyTests(TestCase): ...@@ -599,6 +635,27 @@ class RelyingPartyTests(TestCase):
for field in ('email', 'fullname', 'nickname', 'language'): for field in ('email', 'fullname', 'nickname', 'language'):
self.assertTrue(field in sreg_request) self.assertTrue(field in sreg_request)
def test_login_uses_sreg_required_fields(self):
# The configurable sreg attributes are used in the request.
settings.OPENID_SREG_REQUIRED_FIELDS = ('email', 'language')
user = User.objects.create_user('testuser', 'someone@example.com')
useropenid = UserOpenID(
user=user,
claimed_id='http://example.com/identity',
display_id='http://example.com/identity')
useropenid.save()
# Posting in an identity URL begins the authentication request:
response = self.client.post('/openid/login/',
{'openid_identifier': 'http://example.com/identity',
'next': '/getuser/'})
openid_request = self.provider.parseFormPost(response.content)
sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request)
self.assertEqual(['email', 'language'], sreg_request.required)
self.assertEqual(['fullname', 'nickname'], sreg_request.optional)
def test_login_attribute_exchange(self): def test_login_attribute_exchange(self):
settings.OPENID_UPDATE_DETAILS_FROM_SREG = True settings.OPENID_UPDATE_DETAILS_FROM_SREG = True
user = User.objects.create_user('testuser', 'someone@example.com') user = User.objects.create_user('testuser', 'someone@example.com')
......
...@@ -51,6 +51,10 @@ from openid.consumer.discover import DiscoveryFailure ...@@ -51,6 +51,10 @@ from openid.consumer.discover import DiscoveryFailure
from openid.extensions import sreg, ax from openid.extensions import sreg, ax
from django_openid_auth import teams from django_openid_auth import teams
from django_openid_auth.auth import (
RequiredAttributeNotReturned,
StrictUsernameViolation,
)
from django_openid_auth.forms import OpenIDLoginForm from django_openid_auth.forms import OpenIDLoginForm
from django_openid_auth.models import UserOpenID from django_openid_auth.models import UserOpenID
from django_openid_auth.signals import openid_login_complete from django_openid_auth.signals import openid_login_complete
...@@ -196,11 +200,18 @@ def login_begin(request, template_name='openid/login.html', ...@@ -196,11 +200,18 @@ def login_begin(request, template_name='openid/login.html',
fetch_request.add(ax.AttrInfo(attr, alias=alias, required=True)) fetch_request.add(ax.AttrInfo(attr, alias=alias, required=True))
openid_request.addExtension(fetch_request) openid_request.addExtension(fetch_request)
else: else:
sreg_required_fields = []
sreg_required_fields.extend(
getattr(settings, 'OPENID_SREG_REQUIRED_FIELDS', []))
sreg_optional_fields = ['email', 'fullname', 'nickname'] sreg_optional_fields = ['email', 'fullname', 'nickname']
extra_fields = getattr(settings, 'OPENID_SREG_EXTRA_FIELDS', []) sreg_optional_fields.extend(
sreg_optional_fields.extend(extra_fields) getattr(settings, 'OPENID_SREG_EXTRA_FIELDS', []))
sreg_optional_fields = [
field for field in sreg_optional_fields if (
not field in sreg_required_fields)]
openid_request.addExtension( openid_request.addExtension(
sreg.SRegRequest(optional=sreg_optional_fields)) sreg.SRegRequest(optional=sreg_optional_fields,
required=sreg_required_fields))
# Request team info # Request team info
teams_mapping_auto = getattr(settings, 'OPENID_LAUNCHPAD_TEAMS_MAPPING_AUTO', False) teams_mapping_auto = getattr(settings, 'OPENID_LAUNCHPAD_TEAMS_MAPPING_AUTO', False)
...@@ -240,7 +251,11 @@ def login_complete(request, redirect_field_name=REDIRECT_FIELD_NAME, ...@@ -240,7 +251,11 @@ def login_complete(request, redirect_field_name=REDIRECT_FIELD_NAME,
request, 'This is an OpenID relying party endpoint.') request, 'This is an OpenID relying party endpoint.')
if openid_response.status == SUCCESS: if openid_response.status == SUCCESS:
try:
user = authenticate(openid_response=openid_response) user = authenticate(openid_response=openid_response)
except (StrictUsernameViolation, RequiredAttributeNotReturned), e:
return render_failure(request, e)
if user is not None: if user is not None:
if user.is_active: if user.is_active:
auth_login(request, user) auth_login(request, user)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment