Commit 8b210641 by bmedx

Update repo from upstream revision 128

- Previous pull was revision 115
- Diff of versions: http://bazaar.launchpad.net/~ubuntuone-pqm-team/django-openid-auth/trunk/revision/128?remember=115&compare_revid=115
- Should get us Py3 and Django up to 1.10 support
parent 86b822c2
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Django stuff:
*.log
local_settings.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
.gitignore
.idea/
django_openid_auth.egg-info/
...@@ -26,3 +26,6 @@ ...@@ -26,3 +26,6 @@
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE. # POSSIBILITY OF SUCH DAMAGE.
import sys
PY3 = sys.version_info.major >= 3
...@@ -29,8 +29,11 @@ ...@@ -29,8 +29,11 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from urllib import urlencode try:
from urlparse import parse_qsl, urlparse from urllib.parse import parse_qsl, urlencode, urlparse
except ImportError:
from urllib import urlencode
from urlparse import parse_qsl, urlparse
from django.conf import settings from django.conf import settings
from django.contrib import admin from django.contrib import admin
...@@ -50,6 +53,7 @@ class NonceAdmin(admin.ModelAdmin): ...@@ -50,6 +53,7 @@ class NonceAdmin(admin.ModelAdmin):
self.message_user(request, "%d expired nonces removed" % count) self.message_user(request, "%d expired nonces removed" % count)
cleanup_nonces.short_description = "Clean up expired nonces" cleanup_nonces.short_description = "Clean up expired nonces"
admin.site.register(Nonce, NonceAdmin) admin.site.register(Nonce, NonceAdmin)
...@@ -65,6 +69,7 @@ class AssociationAdmin(admin.ModelAdmin): ...@@ -65,6 +69,7 @@ class AssociationAdmin(admin.ModelAdmin):
self.message_user(request, "%d expired associations removed" % count) self.message_user(request, "%d expired associations removed" % count)
cleanup_associations.short_description = "Clean up expired associations" cleanup_associations.short_description = "Clean up expired associations"
admin.site.register(Association, AssociationAdmin) admin.site.register(Association, AssociationAdmin)
...@@ -73,6 +78,7 @@ class UserOpenIDAdmin(admin.ModelAdmin): ...@@ -73,6 +78,7 @@ class UserOpenIDAdmin(admin.ModelAdmin):
list_display = ('user', 'claimed_id') list_display = ('user', 'claimed_id')
search_fields = ('claimed_id',) search_fields = ('claimed_id',)
admin.site.register(UserOpenID, UserOpenIDAdmin) admin.site.register(UserOpenID, UserOpenIDAdmin)
......
...@@ -30,13 +30,12 @@ ...@@ -30,13 +30,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
__metaclass__ = type
import re import re
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group, Permission from django.contrib.auth.models import Group, Permission
from django.core.exceptions import ImproperlyConfigured
from openid.consumer.consumer import SUCCESS from openid.consumer.consumer import SUCCESS
from openid.extensions import ax, sreg, pape from openid.extensions import ax, sreg, pape
...@@ -49,12 +48,51 @@ from django_openid_auth.exceptions import ( ...@@ -49,12 +48,51 @@ from django_openid_auth.exceptions import (
MissingPhysicalMultiFactor, MissingPhysicalMultiFactor,
RequiredAttributeNotReturned, RequiredAttributeNotReturned,
) )
from django_openid_auth.signals import openid_duplicate_username
User = get_user_model() User = get_user_model()
class OpenIDBackend: def get_user_group_model():
"""Returns the model used for mapping users to groups."""
user_group_model_name = getattr(settings, 'AUTH_USER_GROUP_MODEL', None)
if user_group_model_name is None:
return User.groups.through
else:
try:
# django.apps available starting from django 1.7
from django.apps import apps
get_model = apps.get_model
args = (user_group_model_name,)
except ImportError:
# if we can't import, then it must be django 1.6, still using
# the old django.db.models.loading code
from django.db.models.loading import get_model
app_label, model_name = user_group_model_name.split('.', 1)
args = (app_label, model_name)
try:
model = get_model(*args)
if model is None:
# in django 1.6 referring to a non-installed app will
# return None for get_model, but in 1.7 onwards it will
# raise a LookupError exception.
raise LookupError()
return model
except ValueError:
raise ImproperlyConfigured(
"AUTH_USER_GROUP_MODEL must be of the form "
"'app_label.model_name'")
except LookupError:
raise ImproperlyConfigured(
"AUTH_USER_GROUP_MODEL refers to model '%s' that has not been "
"installed" % user_group_model_name)
UserGroup = get_user_group_model()
class OpenIDBackend(object):
"""A django.contrib.auth backend that authenticates the user based on """A django.contrib.auth backend that authenticates the user based on
an OpenID response.""" an OpenID response."""
...@@ -116,8 +154,9 @@ class OpenIDBackend: ...@@ -116,8 +154,9 @@ class OpenIDBackend:
teams_mapping = self.get_teams_mapping() teams_mapping = self.get_teams_mapping()
groups_required = [group for team, group in teams_mapping.items() groups_required = [group for team, group in teams_mapping.items()
if team in teams_required] if team in teams_required]
user_groups = UserGroup.objects.filter(user=user)
matches = set(groups_required).intersection( matches = set(groups_required).intersection(
user.groups.values_list('name', flat=True)) user_groups.values_list('group__name', flat=True))
if not matches: if not matches:
name = 'OPENID_EMAIL_WHITELIST_REGEXP_LIST' name = 'OPENID_EMAIL_WHITELIST_REGEXP_LIST'
whitelist_regexp_list = getattr(settings, name, []) whitelist_regexp_list = getattr(settings, name, [])
...@@ -194,28 +233,19 @@ class OpenIDBackend: ...@@ -194,28 +233,19 @@ class OpenIDBackend:
return suggestion return suggestion
return 'openiduser' return 'openiduser'
def _get_available_username(self, nickname, identity_url): def _get_available_username_for_nickname(self, nickname, identity_url):
# If we're being strict about usernames, throw an error if we didn't
# get one back from the provider
if getattr(settings, 'OPENID_STRICT_USERNAMES', False):
if nickname is None or nickname == '':
raise MissingUsernameViolation()
# If we don't have a nickname, and we're not being strict, use a # If we don't have a nickname, and we're not being strict, use a
# default # default
nickname = nickname or 'openiduser' nickname = nickname or 'openiduser'
# See if we already have this nickname assigned to a username # See if we already have this nickname assigned to a username
try: if not User.objects.filter(username=nickname).exists():
User.objects.get(username__exact=nickname)
except User.DoesNotExist:
# No conflict, we can use this nickname
return nickname return nickname
# Check if we already have nickname+i for this identity_url # Check if we already have nickname+i for this identity_url
try: try:
user_openid = UserOpenID.objects.get( user_openid = UserOpenID.objects.get(
claimed_id__exact=identity_url, claimed_id=identity_url,
user__username__startswith=nickname) user__username__startswith=nickname)
# No exception means we have an existing user for this identity # No exception means we have an existing user for this identity
# that starts with this nickname. # that starts with this nickname.
...@@ -239,28 +269,48 @@ class OpenIDBackend: ...@@ -239,28 +269,48 @@ class OpenIDBackend:
# No user associated with this identity_url # No user associated with this identity_url
pass pass
if getattr(settings, 'OPENID_STRICT_USERNAMES', False):
if User.objects.filter(username__exact=nickname).count() > 0:
raise DuplicateUsernameViolation(
"The username (%s) with which you tried to log in is "
"already in use for a different account." % nickname)
# Pick a username for the user based on their nickname, # Pick a username for the user based on their nickname,
# checking for conflicts. Start with number of existing users who's # checking for conflicts. Start with number of existing users who's
# username starts with this nickname to avoid having to iterate over # username starts with this nickname to avoid having to iterate over
# all of the existing ones. # all of the existing ones.
i = User.objects.filter(username__startswith=nickname).count() + 1 i = User.objects.filter(username__startswith=nickname).count() + 1
while True: username = nickname
username = nickname while User.objects.filter(username=username).exists():
if i > 1: username = nickname + str(i)
username += str(i)
try:
User.objects.get(username__exact=username)
except User.DoesNotExist:
break
i += 1 i += 1
return username return username
def _ensure_available_username(self, nickname, identity_url):
if not nickname:
raise MissingUsernameViolation()
# As long as the `QuerySet` does not get evaluated, no
# caching should be involved in our multiple `exists()`
# calls. See docs for details: http://bit.ly/2aYCmkw
user_with_same_username = User.objects.exclude(
useropenid__claimed_id=identity_url
).filter(username=nickname)
if user_with_same_username.exists():
# Notify any listeners that a duplicated username was
# found and give the opportunity to handle conflict.
openid_duplicate_username.send(sender=User, username=nickname)
# Check for conflicts again as the signal could have handled it.
if user_with_same_username.exists():
raise DuplicateUsernameViolation(
"The username (%s) with which you tried to log in is "
"already in use for a different account." % nickname)
def _get_available_username(self, nickname, identity_url):
if getattr(settings, 'OPENID_STRICT_USERNAMES', False):
self._ensure_available_username(nickname, identity_url)
else:
nickname = self._get_available_username_for_nickname(
nickname, identity_url)
return nickname
def create_user_from_openid(self, openid_response): def create_user_from_openid(self, openid_response):
details = self._extract_user_details(openid_response) details = self._extract_user_details(openid_response)
required_attrs = getattr(settings, 'OPENID_SREG_REQUIRED_FIELDS', []) required_attrs = getattr(settings, 'OPENID_SREG_REQUIRED_FIELDS', [])
...@@ -357,13 +407,17 @@ class OpenIDBackend: ...@@ -357,13 +407,17 @@ class OpenIDBackend:
mapping = [ mapping = [
teams_mapping[lp_team] for lp_team in teams_response.is_member teams_mapping[lp_team] for lp_team in teams_response.is_member
if lp_team in teams_mapping] if lp_team in teams_mapping]
user_groups = UserGroup.objects.filter(user=user)
matching_groups = user_groups.filter(
group__name__in=teams_mapping.values())
current_groups = set( current_groups = set(
user.groups.filter(name__in=teams_mapping.values())) user_group.group for user_group in matching_groups)
desired_groups = set(Group.objects.filter(name__in=mapping)) desired_groups = set(Group.objects.filter(name__in=mapping))
for group in current_groups - desired_groups: groups_to_remove = current_groups - desired_groups
user.groups.remove(group) groups_to_add = desired_groups - current_groups
for group in desired_groups - current_groups: user_groups.filter(group__in=groups_to_remove).delete()
user.groups.add(group) for group in groups_to_add:
UserGroup.objects.create(user=user, group=group)
def update_staff_status_from_teams(self, user, teams_response): def update_staff_status_from_teams(self, user, teams_response):
if not hasattr(settings, 'OPENID_LAUNCHPAD_STAFF_TEAMS'): if not hasattr(settings, 'OPENID_LAUNCHPAD_STAFF_TEAMS'):
......
...@@ -38,6 +38,8 @@ from django.conf import settings ...@@ -38,6 +38,8 @@ from django.conf import settings
from openid.yadis import xri from openid.yadis import xri
from django_openid_auth import PY3
def teams_new_unicode(self): def teams_new_unicode(self):
""" """
...@@ -52,8 +54,13 @@ def teams_new_unicode(self): ...@@ -52,8 +54,13 @@ def teams_new_unicode(self):
else: else:
return name return name
Group.unicode_before_teams = Group.__unicode__
Group.__unicode__ = teams_new_unicode if PY3:
Group.unicode_before_teams = Group.__str__
Group.__str__ = teams_new_unicode
else:
Group.unicode_before_teams = Group.__unicode__
Group.__unicode__ = teams_new_unicode
class UserChangeFormWithTeamRestriction(UserChangeForm): class UserChangeFormWithTeamRestriction(UserChangeForm):
...@@ -72,6 +79,7 @@ class UserChangeFormWithTeamRestriction(UserChangeForm): ...@@ -72,6 +79,7 @@ class UserChangeFormWithTeamRestriction(UserChangeForm):
"You cannot assign it manually." % group.name) "You cannot assign it manually." % group.name)
return data return data
UserAdmin.form = UserChangeFormWithTeamRestriction UserAdmin.form = UserChangeFormWithTeamRestriction
......
...@@ -34,3 +34,4 @@ import django.dispatch ...@@ -34,3 +34,4 @@ import django.dispatch
openid_login_complete = django.dispatch.Signal(providing_args=[ openid_login_complete = django.dispatch.Signal(providing_args=[
'request', 'openid_response']) 'request', 'openid_response'])
openid_duplicate_username = django.dispatch.Signal(providing_args=['username'])
...@@ -36,6 +36,7 @@ from openid.association import Association as OIDAssociation ...@@ -36,6 +36,7 @@ from openid.association import Association as OIDAssociation
from openid.store.interface import OpenIDStore from openid.store.interface import OpenIDStore
from openid.store.nonce import SKEW from openid.store.nonce import SKEW
from django_openid_auth import PY3
from django_openid_auth.models import Association, Nonce from django_openid_auth.models import Association, Nonce
...@@ -75,10 +76,15 @@ class DjangoOpenIDStore(OpenIDStore): ...@@ -75,10 +76,15 @@ class DjangoOpenIDStore(OpenIDStore):
expired = [] expired = []
for assoc in assocs: for assoc in assocs:
association = OIDAssociation( association = OIDAssociation(
assoc.handle, base64.decodestring(assoc.secret), assoc.issued, assoc.handle,
assoc.lifetime, assoc.assoc_type base64.decodestring(assoc.secret.encode('utf-8')),
assoc.issued, assoc.lifetime, assoc.assoc_type
) )
if association.getExpiresIn() == 0: if PY3:
expires_in = association.expiresIn
else:
expires_in = association.getExpiresIn()
if expires_in == 0:
expired.append(assoc) expired.append(assoc)
else: else:
associations.append((association.issued, association)) associations.append((association.issued, association))
......
...@@ -72,6 +72,7 @@ from openid.message import ( ...@@ -72,6 +72,7 @@ from openid.message import (
registerNamespaceAlias, registerNamespaceAlias,
NamespaceAliasRegistrationError, NamespaceAliasRegistrationError,
) )
from six import string_types
__all__ = [ __all__ = [
'TeamsRequest', 'TeamsRequest',
...@@ -84,7 +85,7 @@ ns_uri = 'http://ns.launchpad.net/2007/openid-teams' ...@@ -84,7 +85,7 @@ ns_uri = 'http://ns.launchpad.net/2007/openid-teams'
try: try:
registerNamespaceAlias(ns_uri, 'lp') registerNamespaceAlias(ns_uri, 'lp')
except NamespaceAliasRegistrationError, e: except NamespaceAliasRegistrationError as e:
oidutil.log( oidutil.log(
'registerNamespaceAlias(%r, %r) failed: %s' % (ns_uri, 'lp', str(e))) 'registerNamespaceAlias(%r, %r) failed: %s' % (ns_uri, 'lp', str(e)))
...@@ -139,7 +140,7 @@ def getTeamsNS(message): ...@@ -139,7 +140,7 @@ def getTeamsNS(message):
# There is no alias, so try to add one. (OpenID version 1) # There is no alias, so try to add one. (OpenID version 1)
try: try:
message.namespaces.addAlias(ns_uri, 'lp') message.namespaces.addAlias(ns_uri, 'lp')
except KeyError, why: except KeyError as why:
# An alias for the string 'lp' already exists, but it's # An alias for the string 'lp' already exists, but it's
# defined for something other than Launchpad teams # defined for something other than Launchpad teams
raise TeamsNamespaceError(why[0]) raise TeamsNamespaceError(why[0])
...@@ -287,7 +288,7 @@ class TeamsRequest(Extension): ...@@ -287,7 +288,7 @@ class TeamsRequest(Extension):
@raise ValueError: when a team requested is not a string @raise ValueError: when a team requested is not a string
or strict is set and a team was requested more than once or strict is set and a team was requested more than once
""" """
if isinstance(query_membership, basestring): if isinstance(query_membership, string_types):
raise TypeError('Teams should be passed as a list of ' raise TypeError('Teams should be passed as a list of '
'strings (not %r)' % (type(query_membership),)) 'strings (not %r)' % (type(query_membership),))
......
{% load i18n %} {% load i18n %}
{% load url from future %}
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01//EN" "http://www.w3.org/TR/html4/strict.dtd"> <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01//EN" "http://www.w3.org/TR/html4/strict.dtd">
<html> <html>
<head> <head>
......
from django.conf import settings
from django.contrib.auth.models import Group
from django.db import models
class UserGroup(models.Model):
user = models.ForeignKey(settings.AUTH_USER_MODEL)
group = models.ForeignKey(Group)
...@@ -28,56 +28,69 @@ ...@@ -28,56 +28,69 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import re
try:
from urllib.parse import urljoin
except ImportError:
from urlparse import urljoin
from django.conf import settings
from django.contrib.auth.models import Group, Permission, User from django.contrib.auth.models import Group, Permission, User
from django.core.exceptions import ImproperlyConfigured
from django.test import TestCase from django.test import TestCase
from django.test.utils import override_settings from django.test.utils import override_settings
from openid.consumer.consumer import (
from openid.consumer.consumer import SuccessResponse CancelResponse,
FailureResponse,
SetupNeededResponse,
SuccessResponse,
)
from openid.consumer.discover import OpenIDServiceEndpoint from openid.consumer.discover import OpenIDServiceEndpoint
from openid.extensions import pape
from openid.message import Message, OPENID2_NS from openid.message import Message, OPENID2_NS
from django_openid_auth.auth import OpenIDBackend from django_openid_auth.auth import OpenIDBackend, get_user_group_model
from django_openid_auth.exceptions import (
DuplicateUsernameViolation,
MissingPhysicalMultiFactor,
MissingUsernameViolation,
RequiredAttributeNotReturned,
)
from django_openid_auth.models import UserOpenID from django_openid_auth.models import UserOpenID
from django_openid_auth.signals import openid_duplicate_username
from django_openid_auth.teams import ns_uri as TEAMS_NS from django_openid_auth.teams import ns_uri as TEAMS_NS
from django_openid_auth.tests.helpers import override_session_serializer from django_openid_auth.tests.helpers import override_session_serializer
SREG_NS = "http://openid.net/sreg/1.0" SREG_NS = "http://openid.net/sreg/1.0"
AX_NS = "http://openid.net/srv/ax/1.0" AX_NS = "http://openid.net/srv/ax/1.0"
SERVER_URL = 'http://example.com'
@override_session_serializer def make_claimed_id(id_):
@override_settings( return urljoin(SERVER_URL, id_)
OPENID_USE_EMAIL_FOR_USERNAME=False,
OPENID_LAUNCHPAD_TEAMS_REQUIRED=[],
OPENID_LAUNCHPAD_TEAMS_MAPPING_AUTO=False,
OPENID_EMAIL_WHITELIST_REGEXP_LIST=[])
class OpenIDBackendTests(TestCase):
def setUp(self):
super(OpenIDBackendTests, self).setUp()
self.backend = OpenIDBackend()
def make_openid_response(self, sreg_args=None, teams_args=None): class TestMessage(Message):
endpoint = OpenIDServiceEndpoint() """Convenience class to construct test OpenID messages and responses."""
endpoint.claimed_id = 'some-id'
def __init__(self, openid_namespace=OPENID2_NS):
message = Message(OPENID2_NS) super(TestMessage, self).__init__(openid_namespace=openid_namespace)
if sreg_args is not None:
for key, value in sreg_args.items():
message.setArg(SREG_NS, key, value)
if teams_args is not None:
for key, value in teams_args.items():
message.setArg(TEAMS_NS, key, value)
response = SuccessResponse(
endpoint, message, signed_fields=message.toPostArgs().keys())
return response
def make_response_ax(
self, schema="http://axschema.org/",
fullname="Some User", nickname="someuser", email="foo@example.com",
first=None, last=None, verified=False):
endpoint = OpenIDServiceEndpoint() endpoint = OpenIDServiceEndpoint()
message = Message(OPENID2_NS) endpoint.claimed_id = make_claimed_id('some-id')
endpoint.server_url = SERVER_URL
self.endpoint = endpoint
def set_ax_args(
self,
email="foo@example.com",
first=None,
fullname="Some User",
last=None,
nickname="someuser",
schema="http://axschema.org/",
verified=False):
attributes = [ attributes = [
("nickname", schema + "namePerson/friendly", nickname), ("nickname", schema + "namePerson/friendly", nickname),
("fullname", schema + "namePerson", fullname), ("fullname", schema + "namePerson", fullname),
...@@ -93,53 +106,79 @@ class OpenIDBackendTests(TestCase): ...@@ -93,53 +106,79 @@ class OpenIDBackendTests(TestCase):
attributes.append( attributes.append(
("last", "http://axschema.org/namePerson/last", last)) ("last", "http://axschema.org/namePerson/last", last))
message.setArg(AX_NS, "mode", "fetch_response") self.setArg(AX_NS, "mode", "fetch_response")
for (alias, uri, value) in attributes: for (alias, uri, value) in attributes:
message.setArg(AX_NS, "type.%s" % alias, uri) self.setArg(AX_NS, "type.%s" % alias, uri)
message.setArg(AX_NS, "value.%s" % alias, value) self.setArg(AX_NS, "value.%s" % alias, value)
def set_pape_args(self, *auth_policies):
self.setArg(pape.ns_uri, 'auth_policies', ' '.join(auth_policies))
def _set_args(self, ns, **kwargs):
for key, value in kwargs.items():
if value is not None:
self.setArg(ns, key, value)
elif self.hasKey(ns, key):
self.delArg(ns, key)
def set_sreg_args(self, **kwargs):
self._set_args(SREG_NS, **kwargs)
def set_team_args(self, **kwargs):
self._set_args(TEAMS_NS, **kwargs)
def to_response(self):
return SuccessResponse( return SuccessResponse(
endpoint, message, signed_fields=message.toPostArgs().keys()) self.endpoint, self, signed_fields=self.toPostArgs().keys())
def make_user_openid(self, user=None, @override_session_serializer
claimed_id='http://example.com/existing_identity', @override_settings(
display_id='http://example.com/existing_identity'): OPENID_USE_EMAIL_FOR_USERNAME=False,
OPENID_LAUNCHPAD_TEAMS_REQUIRED=[],
OPENID_LAUNCHPAD_TEAMS_MAPPING_AUTO=False,
OPENID_EMAIL_WHITELIST_REGEXP_LIST=[])
class OpenIDBackendTests(TestCase):
def setUp(self):
super(OpenIDBackendTests, self).setUp()
self.backend = OpenIDBackend()
self.message = TestMessage()
def make_user_openid(
self, user=None, claimed_id=make_claimed_id('existing_identity')):
if user is None: if user is None:
user = User.objects.create_user( user = User.objects.create_user(
username='someuser', email='someuser@example.com', username='someuser', email='someuser@example.com',
password='12345678') password='12345678')
user_openid, created = UserOpenID.objects.get_or_create( user_openid, created = UserOpenID.objects.get_or_create(
user=user, claimed_id=claimed_id, display_id=display_id) user=user, claimed_id=claimed_id, display_id=claimed_id)
return user_openid return user_openid
def assert_account_verified(self, user, initially_verified, verified): def _assert_account_verified(self, user, expected):
# set user's verification status
permission = Permission.objects.get(codename='account_verified') permission = Permission.objects.get(codename='account_verified')
if initially_verified: perm_label = '%s.%s' % (permission.content_type.app_label,
user.user_permissions.add(permission) permission.codename)
else: # Always invalidate the per-request perm cache
user.user_permissions.remove(permission) attrs = list(user.__dict__.keys())
for attr in attrs:
if attr.endswith('_perm_cache'):
delattr(user, attr)
user = User.objects.get(pk=user.pk) self.assertEqual(user.has_perm(perm_label), expected)
has_perm = user.has_perm('django_openid_auth.account_verified')
assert has_perm == initially_verified
if hasattr(user, '_perm_cache'): def assert_account_not_verified(self, user):
del user._perm_cache self._assert_account_verified(user, False)
# get a response including verification status def assert_account_verified(self, user):
response = self.make_response_ax() self._assert_account_verified(user, True)
data = dict(first_name=u"Some56789012345678901234567890123",
last_name=u"User56789012345678901234567890123",
email=u"someotheruser@example.com",
account_verified=verified)
self.backend.update_user_details(user, data, response)
# refresh object from the database def assert_no_users_created(self, expected_count=0):
user = User.objects.get(pk=user.pk) current_count = User.objects.count()
# check the verification status msg = 'New users found (expected: %i, current: %i)' % (
self.assertEqual( expected_count, current_count)
user.has_perm('django_openid_auth.account_verified'), verified) self.assertEqual(current_count, expected_count, msg)
def test_extract_user_details_sreg(self): def test_extract_user_details_sreg(self):
expected = { expected = {
...@@ -155,16 +194,19 @@ class OpenIDBackendTests(TestCase): ...@@ -155,16 +194,19 @@ class OpenIDBackendTests(TestCase):
expected['last_name']), expected['last_name']),
'email': expected['email'], 'email': expected['email'],
} }
response = self.make_openid_response(sreg_args=data) self.message.set_sreg_args(**data)
details = self.backend._extract_user_details(response) details = self.backend._extract_user_details(
self.message.to_response())
self.assertEqual(details, expected) self.assertEqual(details, expected)
def test_extract_user_details_ax(self): def test_extract_user_details_ax(self):
response = self.make_response_ax( self.message.set_ax_args(
fullname="Some User", nickname="someuser", email="foo@example.com") email="foo@example.com",
fullname="Some User",
data = self.backend._extract_user_details(response) nickname="someuser",
)
data = self.backend._extract_user_details(self.message.to_response())
self.assertEqual(data, {"nickname": "someuser", self.assertEqual(data, {"nickname": "someuser",
"first_name": "Some", "first_name": "Some",
...@@ -175,10 +217,9 @@ class OpenIDBackendTests(TestCase): ...@@ -175,10 +217,9 @@ class OpenIDBackendTests(TestCase):
def test_extract_user_details_ax_split_name(self): def test_extract_user_details_ax_split_name(self):
# Include fullname too to show that the split data takes # Include fullname too to show that the split data takes
# precedence. # precedence.
response = self.make_response_ax( self.message.set_ax_args(
fullname="Bad Data", first="Some", last="User") fullname="Bad Data", first="Some", last="User")
data = self.backend._extract_user_details(self.message.to_response())
data = self.backend._extract_user_details(response)
self.assertEqual(data, {"nickname": "someuser", self.assertEqual(data, {"nickname": "someuser",
"first_name": "Some", "first_name": "Some",
...@@ -187,11 +228,10 @@ class OpenIDBackendTests(TestCase): ...@@ -187,11 +228,10 @@ class OpenIDBackendTests(TestCase):
"account_verified": False}) "account_verified": False})
def test_extract_user_details_ax_broken_myopenid(self): def test_extract_user_details_ax_broken_myopenid(self):
response = self.make_response_ax( self.message.set_ax_args(
schema="http://schema.openid.net/", fullname="Some User", schema="http://schema.openid.net/", fullname="Some User",
nickname="someuser", email="foo@example.com") nickname="someuser", email="foo@example.com")
data = self.backend._extract_user_details(self.message.to_response())
data = self.backend._extract_user_details(response)
self.assertEqual(data, {"nickname": "someuser", self.assertEqual(data, {"nickname": "someuser",
"first_name": "Some", "first_name": "Some",
...@@ -200,7 +240,7 @@ class OpenIDBackendTests(TestCase): ...@@ -200,7 +240,7 @@ class OpenIDBackendTests(TestCase):
"account_verified": False}) "account_verified": False})
def test_update_user_details_long_names(self): def test_update_user_details_long_names(self):
response = self.make_response_ax() self.message.set_ax_args()
user = User.objects.create_user( user = User.objects.create_user(
'someuser', 'someuser@example.com', password=None) 'someuser', 'someuser@example.com', password=None)
user_openid, created = UserOpenID.objects.get_or_create( user_openid, created = UserOpenID.objects.get_or_create(
...@@ -212,49 +252,92 @@ class OpenIDBackendTests(TestCase): ...@@ -212,49 +252,92 @@ class OpenIDBackendTests(TestCase):
last_name=u"User56789012345678901234567890123", last_name=u"User56789012345678901234567890123",
email=u"someotheruser@example.com", account_verified=False) email=u"someotheruser@example.com", account_verified=False)
self.backend.update_user_details(user, data, response) self.backend.update_user_details(
user, data, self.message.to_response())
self.assertEqual("Some56789012345678901234567890", user.first_name) self.assertEqual("Some56789012345678901234567890", user.first_name)
self.assertEqual("User56789012345678901234567890", user.last_name) self.assertEqual("User56789012345678901234567890", user.last_name)
def _test_update_user_perms_account_verified(
self, user, initially_verified, verified):
# set user's verification status
permission = Permission.objects.get(codename='account_verified')
if initially_verified:
user.user_permissions.add(permission)
else:
user.user_permissions.remove(permission)
if initially_verified:
self.assert_account_verified(user)
else:
self.assert_account_not_verified(user)
# get a response including verification status
self.message.set_ax_args()
data = dict(first_name=u"Some56789012345678901234567890123",
last_name=u"User56789012345678901234567890123",
email=u"someotheruser@example.com",
account_verified=verified)
self.backend.update_user_details(
user, data, self.message.to_response())
# refresh object from the database
user = User.objects.get(pk=user.pk)
if verified:
self.assert_account_verified(user)
else:
self.assert_account_not_verified(user)
def test_update_user_perms_initially_verified_then_verified(self): def test_update_user_perms_initially_verified_then_verified(self):
self.assert_account_verified( self._test_update_user_perms_account_verified(
self.make_user_openid().user, self.make_user_openid().user,
initially_verified=True, verified=True) initially_verified=True, verified=True)
def test_update_user_perms_initially_verified_then_unverified(self): def test_update_user_perms_initially_verified_then_unverified(self):
self.assert_account_verified( self._test_update_user_perms_account_verified(
self.make_user_openid().user, self.make_user_openid().user,
initially_verified=True, verified=False) initially_verified=True, verified=False)
def test_update_user_perms_initially_not_verified_then_verified(self): def test_update_user_perms_initially_not_verified_then_verified(self):
self.assert_account_verified( self._test_update_user_perms_account_verified(
self.make_user_openid().user, self.make_user_openid().user,
initially_verified=False, verified=True) initially_verified=False, verified=True)
def test_update_user_perms_initially_not_verified_then_unverified(self): def test_update_user_perms_initially_not_verified_then_unverified(self):
self.assert_account_verified( self._test_update_user_perms_account_verified(
self.make_user_openid().user, self.make_user_openid().user,
initially_verified=False, verified=False) initially_verified=False, verified=False)
def test_extract_user_details_name_with_trailing_space(self): def test_extract_user_details_name_with_trailing_space(self):
response = self.make_response_ax(fullname="SomeUser ") self.message.set_ax_args(fullname="SomeUser ")
data = self.backend._extract_user_details(response) data = self.backend._extract_user_details(self.message.to_response())
self.assertEqual("", data['first_name']) self.assertEqual("", data['first_name'])
self.assertEqual("SomeUser", data['last_name']) self.assertEqual("SomeUser", data['last_name'])
def test_extract_user_details_name_with_thin_space(self): def test_extract_user_details_name_with_thin_space(self):
response = self.make_response_ax(fullname=u"Some\u2009User") self.message.set_ax_args(fullname=u"Some\u2009User")
data = self.backend._extract_user_details(response) data = self.backend._extract_user_details(self.message.to_response())
self.assertEqual("Some", data['first_name']) self.assertEqual("Some", data['first_name'])
self.assertEqual("User", data['last_name']) self.assertEqual("User", data['last_name'])
@override_settings(OPENID_CREATE_USERS=True)
def test_auth_username_when_no_nickname(self):
self.message.set_sreg_args(nickname='')
user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertIsNotNone(user)
self.assertEqual(
user.username, 'openiduser',
"username must default to 'openiduser'")
@override_settings(OPENID_USE_EMAIL_FOR_USERNAME=True) @override_settings(OPENID_USE_EMAIL_FOR_USERNAME=True)
def test_preferred_username_email_munging(self): def test_auth_username_email_munging(self):
for nick, email, expected in [ for nick, email, expected in [
('nickcomesfirst', 'foo@example.com', 'nickcomesfirst'), ('nickcomesfirst', 'foo@example.com', 'nickcomesfirst'),
('', 'foo@example.com', 'fooexamplecom'), ('', 'foo@example.com', 'fooexamplecom'),
...@@ -262,11 +345,17 @@ class OpenIDBackendTests(TestCase): ...@@ -262,11 +345,17 @@ class OpenIDBackendTests(TestCase):
('', '@%.-', 'openiduser'), ('', '@%.-', 'openiduser'),
('', '', 'openiduser'), ('', '', 'openiduser'),
(None, None, 'openiduser')]: (None, None, 'openiduser')]:
self.assertEqual( self.message.set_sreg_args(nickname=nick, email=email)
expected, user = self.backend.authenticate(
self.backend._get_preferred_username(nick, email)) openid_response=self.message.to_response())
# Cleanup user for further tests
user.delete()
def test_preferred_username_no_email_munging(self): self.assertIsNotNone(user)
self.assertEqual(user.username, expected)
@override_settings(OPENID_USE_EMAIL_FOR_USERNAME=False)
def test_auth_username_no_email_munging(self):
for nick, email, expected in [ for nick, email, expected in [
('nickcomesfirst', 'foo@example.com', 'nickcomesfirst'), ('nickcomesfirst', 'foo@example.com', 'nickcomesfirst'),
('', 'foo@example.com', 'openiduser'), ('', 'foo@example.com', 'openiduser'),
...@@ -274,9 +363,89 @@ class OpenIDBackendTests(TestCase): ...@@ -274,9 +363,89 @@ class OpenIDBackendTests(TestCase):
('', '@%.-', 'openiduser'), ('', '@%.-', 'openiduser'),
('', '', 'openiduser'), ('', '', 'openiduser'),
(None, None, 'openiduser')]: (None, None, 'openiduser')]:
self.assertEqual( self.message.set_sreg_args(nickname=nick, email=email)
expected, user = self.backend.authenticate(
self.backend._get_preferred_username(nick, email)) openid_response=self.message.to_response())
# Cleanup user for further tests
user.delete()
self.assertIsNotNone(user)
self.assertEqual(user.username, expected)
@override_settings(
OPENID_CREATE_USERS=True,
OPENID_FOLLOW_RENAMES=False,
OPENID_UPDATE_DETAILS_FROM_SREG=True)
def test_auth_username_duplicate_numbering(self):
# Setup existing user to conflict with
User.objects.create_user('testuser')
self.message.set_sreg_args(nickname='testuser')
user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertIsNotNone(user)
self.assertEqual(
user.username, 'testuser2',
'Username must contain numeric suffix to avoid collisions.')
def test_auth_username_duplicate_numbering_with_conflicts(self):
# Setup existing users to conflict with
User.objects.create_user('testuser')
User.objects.create_user('testuser3')
self.message.set_sreg_args(nickname='testuser')
user = self.backend.authenticate(
openid_response=self.message.to_response())
# Since this username is already taken by someone else, we go through
# the process of adding +i to it starting with the count of users with
# username starting with 'testuser', of which there are 2. i should
# start at 3, which already exists, so it should skip to 4.
self.assertIsNotNone(user)
self.assertEqual(
user.username, 'testuser4',
'Username must contain numeric suffix to avoid collisions.')
def test_auth_username_duplicate_numbering_with_holes(self):
# Setup existing users to conflict with
User.objects.create_user('testuser')
User.objects.create_user('testuser1')
User.objects.create_user('testuser6')
User.objects.create_user('testuser7')
User.objects.create_user('testuser8')
self.message.set_sreg_args(nickname='testuser')
user = self.backend.authenticate(
openid_response=self.message.to_response())
# Since this username is already taken by someone else, we go through
# the process of adding +i to it starting with the count of users with
# username starting with 'testuser', of which there are 5. i should
# start at 6, and increment until it reaches 9.
self.assertIsNotNone(user)
self.assertEqual(
user.username, 'testuser9',
'Username must contain numeric suffix to avoid collisions.')
def test_auth_username_duplicate_numbering_with_nonsequential_matches(
self):
# Setup existing users to conflict with
User.objects.create_user('testuser')
User.objects.create_user('testuserfoo')
self.message.set_sreg_args(nickname='testuser')
user = self.backend.authenticate(
openid_response=self.message.to_response())
# Since this username is already taken by someone else, we go through
# the process of adding +i to it starting with the count of users with
# username starting with 'testuser', of which there are 2. i should
# start at 3, which will be available.
self.assertIsNotNone(user)
self.assertEqual(
user.username, 'testuser3',
'Username must contain numeric suffix to avoid collisions.')
@override_settings( @override_settings(
OPENID_LAUNCHPAD_TEAMS_MAPPING_AUTO=True, OPENID_LAUNCHPAD_TEAMS_MAPPING_AUTO=True,
...@@ -284,10 +453,10 @@ class OpenIDBackendTests(TestCase): ...@@ -284,10 +453,10 @@ class OpenIDBackendTests(TestCase):
def test_authenticate_when_not_member_of_teams_required(self): def test_authenticate_when_not_member_of_teams_required(self):
Group.objects.create(name='team') Group.objects.create(name='team')
response = self.make_openid_response( self.message.set_sreg_args(nickname='someuser')
sreg_args=dict(nickname='someuser'), self.message.set_team_args(is_member='foo')
teams_args=dict(is_member='foo')) user = self.backend.authenticate(
user = self.backend.authenticate(openid_response=response) openid_response=self.message.to_response())
self.assertIsNone(user) self.assertIsNone(user)
...@@ -297,10 +466,10 @@ class OpenIDBackendTests(TestCase): ...@@ -297,10 +466,10 @@ class OpenIDBackendTests(TestCase):
def test_authenticate_when_no_group_mapping_to_required_team(self): def test_authenticate_when_no_group_mapping_to_required_team(self):
assert Group.objects.filter(name='team').count() == 0 assert Group.objects.filter(name='team').count() == 0
response = self.make_openid_response( self.message.set_sreg_args(nickname='someuser')
sreg_args=dict(nickname='someuser'), self.message.set_team_args(is_member='foo')
teams_args=dict(is_member='foo')) user = self.backend.authenticate(
user = self.backend.authenticate(openid_response=response) openid_response=self.message.to_response())
self.assertIsNone(user) self.assertIsNone(user)
...@@ -310,19 +479,19 @@ class OpenIDBackendTests(TestCase): ...@@ -310,19 +479,19 @@ class OpenIDBackendTests(TestCase):
def test_authenticate_when_member_of_teams_required(self): def test_authenticate_when_member_of_teams_required(self):
Group.objects.create(name='team') Group.objects.create(name='team')
response = self.make_openid_response( self.message.set_sreg_args(nickname='someuser')
sreg_args=dict(nickname='someuser'), self.message.set_team_args(is_member='foo,team')
teams_args=dict(is_member='foo,team')) user = self.backend.authenticate(
user = self.backend.authenticate(openid_response=response) openid_response=self.message.to_response())
self.assertIsNotNone(user) self.assertIsNotNone(user)
@override_settings(OPENID_LAUNCHPAD_TEAMS_REQUIRED=[]) @override_settings(OPENID_LAUNCHPAD_TEAMS_REQUIRED=[])
def test_authenticate_when_no_teams_required(self): def test_authenticate_when_no_teams_required(self):
response = self.make_openid_response( self.message.set_sreg_args(nickname='someuser')
sreg_args=dict(nickname='someuser'), self.message.set_team_args(is_member='team')
teams_args=dict(is_member='team')) user = self.backend.authenticate(
user = self.backend.authenticate(openid_response=response) openid_response=self.message.to_response())
self.assertIsNotNone(user) self.assertIsNotNone(user)
...@@ -332,10 +501,10 @@ class OpenIDBackendTests(TestCase): ...@@ -332,10 +501,10 @@ class OpenIDBackendTests(TestCase):
def test_authenticate_when_member_of_at_least_one_team(self): def test_authenticate_when_member_of_at_least_one_team(self):
Group.objects.create(name='team1') Group.objects.create(name='team1')
response = self.make_openid_response( self.message.set_sreg_args(nickname='someuser')
sreg_args=dict(nickname='someuser'), self.message.set_team_args(is_member='foo,team1')
teams_args=dict(is_member='foo,team1')) user = self.backend.authenticate(
user = self.backend.authenticate(openid_response=response) openid_response=self.message.to_response())
self.assertIsNotNone(user) self.assertIsNotNone(user)
...@@ -347,17 +516,17 @@ class OpenIDBackendTests(TestCase): ...@@ -347,17 +516,17 @@ class OpenIDBackendTests(TestCase):
self): self):
assert Group.objects.filter(name='team').count() == 0 assert Group.objects.filter(name='team').count() == 0
response = self.make_openid_response( self.message.set_sreg_args(
sreg_args=dict(nickname='someuser', email='foo@foo.com'), nickname='someuser', email='foo@foo.com')
teams_args=dict(is_member='foo')) user = self.backend.authenticate(
user = self.backend.authenticate(openid_response=response) openid_response=self.message.to_response())
self.assertIsNotNone(user) self.assertIsNotNone(user)
response = self.make_openid_response( self.message.set_sreg_args(
sreg_args=dict(nickname='someuser', email='foo+bar@foo.com'), nickname='someuser', email='foo+bar@foo.com')
teams_args=dict(is_member='foo')) user = self.backend.authenticate(
user = self.backend.authenticate(openid_response=response) openid_response=self.message.to_response())
self.assertIsNotNone(user) self.assertIsNotNone(user)
...@@ -368,10 +537,9 @@ class OpenIDBackendTests(TestCase): ...@@ -368,10 +537,9 @@ class OpenIDBackendTests(TestCase):
def test_authenticate_whitelisted_email_multiple_patterns(self): def test_authenticate_whitelisted_email_multiple_patterns(self):
assert Group.objects.filter(name='team').count() == 0 assert Group.objects.filter(name='team').count() == 0
response = self.make_openid_response( self.message.set_sreg_args(nickname='someuser', email='bar@foo.com')
sreg_args=dict(nickname='someuser', email='bar@foo.com'), user = self.backend.authenticate(
teams_args=dict(is_member='foo')) openid_response=self.message.to_response())
user = self.backend.authenticate(openid_response=response)
self.assertIsNotNone(user) self.assertIsNotNone(user)
...@@ -382,9 +550,530 @@ class OpenIDBackendTests(TestCase): ...@@ -382,9 +550,530 @@ class OpenIDBackendTests(TestCase):
def test_authenticate_whitelisted_email_not_match(self): def test_authenticate_whitelisted_email_not_match(self):
assert Group.objects.filter(name='team').count() == 0 assert Group.objects.filter(name='team').count() == 0
response = self.make_openid_response( self.message.set_sreg_args(nickname='someuser', email='bar@foo.com')
sreg_args=dict(nickname='someuser', email='bar@foo.com'), self.message.set_team_args(is_member='foo')
teams_args=dict(is_member='foo')) user = self.backend.authenticate(
user = self.backend.authenticate(openid_response=response) openid_response=self.message.to_response())
self.assertIsNone(user)
def test_auth_no_response(self):
self.assertIsNone(self.backend.authenticate())
self.assert_no_users_created()
def test_auth_cancel_response(self):
response = CancelResponse(OpenIDServiceEndpoint())
self.assertIsNone(self.backend.authenticate(openid_response=response))
self.assert_no_users_created()
def test_auth_failure_response(self):
response = FailureResponse(OpenIDServiceEndpoint())
self.assertIsNone(self.backend.authenticate(openid_response=response))
self.assert_no_users_created()
def test_auth_setup_needed_response(self):
response = SetupNeededResponse(OpenIDServiceEndpoint())
self.assertIsNone(self.backend.authenticate(openid_response=response))
self.assert_no_users_created()
@override_settings(OPENID_CREATE_USERS=False)
def test_auth_no_create_users(self):
user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertIsNone(user) self.assertIsNone(user)
self.assert_no_users_created()
@override_settings(OPENID_CREATE_USERS=False)
def test_auth_no_create_users_existing_user(self):
existing_openid = self.make_user_openid(
claimed_id=self.message.endpoint.claimed_id)
expected_user_count = User.objects.count()
user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertIsNotNone(user)
self.assertEqual(user, existing_openid.user)
self.assert_no_users_created(expected_count=expected_user_count)
@override_settings(
OPENID_UPDATE_DETAILS_FROM_SREG=True,
OPENID_VALID_VERIFICATION_SCHEMES={
SERVER_URL: {'token_via_email'}})
def test_auth_update_details_from_sreg(self):
first_name = 'a' * 31
last_name = 'b' * 31
email = 'new@email.com'
self.message.set_ax_args(
fullname=first_name + ' ' + last_name,
nickname='newnickname',
email=email,
first=first_name,
last=last_name,
verified=True,
)
existing_openid = self.make_user_openid(
claimed_id=self.message.endpoint.claimed_id)
original_username = existing_openid.user.username
expected_user_count = User.objects.count()
self.assert_account_not_verified(existing_openid.user)
user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertEqual(user, existing_openid.user)
self.assertEqual(
user.username, original_username,
'Username must not be updated unless OPENID_FOLLOW_RENAMES=True.')
self.assertEqual(user.email, email)
self.assertEqual(user.first_name, first_name[:30])
self.assertEqual(user.last_name, last_name[:30])
self.assert_account_verified(user)
self.assert_no_users_created(expected_count=expected_user_count)
@override_settings(
OPENID_UPDATE_DETAILS_FROM_SREG=True,
OPENID_VALID_VERIFICATION_SCHEMES={
SERVER_URL: {'token_via_email'}})
def test_auth_update_details_from_sreg_unverifies_account(self):
first_name = 'a' * 31
last_name = 'b' * 31
email = 'new@email.com'
kwargs = dict(
fullname=first_name + ' ' + last_name,
nickname='newnickname',
email=email,
first=first_name,
last=last_name,
verified=True,
)
self.message.set_ax_args(**kwargs)
verified_user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assert_account_verified(verified_user)
expected_user_count = User.objects.count()
kwargs['verified'] = False
self.message.set_ax_args(**kwargs)
unverified_user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertEqual(verified_user, unverified_user)
self.assert_account_not_verified(unverified_user)
self.assert_no_users_created(expected_count=expected_user_count)
@override_settings(OPENID_PHYSICAL_MULTIFACTOR_REQUIRED=True)
def test_physical_multifactor_required_not_given(self):
response = self.message.to_response()
with self.assertRaises(MissingPhysicalMultiFactor):
self.backend.authenticate(openid_response=response)
self.assertTrue(
UserOpenID.objects.filter(
claimed_id=self.message.endpoint.claimed_id).exists(),
'User must be created anyways.')
@override_settings(OPENID_PHYSICAL_MULTIFACTOR_REQUIRED=True)
def test_physical_multifactor_required_invalid_auth_policy(self):
self.message.set_pape_args(
pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT)
with self.assertRaises(MissingPhysicalMultiFactor):
self.backend.authenticate(
openid_response=self.message.to_response())
self.assertTrue(
UserOpenID.objects.filter(
claimed_id=self.message.endpoint.claimed_id).exists(),
'User must be created anyways.')
@override_settings(OPENID_PHYSICAL_MULTIFACTOR_REQUIRED=True)
def test_physical_multifactor_required_valid_auth_policy(self):
self.message.set_pape_args(
pape.AUTH_MULTI_FACTOR, pape.AUTH_MULTI_FACTOR_PHYSICAL,
pape.AUTH_PHISHING_RESISTANT)
user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertIsNotNone(user)
@override_settings(OPENID_STRICT_USERNAMES=True)
def test_auth_strict_usernames(self):
username = 'nickname'
self.message.set_sreg_args(nickname=username)
user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertIsNotNone(user, 'User must be created')
self.assertEqual(user.username, username)
@override_settings(OPENID_STRICT_USERNAMES=True)
def test_auth_strict_usernames_no_nickname(self):
self.message.set_sreg_args(nickname='')
msg = re.escape(
"An attribute required for logging in was not returned (nickname)")
with self.assertRaisesRegexp(RequiredAttributeNotReturned, msg):
self.backend.authenticate(
openid_response=self.message.to_response())
self.assert_no_users_created()
@override_settings(
OPENID_STRICT_USERNAMES=True,
OPENID_UPDATE_DETAILS_FROM_SREG=True)
def test_auth_strict_usernames_conflict(self):
existing_openid = self.make_user_openid()
expected_user_count = User.objects.count()
self.message.set_sreg_args(
nickname=existing_openid.user.username)
with self.assertRaises(DuplicateUsernameViolation):
self.backend.authenticate(
openid_response=self.message.to_response())
self.assert_no_users_created(expected_count=expected_user_count)
@override_settings(
OPENID_FOLLOW_RENAMES=True,
OPENID_UPDATE_DETAILS_FROM_SREG=True)
def test_auth_follow_renames(self):
new_username = 'new'
self.message.set_sreg_args(nickname='username')
user = self.backend.authenticate(
openid_response=self.message.to_response())
expected_user_count = User.objects.count()
self.assertIsNotNone(user, 'User must be created')
self.message.set_sreg_args(nickname=new_username)
renamed_user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertEqual(user.pk, renamed_user.pk)
self.assertEqual(renamed_user.username, new_username)
self.assert_no_users_created(expected_count=expected_user_count)
@override_settings(
OPENID_FOLLOW_RENAMES=True,
OPENID_STRICT_USERNAMES=True,
OPENID_UPDATE_DETAILS_FROM_SREG=True)
def test_auth_follow_renames_strict_usernames_no_nickname(self):
self.message.set_sreg_args(nickname='nickame')
user = self.backend.authenticate(
openid_response=self.message.to_response())
expected_user_count = User.objects.count()
self.assertIsNotNone(user, 'User must be created')
self.message.set_sreg_args(nickname='')
# XXX: Check possibilities to normalize this error into a
# `RequiredAttributeNotReturned`.
with self.assertRaises(MissingUsernameViolation):
self.backend.authenticate(
openid_response=self.message.to_response())
self.assert_no_users_created(expected_count=expected_user_count)
@override_settings(
OPENID_FOLLOW_RENAMES=True,
OPENID_STRICT_USERNAMES=True,
OPENID_UPDATE_DETAILS_FROM_SREG=True)
def test_auth_follow_renames_strict_usernames_rename_conflict(self):
# Setup existing user to conflict with
User.objects.create_user('testuser')
self.message.set_sreg_args(nickname='nickname')
user = self.backend.authenticate(
openid_response=self.message.to_response())
expected_user_count = User.objects.count()
self.assertIsNotNone(user, 'First request should succeed')
self.message.set_sreg_args(nickname='testuser')
with self.assertRaises(DuplicateUsernameViolation):
self.backend.authenticate(
openid_response=self.message.to_response())
db_user = User.objects.get(pk=user.pk)
self.assertEqual(db_user.username, 'nickname')
self.assert_no_users_created(expected_count=expected_user_count)
@override_settings(
OPENID_FOLLOW_RENAMES=True,
OPENID_STRICT_USERNAMES=False,
OPENID_UPDATE_DETAILS_FROM_SREG=True)
def test_auth_follow_renames_to_conflict(self):
# Setup existing user to conflict with
User.objects.create_user('testuser')
# Setup user to rename
user = User.objects.create_user('nickname')
self.make_user_openid(
user=user, claimed_id=self.message.endpoint.claimed_id)
# Trigger rename
self.message.set_sreg_args(nickname='testuser')
renamed_user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertEqual(renamed_user.pk, user.pk)
self.assertEqual(
renamed_user.username, 'testuser2',
'Username must have a numbered suffix to avoid conflict.')
@override_settings(
OPENID_FOLLOW_RENAMES=True,
OPENID_UPDATE_DETAILS_FROM_SREG=True)
def test_auth_follow_renames_no_change(self):
# Setup user to rename
user = User.objects.create_user('username')
self.make_user_openid(
user=user, claimed_id=self.message.endpoint.claimed_id)
expected_user_count = User.objects.count()
# Trigger rename
self.message.set_sreg_args(nickname=user.username)
renamed_user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertEqual(renamed_user.pk, user.pk)
self.assertEqual(
renamed_user.username, user.username,
'No numeric suffix should be appended for username owner.')
self.assert_no_users_created(expected_count=expected_user_count)
@override_settings(
OPENID_FOLLOW_RENAMES=True,
OPENID_UPDATE_DETAILS_FROM_SREG=True)
def test_auth_follow_renames_to_numbered_suffix(self):
# Setup user to rename to numbered suffix pattern
user = User.objects.create_user('testuser2000eight')
self.make_user_openid(
user=user, claimed_id=self.message.endpoint.claimed_id)
# Trigger rename
self.message.set_sreg_args(nickname='testuser2')
renamed_user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertEqual(renamed_user.pk, user.pk)
self.assertEqual(
renamed_user.username, 'testuser2',
'The numbered suffix must be kept.')
@override_settings(
OPENID_FOLLOW_RENAMES=True,
OPENID_UPDATE_DETAILS_FROM_SREG=True)
def test_auth_follow_renames_to_numbered_suffix_with_existing(self):
# Setup existing user to conflict with
User.objects.create_user('testuser')
# Setup user to rename to numbered suffix pattern
user = User.objects.create_user('testuser2000eight')
self.make_user_openid(
user=user, claimed_id=self.message.endpoint.claimed_id)
# Trigger rename
self.message.set_sreg_args(nickname='testuser3')
renamed_user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertEqual(renamed_user.pk, user.pk)
self.assertEqual(
renamed_user.username, 'testuser3',
'Username must be kept as there are no conflicts.')
@override_settings(
OPENID_FOLLOW_RENAMES=True,
OPENID_UPDATE_DETAILS_FROM_SREG=True)
def test_auth_follow_renames_from_numbered_suffix_to_conflict(self):
# Setup existing user to conflict with
User.objects.create_user('testuser')
# Setup user with numbered suffix pattern
user = User.objects.create_user('testuser2000')
self.make_user_openid(
user=user, claimed_id=self.message.endpoint.claimed_id)
# Trigger rename
self.message.set_sreg_args(nickname='testuser')
renamed_user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertEqual(renamed_user.pk, user.pk)
self.assertEqual(
user.username, 'testuser2000',
'Since testuser conflicts, username must remain unchanged as it '
'maches the number suffix pattern.')
@override_settings(
OPENID_FOLLOW_RENAMES=True,
OPENID_UPDATE_DETAILS_FROM_SREG=True)
def test_auth_follow_renames_from_numbered_suffix_no_conflict(self):
# Setup user with numbered suffix pattern
user = User.objects.create_user('testuser2')
self.make_user_openid(
user=user, claimed_id=self.message.endpoint.claimed_id)
# Trigger rename
self.message.set_sreg_args(nickname='testuser')
renamed_user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertEqual(renamed_user.pk, user.pk)
self.assertEqual(
renamed_user.username, 'testuser',
'Username must be updated as there are no conflicts.')
@override_settings(OPENID_STRICT_USERNAMES=True)
def test_auth_duplicate_username_signal_is_sent(self):
existing_openid = self.make_user_openid()
expected_user_count = User.objects.count()
signal_kwargs = {}
def duplicate_username_handler(sender, **kwargs):
signal_kwargs.update(kwargs)
self.addCleanup(
openid_duplicate_username.disconnect,
duplicate_username_handler, sender=User, dispatch_uid='testing')
openid_duplicate_username.connect(
duplicate_username_handler, sender=User, weak=False,
dispatch_uid='testing')
self.message.set_sreg_args(
nickname=existing_openid.user.username)
with self.assertRaises(DuplicateUsernameViolation):
self.backend.authenticate(
openid_response=self.message.to_response())
self.assertIn('username', signal_kwargs)
self.assertEqual(
signal_kwargs['username'], existing_openid.user.username)
self.assert_no_users_created(expected_count=expected_user_count)
@override_settings(OPENID_STRICT_USERNAMES=True)
def test_auth_duplicate_username_signal_can_prevent_duplicate_error(self):
existing_openid = self.make_user_openid()
def duplicate_username_handler(sender, **kwargs):
existing_user = existing_openid.user
existing_user.username += '_other'
existing_user.save()
self.addCleanup(
openid_duplicate_username.disconnect,
duplicate_username_handler, sender=User, dispatch_uid='testing')
openid_duplicate_username.connect(
duplicate_username_handler, sender=User, weak=False,
dispatch_uid='testing')
self.message.set_sreg_args(
nickname=existing_openid.user.username)
user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertIsNotNone(user)
self.assertNotEqual(user, existing_openid.user)
@override_settings(OPENID_STRICT_USERNAMES=True)
def test_auth_duplicate_username_is_not_called_if_no_conflict(self):
def duplicate_username_handler(sender, **kwargs):
assert False, 'This should never have been called.'
self.addCleanup(
openid_duplicate_username.disconnect,
duplicate_username_handler, sender=User, dispatch_uid='testing')
openid_duplicate_username.connect(
duplicate_username_handler, sender=User, weak=False,
dispatch_uid='testing')
self.message.set_sreg_args(nickname='nickname')
self.backend.authenticate(openid_response=self.message.to_response())
@override_settings(OPENID_STRICT_USERNAMES=False)
def test_auth_duplicate_username_is_not_called_if_not_strict(self):
existing_openid = self.make_user_openid()
def duplicate_username_handler(sender, **kwargs):
assert False, 'This should never have been called.'
self.addCleanup(
openid_duplicate_username.disconnect,
duplicate_username_handler, sender=User, dispatch_uid='testing')
openid_duplicate_username.connect(
duplicate_username_handler, sender=User, weak=False,
dispatch_uid='testing')
self.message.set_sreg_args(nickname=existing_openid.user.username)
self.backend.authenticate(openid_response=self.message.to_response())
@override_settings(OPENID_STRICT_USERNAMES=True)
def test_auth_duplicate_username_handling_bypass_numbered_suffix(self):
nickname = 'nickname87'
existing_openid = self.make_user_openid(
user=User.objects.create_user(nickname))
def duplicate_username_handler(sender, **kwargs):
existing_user = existing_openid.user
existing_user.username += '00'
existing_user.save()
self.addCleanup(
openid_duplicate_username.disconnect,
duplicate_username_handler, sender=User, dispatch_uid='testing')
openid_duplicate_username.connect(
duplicate_username_handler, sender=User, weak=False,
dispatch_uid='testing')
self.message.set_sreg_args(nickname=existing_openid.user.username)
user = self.backend.authenticate(
openid_response=self.message.to_response())
self.assertIsNotNone(user)
self.assertNotEqual(user, existing_openid.user)
self.assertEqual(
user.username, nickname,
'In strict mode, when conflicts are handled, the username must '
'be kept unmodified without numbered suffixes.')
class GetGroupModelTestCase(TestCase):
def setUp(self):
super(GetGroupModelTestCase, self).setUp()
self.inject_test_models()
def inject_test_models(self):
installed_apps = settings.INSTALLED_APPS + (
'django_openid_auth.tests',
)
p = self.settings(INSTALLED_APPS=installed_apps)
p.enable()
self.addCleanup(p.disable)
self.clear_app_cache()
def clear_app_cache(self):
try:
from django.apps import apps
apps.clear_cache()
except ImportError:
from django.db.models.loading import cache
cache.loaded = False
def test_default_group_model(self):
model = get_user_group_model()
self.assertEqual(model, User.groups.through)
@override_settings(AUTH_USER_GROUP_MODEL='tests.UserGroup')
def test_custom_group_model(self):
from django_openid_auth.tests.models import UserGroup
model = get_user_group_model()
self.assertEqual(model, UserGroup)
@override_settings(
AUTH_USER_GROUP_MODEL='django_openid_auth.models.UserGroup')
def test_improperly_configured_invalid_name(self):
self.assertRaises(ImproperlyConfigured, get_user_group_model)
@override_settings(
AUTH_USER_GROUP_MODEL='invalid.UserGroup')
def test_improperly_configured_invalid_app(self):
self.assertRaises(ImproperlyConfigured, get_user_group_model)
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import base64
import time import time
from django.test import TestCase from django.test import TestCase
...@@ -52,7 +53,8 @@ class OpenIDStoreTests(TestCase): ...@@ -52,7 +53,8 @@ class OpenIDStoreTests(TestCase):
server_url='server-url', handle='handle') server_url='server-url', handle='handle')
self.assertEquals(dbassoc.server_url, 'server-url') self.assertEquals(dbassoc.server_url, 'server-url')
self.assertEquals(dbassoc.handle, 'handle') self.assertEquals(dbassoc.handle, 'handle')
self.assertEquals(dbassoc.secret, 'secret'.encode('base-64')) self.assertEquals(
dbassoc.secret, base64.encodestring(b'secret').decode('utf-8'))
self.assertEquals(dbassoc.issued, 42) self.assertEquals(dbassoc.issued, 42)
self.assertEquals(dbassoc.lifetime, 600) self.assertEquals(dbassoc.lifetime, 600)
self.assertEquals(dbassoc.assoc_type, 'HMAC-SHA1') self.assertEquals(dbassoc.assoc_type, 'HMAC-SHA1')
...@@ -66,7 +68,8 @@ class OpenIDStoreTests(TestCase): ...@@ -66,7 +68,8 @@ class OpenIDStoreTests(TestCase):
self.store.storeAssociation('server-url', assoc) self.store.storeAssociation('server-url', assoc)
dbassoc = Association.objects.get( dbassoc = Association.objects.get(
server_url='server-url', handle='handle') server_url='server-url', handle='handle')
self.assertEqual(dbassoc.secret, 'secret2'.encode('base-64')) self.assertEqual(
dbassoc.secret, base64.encodestring(b'secret2').decode('utf-8'))
self.assertEqual(dbassoc.issued, 420) self.assertEqual(dbassoc.issued, 420)
self.assertEqual(dbassoc.lifetime, 900) self.assertEqual(dbassoc.lifetime, 900)
self.assertEqual(dbassoc.assoc_type, 'HMAC-SHA256') self.assertEqual(dbassoc.assoc_type, 'HMAC-SHA256')
...@@ -80,7 +83,7 @@ class OpenIDStoreTests(TestCase): ...@@ -80,7 +83,7 @@ class OpenIDStoreTests(TestCase):
self.assertTrue(isinstance(assoc, OIDAssociation)) self.assertTrue(isinstance(assoc, OIDAssociation))
self.assertEquals(assoc.handle, 'handle') self.assertEquals(assoc.handle, 'handle')
self.assertEquals(assoc.secret, 'secret') self.assertEquals(assoc.secret, b'secret')
self.assertEquals(assoc.issued, timestamp) self.assertEquals(assoc.issued, timestamp)
self.assertEquals(assoc.lifetime, 600) self.assertEquals(assoc.lifetime, 600)
self.assertEquals(assoc.assoc_type, 'HMAC-SHA1') self.assertEquals(assoc.assoc_type, 'HMAC-SHA1')
......
...@@ -31,13 +31,17 @@ from __future__ import unicode_literals ...@@ -31,13 +31,17 @@ from __future__ import unicode_literals
import cgi import cgi
from urlparse import parse_qs try:
from urllib.parse import parse_qs
except ImportError:
from urlparse import parse_qs
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import User, Group, Permission from django.contrib.auth.models import User, Group, Permission
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory
from django.test.utils import override_settings from django.test.utils import override_settings
from mock import patch from mock import patch
from openid.consumer.consumer import Consumer, SuccessResponse from openid.consumer.consumer import Consumer, SuccessResponse
...@@ -54,8 +58,9 @@ from django_openid_auth import teams ...@@ -54,8 +58,9 @@ from django_openid_auth import teams
from django_openid_auth.models import UserOpenID from django_openid_auth.models import UserOpenID
from django_openid_auth.tests.helpers import override_session_serializer from django_openid_auth.tests.helpers import override_session_serializer
from django_openid_auth.views import ( from django_openid_auth.views import (
sanitise_redirect_url, get_request_data,
make_consumer, make_consumer,
sanitise_redirect_url,
) )
from django_openid_auth.signals import openid_login_complete from django_openid_auth.signals import openid_login_complete
from django_openid_auth.store import DjangoOpenIDStore from django_openid_auth.store import DjangoOpenIDStore
...@@ -123,8 +128,8 @@ class StubOpenIDProvider(HTTPFetcher): ...@@ -123,8 +128,8 @@ class StubOpenIDProvider(HTTPFetcher):
def parseFormPost(self, content): def parseFormPost(self, content):
"""Parse an HTML form post to create an OpenID request.""" """Parse an HTML form post to create an OpenID request."""
# Hack to make the javascript XML compliant ... # Hack to make the javascript XML compliant ...
content = content.replace('i < elements.length', content = content.replace(
'i &lt; elements.length') 'i < elements.length', 'i &lt; elements.length')
tree = ET.XML(content) tree = ET.XML(content)
form = tree.find('.//form') form = tree.find('.//form')
assert form is not None, 'No form in document' assert form is not None, 'No form in document'
...@@ -135,8 +140,7 @@ class StubOpenIDProvider(HTTPFetcher): ...@@ -135,8 +140,7 @@ class StubOpenIDProvider(HTTPFetcher):
for input in form.findall('input'): for input in form.findall('input'):
if input.get('type') != 'hidden': if input.get('type') != 'hidden':
continue continue
query[input.get('name').encode('UTF-8')] = \ query[input.get('name')] = input.get('value')
input.get('value').encode('UTF-8')
self.last_request = self.server.decodeRequest(query) self.last_request = self.server.decodeRequest(query)
return self.last_request return self.last_request
...@@ -163,13 +167,6 @@ class DummyDjangoRequest(object): ...@@ -163,13 +167,6 @@ class DummyDjangoRequest(object):
def build_absolute_uri(self): def build_absolute_uri(self):
return self.META['SCRIPT_NAME'] + self.request_path return self.META['SCRIPT_NAME'] + self.request_path
def _combined_request(self):
request = {}
request.update(self.POST)
request.update(self.GET)
return request
REQUEST = property(_combined_request)
@override_session_serializer @override_session_serializer
@override_settings( @override_settings(
...@@ -184,10 +181,10 @@ class DummyDjangoRequest(object): ...@@ -184,10 +181,10 @@ class DummyDjangoRequest(object):
OPENID_SREG_REQUIRED_FIELDS=[], OPENID_SREG_REQUIRED_FIELDS=[],
OPENID_USE_EMAIL_FOR_USERNAME=False, OPENID_USE_EMAIL_FOR_USERNAME=False,
OPENID_VALID_VERIFICATION_SCHEMES={}, OPENID_VALID_VERIFICATION_SCHEMES={},
ROOT_URLCONF='django_openid_auth.tests.urls',
) )
class RelyingPartyTests(TestCase): class RelyingPartyTests(TestCase):
urls = 'django_openid_auth.tests.urls'
login_url = reverse('openid-login') login_url = reverse('openid-login')
def setUp(self): def setUp(self):
...@@ -241,7 +238,8 @@ class RelyingPartyTests(TestCase): ...@@ -241,7 +238,8 @@ class RelyingPartyTests(TestCase):
response = self.client.post(self.login_url, self.openid_req) response = self.client.post(self.login_url, self.openid_req)
self.assertContains(response, 'OpenID transaction in progress') self.assertContains(response, 'OpenID transaction in progress')
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
self.assertEqual(openid_request.mode, 'checkid_setup') self.assertEqual(openid_request.mode, 'checkid_setup')
self.assertTrue(openid_request.return_to.startswith( self.assertTrue(openid_request.return_to.startswith(
'http://testserver/openid/complete/')) 'http://testserver/openid/complete/'))
...@@ -253,7 +251,7 @@ class RelyingPartyTests(TestCase): ...@@ -253,7 +251,7 @@ class RelyingPartyTests(TestCase):
# And they are now logged in: # And they are now logged in:
response = self.client.get('/getuser/') response = self.client.get('/getuser/')
self.assertEqual(response.content, 'someuser') self.assertEqual(response.content.decode('utf-8'), 'someuser')
def test_login_with_nonascii_return_to(self): def test_login_with_nonascii_return_to(self):
"""Ensure non-ascii characters can be used for the 'next' arg.""" """Ensure non-ascii characters can be used for the 'next' arg."""
...@@ -275,7 +273,8 @@ class RelyingPartyTests(TestCase): ...@@ -275,7 +273,8 @@ class RelyingPartyTests(TestCase):
self.assertContains(response, 'OpenID transaction in progress') self.assertContains(response, 'OpenID transaction in progress')
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
self.assertEqual(openid_request.mode, 'checkid_setup') self.assertEqual(openid_request.mode, 'checkid_setup')
self.assertTrue(openid_request.return_to.startswith( self.assertTrue(openid_request.return_to.startswith(
'http://testserver/openid/complete/')) 'http://testserver/openid/complete/'))
...@@ -302,7 +301,8 @@ class RelyingPartyTests(TestCase): ...@@ -302,7 +301,8 @@ class RelyingPartyTests(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertContains(response, 'OpenID transaction in progress') self.assertContains(response, 'OpenID transaction in progress')
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
self.assertEqual(openid_request.mode, 'checkid_setup') self.assertEqual(openid_request.mode, 'checkid_setup')
self.assertTrue(openid_request.return_to.startswith( self.assertTrue(openid_request.return_to.startswith(
'http://testserver/openid/complete/')) 'http://testserver/openid/complete/'))
...@@ -314,7 +314,7 @@ class RelyingPartyTests(TestCase): ...@@ -314,7 +314,7 @@ class RelyingPartyTests(TestCase):
# And they are now logged in: # And they are now logged in:
response = self.client.get('/getuser/') response = self.client.get('/getuser/')
self.assertEqual(response.content, 'someuser') self.assertEqual(response.content.decode('utf-8'), 'someuser')
def test_login_create_users(self): def test_login_create_users(self):
# Create a user with the same name as we'll pass back via sreg. # Create a user with the same name as we'll pass back via sreg.
...@@ -326,7 +326,8 @@ class RelyingPartyTests(TestCase): ...@@ -326,7 +326,8 @@ class RelyingPartyTests(TestCase):
# Complete the request, passing back some simple registration # Complete the request, passing back some simple registration
# data. The user is redirected to the next URL. # data. The user is redirected to the next URL.
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request)
openid_response = openid_request.answer(True) openid_response = openid_request.answer(True)
sreg_response = sreg.SRegResponse.extractResponse( sreg_response = sreg.SRegResponse.extractResponse(
...@@ -340,7 +341,7 @@ class RelyingPartyTests(TestCase): ...@@ -340,7 +341,7 @@ class RelyingPartyTests(TestCase):
# And they are now logged in as a new user (they haven't taken # And they are now logged in as a new user (they haven't taken
# over the existing "someuser" user). # over the existing "someuser" user).
response = self.client.get('/getuser/') response = self.client.get('/getuser/')
self.assertEqual(response.content, 'someuser2') self.assertEqual(response.content.decode('utf-8'), 'someuser2')
# Check the details of the new user. # Check the details of the new user.
user = User.objects.get(username='someuser2') user = User.objects.get(username='someuser2')
...@@ -364,7 +365,8 @@ class RelyingPartyTests(TestCase): ...@@ -364,7 +365,8 @@ class RelyingPartyTests(TestCase):
# Complete the request, passing back some simple registration # Complete the request, passing back some simple registration
# data. The user is redirected to the next URL. # data. The user is redirected to the next URL.
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
return openid_request return openid_request
def _get_login_response(self, openid_request, resp_data, use_sreg, def _get_login_response(self, openid_request, resp_data, use_sreg,
...@@ -388,7 +390,8 @@ class RelyingPartyTests(TestCase): ...@@ -388,7 +390,8 @@ class RelyingPartyTests(TestCase):
self.provider.type_uris.append(pape.ns_uri) self.provider.type_uris.append(pape.ns_uri)
response = self.client.post(self.login_url, self.openid_req) response = self.client.post(self.login_url, self.openid_req)
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
request_auth = openid_request.message.getArg( request_auth = openid_request.message.getArg(
'http://specs.openid.net/extensions/pape/1.0', 'http://specs.openid.net/extensions/pape/1.0',
...@@ -436,7 +439,7 @@ class RelyingPartyTests(TestCase): ...@@ -436,7 +439,7 @@ class RelyingPartyTests(TestCase):
query['openid.pape.auth_policies'], [preferred_auth]) query['openid.pape.auth_policies'], [preferred_auth])
response = self.client.get('/getuser/') response = self.client.get('/getuser/')
self.assertEqual(response.content, 'testuser') self.assertEqual(response.content.decode('utf-8'), 'testuser')
@override_settings(OPENID_PHYSICAL_MULTIFACTOR_REQUIRED=True) @override_settings(OPENID_PHYSICAL_MULTIFACTOR_REQUIRED=True)
def test_login_physical_multifactor_not_provided(self): def test_login_physical_multifactor_not_provided(self):
...@@ -552,10 +555,11 @@ class RelyingPartyTests(TestCase): ...@@ -552,10 +555,11 @@ class RelyingPartyTests(TestCase):
response = self.client.get('/getuser/') response = self.client.get('/getuser/')
# username defaults to 'openiduser' # username defaults to 'openiduser'
self.assertEqual(response.content, 'openiduser') username = response.content.decode('utf-8')
self.assertEqual(username, 'openiduser')
# The user's full name and email have been updated. # The user's full name and email have been updated.
user = User.objects.get(username=response.content) user = User.objects.get(username=username)
self.assertEqual(user.first_name, 'Openid') self.assertEqual(user.first_name, 'Openid')
self.assertEqual(user.last_name, 'User') self.assertEqual(user.last_name, 'User')
self.assertEqual(user.email, 'foo@example.com') self.assertEqual(user.email, 'foo@example.com')
...@@ -570,7 +574,7 @@ class RelyingPartyTests(TestCase): ...@@ -570,7 +574,7 @@ class RelyingPartyTests(TestCase):
response = self.client.get('/getuser/') response = self.client.get('/getuser/')
# username defaults to a munged version of the email # username defaults to a munged version of the email
self.assertEqual(response.content, 'fooexamplecom') self.assertEqual(response.content.decode('utf-8'), 'fooexamplecom')
def test_login_duplicate_username_numbering(self): def test_login_duplicate_username_numbering(self):
# Setup existing user who's name we're going to conflict with # Setup existing user who's name we're going to conflict with
...@@ -587,7 +591,7 @@ class RelyingPartyTests(TestCase): ...@@ -587,7 +591,7 @@ class RelyingPartyTests(TestCase):
# Since this username is already taken by someone else, we go through # Since this username is already taken by someone else, we go through
# the process of adding +i to it, and get testuser2. # the process of adding +i to it, and get testuser2.
self.assertEqual(response.content, 'testuser2') self.assertEqual(response.content.decode('utf-8'), 'testuser2')
def test_login_duplicate_username_numbering_with_conflicts(self): def test_login_duplicate_username_numbering_with_conflicts(self):
# Setup existing user who's name we're going to conflict with # Setup existing user who's name we're going to conflict with
...@@ -607,7 +611,7 @@ class RelyingPartyTests(TestCase): ...@@ -607,7 +611,7 @@ class RelyingPartyTests(TestCase):
# the process of adding +i to it starting with the count of users with # the process of adding +i to it starting with the count of users with
# username starting with 'testuser', of which there are 2. i should # username starting with 'testuser', of which there are 2. i should
# start at 3, which already exists, so it should skip to 4. # start at 3, which already exists, so it should skip to 4.
self.assertEqual(response.content, 'testuser4') self.assertEqual(response.content.decode('utf-8'), 'testuser4')
def test_login_duplicate_username_numbering_with_holes(self): def test_login_duplicate_username_numbering_with_holes(self):
# Setup existing user who's name we're going to conflict with # Setup existing user who's name we're going to conflict with
...@@ -630,7 +634,7 @@ class RelyingPartyTests(TestCase): ...@@ -630,7 +634,7 @@ class RelyingPartyTests(TestCase):
# the process of adding +i to it starting with the count of users with # the process of adding +i to it starting with the count of users with
# username starting with 'testuser', of which there are 5. i should # username starting with 'testuser', of which there are 5. i should
# start at 6, and increment until it reaches 9. # start at 6, and increment until it reaches 9.
self.assertEqual(response.content, 'testuser9') self.assertEqual(response.content.decode('utf-8'), 'testuser9')
def test_login_duplicate_username_numbering_with_nonsequential_matches( def test_login_duplicate_username_numbering_with_nonsequential_matches(
self): self):
...@@ -651,7 +655,7 @@ class RelyingPartyTests(TestCase): ...@@ -651,7 +655,7 @@ class RelyingPartyTests(TestCase):
# the process of adding +i to it starting with the count of users with # the process of adding +i to it starting with the count of users with
# username starting with 'testuser', of which there are 2. i should # username starting with 'testuser', of which there are 2. i should
# start at 3, which will be available. # start at 3, which will be available.
self.assertEqual(response.content, 'testuser3') self.assertEqual(response.content.decode('utf-8'), 'testuser3')
def test_login_follow_rename(self): def test_login_follow_rename(self):
user = User.objects.create_user('testuser', 'someone@example.com') user = User.objects.create_user('testuser', 'someone@example.com')
...@@ -671,10 +675,11 @@ class RelyingPartyTests(TestCase): ...@@ -671,10 +675,11 @@ class RelyingPartyTests(TestCase):
# If OPENID_FOLLOW_RENAMES, they are logged in as # If OPENID_FOLLOW_RENAMES, they are logged in as
# someuser (the passed in nickname has changed the username) # someuser (the passed in nickname has changed the username)
self.assertEqual(response.content, 'someuser') username = response.content.decode('utf-8')
self.assertEqual(username, 'someuser')
# The user's full name and email have been updated. # The user's full name and email have been updated.
user = User.objects.get(username=response.content) user = User.objects.get(username=username)
self.assertEqual(user.first_name, 'Some') self.assertEqual(user.first_name, 'Some')
self.assertEqual(user.last_name, 'User') self.assertEqual(user.last_name, 'User')
self.assertEqual(user.email, 'foo@example.com') self.assertEqual(user.email, 'foo@example.com')
...@@ -696,10 +701,11 @@ class RelyingPartyTests(TestCase): ...@@ -696,10 +701,11 @@ class RelyingPartyTests(TestCase):
response = self.client.get('/getuser/') response = self.client.get('/getuser/')
# Username should not have changed # Username should not have changed
self.assertEqual(response.content, 'testuser') username = response.content.decode('utf-8')
self.assertEqual(username, 'testuser')
# The user's full name and email have been updated. # The user's full name and email have been updated.
user = User.objects.get(username=response.content) user = User.objects.get(username=username)
self.assertEqual(user.first_name, 'Some') self.assertEqual(user.first_name, 'Some')
self.assertEqual(user.last_name, 'User') self.assertEqual(user.last_name, 'User')
self.assertEqual(user.email, 'foo@example.com') self.assertEqual(user.email, 'foo@example.com')
...@@ -735,10 +741,11 @@ class RelyingPartyTests(TestCase): ...@@ -735,10 +741,11 @@ class RelyingPartyTests(TestCase):
# If OPENID_FOLLOW_RENAMES, attempt to change username to 'testuser' # If OPENID_FOLLOW_RENAMES, attempt to change username to 'testuser'
# but since that username is already taken by someone else, we go # but since that username is already taken by someone else, we go
# through the process of adding +i to it, and get testuser2. # through the process of adding +i to it, and get testuser2.
self.assertEqual(response.content, 'testuser2') username = response.content.decode('utf-8')
self.assertEqual(username, 'testuser2')
# The user's full name and email have been updated. # The user's full name and email have been updated.
user = User.objects.get(username=response.content) user = User.objects.get(username=username)
self.assertEqual(user.first_name, 'Rename') self.assertEqual(user.first_name, 'Rename')
self.assertEqual(user.last_name, 'User') self.assertEqual(user.last_name, 'User')
self.assertEqual(user.email, 'rename@example.com') self.assertEqual(user.email, 'rename@example.com')
...@@ -777,10 +784,11 @@ class RelyingPartyTests(TestCase): ...@@ -777,10 +784,11 @@ class RelyingPartyTests(TestCase):
# the username follows the nickname+i scheme, it has non-numbers in the # the username follows the nickname+i scheme, it has non-numbers in the
# suffix, so it's not an auto-generated one. The regular process of # suffix, so it's not an auto-generated one. The regular process of
# renaming to 'testuser' has a conflict, so we get +2 at the end. # renaming to 'testuser' has a conflict, so we get +2 at the end.
self.assertEqual(response.content, 'testuser2') username = response.content.decode('utf-8')
self.assertEqual(username, 'testuser2')
# The user's full name and email have been updated. # The user's full name and email have been updated.
user = User.objects.get(username=response.content) user = User.objects.get(username=username)
self.assertEqual(user.first_name, 'Rename') self.assertEqual(user.first_name, 'Rename')
self.assertEqual(user.last_name, 'User') self.assertEqual(user.last_name, 'User')
self.assertEqual(user.email, 'rename@example.com') self.assertEqual(user.email, 'rename@example.com')
...@@ -817,10 +825,11 @@ class RelyingPartyTests(TestCase): ...@@ -817,10 +825,11 @@ class RelyingPartyTests(TestCase):
# but since that username is already taken by someone else, we go # but since that username is already taken by someone else, we go
# through the process of adding +i to it. Since the user for this # through the process of adding +i to it. Since the user for this
# identity url already has a name matching that pattern, check if first # identity url already has a name matching that pattern, check if first
self.assertEqual(response.content, 'testuser2000') username = response.content.decode('utf-8')
self.assertEqual(username, 'testuser2000')
# The user's full name and email have been updated. # The user's full name and email have been updated.
user = User.objects.get(username=response.content) user = User.objects.get(username=username)
self.assertEqual(user.first_name, 'Rename') self.assertEqual(user.first_name, 'Rename')
self.assertEqual(user.last_name, 'User') self.assertEqual(user.last_name, 'User')
self.assertEqual(user.email, 'rename@example.com') self.assertEqual(user.email, 'rename@example.com')
...@@ -847,10 +856,11 @@ class RelyingPartyTests(TestCase): ...@@ -847,10 +856,11 @@ class RelyingPartyTests(TestCase):
# If OPENID_FOLLOW_RENAMES, username should be changed to 'testuser' # If OPENID_FOLLOW_RENAMES, username should be changed to 'testuser'
# because it wasn't currently taken # because it wasn't currently taken
self.assertEqual(response.content, 'testuser') username = response.content.decode('utf-8')
self.assertEqual(username, 'testuser')
# The user's full name and email have been updated. # The user's full name and email have been updated.
user = User.objects.get(username=response.content) user = User.objects.get(username=username)
self.assertEqual(user.first_name, 'Same') self.assertEqual(user.first_name, 'Same')
self.assertEqual(user.last_name, 'User') self.assertEqual(user.last_name, 'User')
self.assertEqual(user.email, 'same@example.com') self.assertEqual(user.email, 'same@example.com')
...@@ -865,7 +875,8 @@ class RelyingPartyTests(TestCase): ...@@ -865,7 +875,8 @@ class RelyingPartyTests(TestCase):
# Complete the request, passing back some simple registration # Complete the request, passing back some simple registration
# data. The user is redirected to the next URL. # data. The user is redirected to the next URL.
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request)
openid_response = openid_request.answer(True) openid_response = openid_request.answer(True)
sreg_response = sreg.SRegResponse.extractResponse( sreg_response = sreg.SRegResponse.extractResponse(
...@@ -903,7 +914,8 @@ class RelyingPartyTests(TestCase): ...@@ -903,7 +914,8 @@ class RelyingPartyTests(TestCase):
# Complete the request, passing back some simple registration # Complete the request, passing back some simple registration
# data. The user is redirected to the next URL. # data. The user is redirected to the next URL.
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request)
openid_response = openid_request.answer(True) openid_response = openid_request.answer(True)
sreg_response = sreg.SRegResponse.extractResponse( sreg_response = sreg.SRegResponse.extractResponse(
...@@ -932,7 +944,8 @@ class RelyingPartyTests(TestCase): ...@@ -932,7 +944,8 @@ class RelyingPartyTests(TestCase):
# Complete the request, passing back some simple registration # Complete the request, passing back some simple registration
# data. The user is redirected to the next URL. # data. The user is redirected to the next URL.
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request)
openid_response = openid_request.answer(True) openid_response = openid_request.answer(True)
sreg_response = sreg.SRegResponse.extractResponse( sreg_response = sreg.SRegResponse.extractResponse(
...@@ -974,7 +987,8 @@ class RelyingPartyTests(TestCase): ...@@ -974,7 +987,8 @@ class RelyingPartyTests(TestCase):
# Complete the request, passing back some simple registration # Complete the request, passing back some simple registration
# data. The user is redirected to the next URL. # data. The user is redirected to the next URL.
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request)
openid_response = openid_request.answer(True) openid_response = openid_request.answer(True)
sreg_response = sreg.SRegResponse.extractResponse( sreg_response = sreg.SRegResponse.extractResponse(
...@@ -997,7 +1011,8 @@ class RelyingPartyTests(TestCase): ...@@ -997,7 +1011,8 @@ class RelyingPartyTests(TestCase):
# Complete the request, passing back some simple registration # Complete the request, passing back some simple registration
# data. The user is redirected to the next URL. # data. The user is redirected to the next URL.
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request)
openid_response = openid_request.answer(True) openid_response = openid_request.answer(True)
sreg_response = sreg.SRegResponse.extractResponse( sreg_response = sreg.SRegResponse.extractResponse(
...@@ -1033,10 +1048,11 @@ class RelyingPartyTests(TestCase): ...@@ -1033,10 +1048,11 @@ class RelyingPartyTests(TestCase):
self._do_user_login(self.openid_req, self.openid_resp) self._do_user_login(self.openid_req, self.openid_resp)
response = self.client.get('/getuser/') response = self.client.get('/getuser/')
self.assertEqual(response.content, 'testuser') username = response.content.decode('utf-8')
self.assertEqual(username, 'testuser')
# The user's full name and email have been updated. # The user's full name and email have been updated.
user = User.objects.get(username=response.content) user = User.objects.get(username=username)
self.assertEqual(user.first_name, 'Some') self.assertEqual(user.first_name, 'Some')
self.assertEqual(user.last_name, 'User') self.assertEqual(user.last_name, 'User')
self.assertEqual(user.email, 'foo@example.com') self.assertEqual(user.email, 'foo@example.com')
...@@ -1052,7 +1068,8 @@ class RelyingPartyTests(TestCase): ...@@ -1052,7 +1068,8 @@ class RelyingPartyTests(TestCase):
with self.settings(OPENID_SREG_EXTRA_FIELDS=('language',)): with self.settings(OPENID_SREG_EXTRA_FIELDS=('language',)):
response = self.client.post(self.login_url, self.openid_req) response = self.client.post(self.login_url, self.openid_req)
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request)
for field in ('email', 'fullname', 'nickname', 'language'): for field in ('email', 'fullname', 'nickname', 'language'):
self.assertTrue(field in sreg_request) self.assertTrue(field in sreg_request)
...@@ -1069,7 +1086,8 @@ class RelyingPartyTests(TestCase): ...@@ -1069,7 +1086,8 @@ class RelyingPartyTests(TestCase):
with self.settings(OPENID_SREG_REQUIRED_FIELDS=('email', 'language')): with self.settings(OPENID_SREG_REQUIRED_FIELDS=('email', 'language')):
response = self.client.post(self.login_url, self.openid_req) response = self.client.post(self.login_url, self.openid_req)
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request)
self.assertEqual(['email', 'language'], sreg_request.required) self.assertEqual(['email', 'language'], sreg_request.required)
...@@ -1091,7 +1109,8 @@ class RelyingPartyTests(TestCase): ...@@ -1091,7 +1109,8 @@ class RelyingPartyTests(TestCase):
# The resulting OpenID request uses the Attribute Exchange # The resulting OpenID request uses the Attribute Exchange
# extension rather than the Simple Registration extension. # extension rather than the Simple Registration extension.
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_request = sreg.SRegRequest.fromOpenIDRequest(openid_request)
self.assertEqual(sreg_request.required, []) self.assertEqual(sreg_request.required, [])
self.assertEqual(sreg_request.optional, []) self.assertEqual(sreg_request.optional, [])
...@@ -1137,7 +1156,7 @@ class RelyingPartyTests(TestCase): ...@@ -1137,7 +1156,7 @@ class RelyingPartyTests(TestCase):
assert not settings.OPENID_FOLLOW_RENAMES, ( assert not settings.OPENID_FOLLOW_RENAMES, (
'OPENID_FOLLOW_RENAMES must be False') 'OPENID_FOLLOW_RENAMES must be False')
response = self.client.get('/getuser/') response = self.client.get('/getuser/')
self.assertEqual(response.content, 'testuser') self.assertEqual(response.content.decode('utf-8'), 'testuser')
# The user's full name and email have been updated. # The user's full name and email have been updated.
user = User.objects.get(username='testuser') user = User.objects.get(username='testuser')
...@@ -1221,7 +1240,8 @@ class RelyingPartyTests(TestCase): ...@@ -1221,7 +1240,8 @@ class RelyingPartyTests(TestCase):
self.assertContains(response, 'OpenID transaction in progress') self.assertContains(response, 'OpenID transaction in progress')
# Complete the request # Complete the request
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
openid_response = openid_request.answer(True) openid_response = openid_request.answer(True)
teams_request = teams.TeamsRequest.fromOpenIDRequest(openid_request) teams_request = teams.TeamsRequest.fromOpenIDRequest(openid_request)
teams_response = teams.TeamsResponse.extractResponse( teams_response = teams.TeamsResponse.extractResponse(
...@@ -1236,7 +1256,7 @@ class RelyingPartyTests(TestCase): ...@@ -1236,7 +1256,7 @@ class RelyingPartyTests(TestCase):
# And they are now logged in as testuser # And they are now logged in as testuser
response = self.client.get('/getuser/') response = self.client.get('/getuser/')
self.assertEqual(response.content, 'testuser') self.assertEqual(response.content.decode('utf-8'), 'testuser')
# The user's groups have been updated. # The user's groups have been updated.
User.objects.get(username='testuser') User.objects.get(username='testuser')
...@@ -1268,7 +1288,8 @@ class RelyingPartyTests(TestCase): ...@@ -1268,7 +1288,8 @@ class RelyingPartyTests(TestCase):
OPENID_LAUNCHPAD_TEAMS_MAPPING=mapping, OPENID_LAUNCHPAD_TEAMS_MAPPING=mapping,
OPENID_LAUNCHPAD_TEAMS_MAPPING_AUTO=True, OPENID_LAUNCHPAD_TEAMS_MAPPING_AUTO=True,
OPENID_LAUNCHPAD_TEAMS_MAPPING_AUTO_BLACKLIST=blacklist): OPENID_LAUNCHPAD_TEAMS_MAPPING_AUTO_BLACKLIST=blacklist):
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
openid_request.answer(True) openid_request.answer(True)
teams.TeamsRequest.fromOpenIDRequest(openid_request) teams.TeamsRequest.fromOpenIDRequest(openid_request)
...@@ -1321,7 +1342,8 @@ class RelyingPartyTests(TestCase): ...@@ -1321,7 +1342,8 @@ class RelyingPartyTests(TestCase):
response = self.client.post(self.login_url, self.openid_req_no_next) response = self.client.post(self.login_url, self.openid_req_no_next)
# Complete the request # Complete the request
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
openid_response = openid_request.answer(True) openid_response = openid_request.answer(True)
teams_request = teams.TeamsRequest.fromOpenIDRequest(openid_request) teams_request = teams.TeamsRequest.fromOpenIDRequest(openid_request)
teams_response = teams.TeamsResponse.extractResponse( teams_response = teams.TeamsResponse.extractResponse(
...@@ -1340,7 +1362,8 @@ class RelyingPartyTests(TestCase): ...@@ -1340,7 +1362,8 @@ class RelyingPartyTests(TestCase):
display_id='http://example.com/identity') display_id='http://example.com/identity')
response = self.client.post(self.login_url, self.openid_req_no_next) response = self.client.post(self.login_url, self.openid_req_no_next)
openid_request = self.provider.parseFormPost(response.content) openid_request = self.provider.parseFormPost(
response.content.decode('utf-8'))
openid_response = openid_request.answer(True) openid_response = openid_request.answer(True)
# Use a closure to test whether the signal handler was called. # Use a closure to test whether the signal handler was called.
self.signal_handler_called = False self.signal_handler_called = False
...@@ -1388,3 +1411,24 @@ class HelperFunctionsTest(TestCase): ...@@ -1388,3 +1411,24 @@ class HelperFunctionsTest(TestCase):
self.assertEqual(url, sanitised) self.assertEqual(url, sanitised)
else: else:
self.assertEqual(settings.LOGIN_REDIRECT_URL, sanitised) self.assertEqual(settings.LOGIN_REDIRECT_URL, sanitised)
def test_get_request_data_from_post(self):
request = RequestFactory().post('/', data={'foo': 'bar'})
data = get_request_data(request)
self.assertEqual(dict(data), {'foo': ['bar']})
def test_get_request_data_from_get(self):
request = RequestFactory().get('/', data={'foo': 'bar'})
data = get_request_data(request)
self.assertEqual(dict(data), {'foo': ['bar']})
def test_get_request_data_merged(self):
request = RequestFactory().post('/?baz=42', data={'foo': 'bar'})
data = get_request_data(request)
self.assertEqual(dict(data), {'foo': ['bar'], 'baz': ['42']})
def test_get_request_data_override_order(self):
request = RequestFactory().post('/?foo=42', data={'foo': 'bar'})
data = get_request_data(request)
self.assertEqual(dict(data), {'foo': ['42', 'bar']})
self.assertEqual(data['foo'], 'bar')
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, include from django.conf.urls import include, url
from django.http import HttpResponse from django.http import HttpResponse
...@@ -36,8 +36,7 @@ def get_user(request): ...@@ -36,8 +36,7 @@ def get_user(request):
return HttpResponse(request.user.username) return HttpResponse(request.user.username)
urlpatterns = patterns( urlpatterns = [
'', url(r'^getuser/$', get_user),
(r'^getuser/$', get_user), url(r'^openid/', include('django_openid_auth.urls')),
(r'^openid/', include('django_openid_auth.urls')), ]
)
...@@ -29,11 +29,17 @@ ...@@ -29,11 +29,17 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, url from django.conf.urls import url
urlpatterns = patterns( from django_openid_auth.views import (
'django_openid_auth.views', login_begin,
url(r'^login/$', 'login_begin', name='openid-login'), login_complete,
url(r'^complete/$', 'login_complete', name='openid-complete'), logo,
url(r'^logo.gif$', 'logo', name='openid-logo'),
) )
urlpatterns = [
url(r'^login/$', login_begin, name='openid-login'),
url(r'^complete/$', login_complete, name='openid-complete'),
url(r'^logo.gif$', logo, name='openid-logo'),
]
...@@ -30,8 +30,11 @@ ...@@ -30,8 +30,11 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import re import re
import urllib try:
from urlparse import urlsplit from urllib.parse import urlencode, urlsplit
except ImportError:
from urllib import urlencode
from urlparse import urlsplit
from django.conf import settings from django.conf import settings
from django.contrib.auth import ( from django.contrib.auth import (
...@@ -39,6 +42,7 @@ from django.contrib.auth import ( ...@@ -39,6 +42,7 @@ from django.contrib.auth import (
from django.contrib.auth.models import Group from django.contrib.auth.models import Group
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.http import HttpResponse, HttpResponseRedirect from django.http import HttpResponse, HttpResponseRedirect
from django.http.request import QueryDict
from django.shortcuts import render_to_response from django.shortcuts import render_to_response
from django.template import RequestContext from django.template import RequestContext
from django.template.loader import render_to_string from django.template.loader import render_to_string
...@@ -129,22 +133,27 @@ def default_render_failure(request, message, status=403, ...@@ -129,22 +133,27 @@ def default_render_failure(request, message, status=403,
template_name='openid/failure.html', template_name='openid/failure.html',
exception=None): exception=None):
"""Render an error page to the user.""" """Render an error page to the user."""
data = render_to_string( context = RequestContext(request)
template_name, dict(message=message, exception=exception), context.update(dict(message=message, exception=exception))
context_instance=RequestContext(request)) data = render_to_string(template_name, context)
return HttpResponse(data, status=status) return HttpResponse(data, status=status)
def parse_openid_response(request): def parse_openid_response(request):
"""Parse an OpenID response from a Django request.""" """Parse an OpenID response from a Django request."""
# Short cut if there is no request parameters.
# if len(request.REQUEST) == 0:
# return None
current_url = request.build_absolute_uri() current_url = request.build_absolute_uri()
consumer = make_consumer(request) consumer = make_consumer(request)
return consumer.complete(dict(request.REQUEST.items()), current_url) data = get_request_data(request)
return consumer.complete(data, current_url)
def get_request_data(request):
# simulate old request.REQUEST for backwards compatibility
data = QueryDict(query_string=None, mutable=True)
data.update(request.GET)
data.update(request.POST)
return data
def login_begin(request, template_name='openid/login.html', def login_begin(request, template_name='openid/login.html',
...@@ -153,7 +162,8 @@ def login_begin(request, template_name='openid/login.html', ...@@ -153,7 +162,8 @@ def login_begin(request, template_name='openid/login.html',
render_failure=default_render_failure, render_failure=default_render_failure,
redirect_field_name=REDIRECT_FIELD_NAME): redirect_field_name=REDIRECT_FIELD_NAME):
"""Begin an OpenID login request, possibly asking for an identity URL.""" """Begin an OpenID login request, possibly asking for an identity URL."""
redirect_to = request.REQUEST.get(redirect_field_name, '') data = get_request_data(request)
redirect_to = data.get(redirect_field_name, '')
# Get the OpenID URL to try. First see if we've been configured # Get the OpenID URL to try. First see if we've been configured
# to use a fixed server URL. # to use a fixed server URL.
...@@ -169,10 +179,12 @@ def login_begin(request, template_name='openid/login.html', ...@@ -169,10 +179,12 @@ def login_begin(request, template_name='openid/login.html',
# Invalid or no form data: # Invalid or no form data:
if openid_url is None: if openid_url is None:
context = {'form': login_form, redirect_field_name: redirect_to} context = RequestContext(request)
return render_to_response( context.update({
template_name, context, 'form': login_form,
context_instance=RequestContext(request)) redirect_field_name: redirect_to,
})
return render_to_response(template_name, context)
consumer = make_consumer(request) consumer = make_consumer(request)
try: try:
...@@ -268,7 +280,7 @@ def login_begin(request, template_name='openid/login.html', ...@@ -268,7 +280,7 @@ def login_begin(request, template_name='openid/login.html',
# Django gives us Unicode, which is great. We must encode URI. # Django gives us Unicode, which is great. We must encode URI.
# urllib enforces str. We can't trust anything about the default # urllib enforces str. We can't trust anything about the default
# encoding inside str(foo) , so we must explicitly make foo a str. # encoding inside str(foo) , so we must explicitly make foo a str.
return_to += urllib.urlencode( return_to += urlencode(
{redirect_field_name: redirect_to.encode("UTF-8")}) {redirect_field_name: redirect_to.encode("UTF-8")})
return render_openid_request(request, openid_request, return_to) return render_openid_request(request, openid_request, return_to)
...@@ -277,7 +289,8 @@ def login_begin(request, template_name='openid/login.html', ...@@ -277,7 +289,8 @@ def login_begin(request, template_name='openid/login.html',
@csrf_exempt @csrf_exempt
def login_complete(request, redirect_field_name=REDIRECT_FIELD_NAME, def login_complete(request, redirect_field_name=REDIRECT_FIELD_NAME,
render_failure=None): render_failure=None):
redirect_to = request.REQUEST.get(redirect_field_name, '') data = get_request_data(request)
redirect_to = data.get(redirect_field_name, '')
render_failure = ( render_failure = (
render_failure or getattr(settings, 'OPENID_RENDER_FAILURE', None) or render_failure or getattr(settings, 'OPENID_RENDER_FAILURE', None) or
default_render_failure) default_render_failure)
...@@ -290,8 +303,9 @@ def login_complete(request, redirect_field_name=REDIRECT_FIELD_NAME, ...@@ -290,8 +303,9 @@ def login_complete(request, redirect_field_name=REDIRECT_FIELD_NAME,
if openid_response.status == SUCCESS: if openid_response.status == SUCCESS:
try: try:
user = authenticate(openid_response=openid_response) user = authenticate(openid_response=openid_response)
except DjangoOpenIDException, e: except DjangoOpenIDException as e:
return render_failure(request, e.message, exception=e) return render_failure(
request, getattr(e, 'message', str(e)), exception=e)
if user is not None: if user is not None:
if user.is_active: if user.is_active:
...@@ -325,6 +339,7 @@ def logo(request): ...@@ -325,6 +339,7 @@ def logo(request):
OPENID_LOGO_BASE_64.decode('base64'), mimetype='image/gif' OPENID_LOGO_BASE_64.decode('base64'), mimetype='image/gif'
) )
# Logo from http://openid.net/login-bg.gif # Logo from http://openid.net/login-bg.gif
# Embedded here for convenience; you should serve this as a static file # Embedded here for convenience; you should serve this as a static file
OPENID_LOGO_BASE_64 = """ OPENID_LOGO_BASE_64 = """
......
...@@ -54,6 +54,24 @@ SECRET_KEY = '34958734985734985734985798437' ...@@ -54,6 +54,24 @@ SECRET_KEY = '34958734985734985734985798437'
DEBUG = True DEBUG = True
TEMPLATE_DEBUG = True TEMPLATE_DEBUG = True
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.contrib.auth.context_processors.auth',
'django.template.context_processors.debug',
'django.template.context_processors.i18n',
'django.template.context_processors.media',
'django.template.context_processors.static',
'django.template.context_processors.tz',
'django.contrib.messages.context_processors.messages',
]
}
}
]
ALLOWED_HOSTS = [] ALLOWED_HOSTS = []
......
...@@ -27,20 +27,20 @@ ...@@ -27,20 +27,20 @@
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE. # POSSIBILITY OF SUCH DAMAGE.
from django.conf.urls import patterns, include, url from django.conf.urls import include, url
from django.contrib import admin from django.contrib import admin
from django.contrib.auth import views as auth_views
import views from example_consumer import views
admin.autodiscover() admin.autodiscover()
urlpatterns = patterns( urlpatterns = [
'',
url(r'^$', views.index), url(r'^$', views.index),
url(r'^openid/', include('django_openid_auth.urls')), url(r'^openid/', include('django_openid_auth.urls')),
url(r'^logout/$', 'django.contrib.auth.views.logout'), url(r'^logout/$', auth_views.logout),
url(r'^private/$', views.require_authentication), url(r'^private/$', views.require_authentication),
url(r'^admin/', include(admin.site.urls)), url(r'^admin/', include(admin.site.urls)),
) ]
...@@ -39,21 +39,28 @@ library also includes the following features: ...@@ -39,21 +39,28 @@ library also includes the following features:
info. info.
""" """
import sys
from setuptools import find_packages, setup from setuptools import find_packages, setup
PY3 = sys.version_info.major >= 3
description, long_description = __doc__.split('\n\n', 1) description, long_description = __doc__.split('\n\n', 1)
VERSION = '0.8' VERSION = '0.14'
install_requires = ['django>=1.6', 'six']
if PY3:
install_requires.append('python3-openid')
else:
install_requires.append('python-openid>=2.2.0')
setup( setup(
name='django-openid-auth', name='django-openid-auth',
version=VERSION, version=VERSION,
packages=find_packages(), packages=find_packages(),
install_requires=[ install_requires=install_requires,
'django>=1.5',
'python-openid>=2.2.0',
],
package_data={ package_data={
'django_openid_auth': ['templates/openid/*.html'], 'django_openid_auth': ['templates/openid/*.html'],
}, },
......
[tox] [tox]
envlist = envlist =
py2.7-django1.5, py2.7-django1.6, py2.7-django1.7, py2.7-django1.8 py27-django{1.8,1.9,1.10}
# py3-django{1.11}
[testenv] [testenv]
commands = python manage.py test django_openid_auth commands = python manage.py test django_openid_auth
deps= deps =
mock mock
python-openid
[testenv:py2.7-django1.5] [testenv:py27]
basepython = python2.7 basepython = python2.7
deps = deps =
django >= 1.5, < 1.6 python-openid
{[testenv]deps} {[testenv]deps}
south==1.0
[testenv:py2.7-django1.6] [testenv:py3]
basepython = python2.7 basepython = python3
deps = deps =
django >= 1.6, < 1.7 python3-openid
{[testenv]deps} {[testenv]deps}
south==1.0
[testenv:py2.7-django1.7] [testenv:py27-django1.8]
basepython = python2.7
deps = deps =
django >= 1.7, < 1.8 django >= 1.8, < 1.9
{[testenv]deps} {[testenv:py27]deps}
[testenv:py2.7-django1.8] [testenv:py27-django1.9]
basepython = python2.7
deps = deps =
django >= 1.8, < 1.9 django >= 1.9, < 1.10
{[testenv]deps} {[testenv:py27]deps}
[testenv:py27-django1.10]
deps =
django >= 1.10, < 1.11
{[testenv:py27]deps}
[testenv:py27-django1.11]
deps =
django >= 1.11, < 2
{[testenv:py27]deps}
[testenv:py3-django1.11]
deps =
django >= 1.11, < 2
{[testenv:py3]deps}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment