Commit caca3e1b by Braden MacDonald

SAML2 third_party_auth provider(s) - PR 8018

parent 2942846a
...@@ -424,7 +424,7 @@ def register_user(request, extra_context=None): ...@@ -424,7 +424,7 @@ def register_user(request, extra_context=None):
# selected provider. # selected provider.
if third_party_auth.is_enabled() and pipeline.running(request): if third_party_auth.is_enabled() and pipeline.running(request):
running_pipeline = pipeline.get(request) running_pipeline = pipeline.get(request)
current_provider = provider.Registry.get_by_backend_name(running_pipeline.get('backend')) current_provider = provider.Registry.get_from_pipeline(running_pipeline)
overrides = current_provider.get_register_form_data(running_pipeline.get('kwargs')) overrides = current_provider.get_register_form_data(running_pipeline.get('kwargs'))
overrides['running_pipeline'] = running_pipeline overrides['running_pipeline'] = running_pipeline
overrides['selected_provider'] = current_provider.NAME overrides['selected_provider'] = current_provider.NAME
...@@ -952,10 +952,11 @@ def login_user(request, error=""): # pylint: disable-msg=too-many-statements,un ...@@ -952,10 +952,11 @@ def login_user(request, error=""): # pylint: disable-msg=too-many-statements,un
running_pipeline = pipeline.get(request) running_pipeline = pipeline.get(request)
username = running_pipeline['kwargs'].get('username') username = running_pipeline['kwargs'].get('username')
backend_name = running_pipeline['backend'] backend_name = running_pipeline['backend']
requested_provider = provider.Registry.get_by_backend_name(backend_name) third_party_uid = running_pipeline['kwargs']['uid']
requested_provider = provider.Registry.get_from_pipeline(running_pipeline)
try: try:
user = pipeline.get_authenticated_user(username, backend_name) user = pipeline.get_authenticated_user(requested_provider, username, third_party_uid)
third_party_auth_successful = True third_party_auth_successful = True
except User.DoesNotExist: except User.DoesNotExist:
AUDIT_LOG.warning( AUDIT_LOG.warning(
...@@ -1509,7 +1510,7 @@ def create_account_with_params(request, params): ...@@ -1509,7 +1510,7 @@ def create_account_with_params(request, params):
provider_name = None provider_name = None
if third_party_auth.is_enabled() and pipeline.running(request): if third_party_auth.is_enabled() and pipeline.running(request):
running_pipeline = pipeline.get(request) running_pipeline = pipeline.get(request)
current_provider = provider.Registry.get_by_backend_name(running_pipeline.get('backend')) current_provider = provider.Registry.get_from_pipeline(running_pipeline)
provider_name = current_provider.NAME provider_name = current_provider.NAME
analytics.track( analytics.track(
......
...@@ -196,9 +196,11 @@ class ProviderUserState(object): ...@@ -196,9 +196,11 @@ class ProviderUserState(object):
lms/templates/dashboard.html. lms/templates/dashboard.html.
""" """
def __init__(self, enabled_provider, user, state): def __init__(self, enabled_provider, user, association_id=None):
# UserSocialAuth row ID
self.association_id = association_id
# Boolean. Whether the user has an account associated with the provider # Boolean. Whether the user has an account associated with the provider
self.has_account = state self.has_account = association_id is not None
# provider.BaseProvider child. Callers must verify that the provider is # provider.BaseProvider child. Callers must verify that the provider is
# enabled. # enabled.
self.provider = enabled_provider self.provider = enabled_provider
...@@ -215,7 +217,7 @@ def get(request): ...@@ -215,7 +217,7 @@ def get(request):
return request.session.get('partial_pipeline') return request.session.get('partial_pipeline')
def get_authenticated_user(username, backend_name): def get_authenticated_user(auth_provider, username, uid):
"""Gets a saved user authenticated by a particular backend. """Gets a saved user authenticated by a particular backend.
Between pipeline steps User objects are not saved. We need to reconstitute Between pipeline steps User objects are not saved. We need to reconstitute
...@@ -224,26 +226,26 @@ def get_authenticated_user(username, backend_name): ...@@ -224,26 +226,26 @@ def get_authenticated_user(username, backend_name):
authenticate(). authenticate().
Args: Args:
auth_provider: the third_party_auth provider in use for the current pipeline.
username: string. Username of user to get. username: string. Username of user to get.
backend_name: string. The name of the third-party auth backend from uid: string. The user ID according to the third party.
the running pipeline.
Returns: Returns:
User if user is found and has a social auth from the passed User if user is found and has a social auth from the passed
backend_name. provider.
Raises: Raises:
User.DoesNotExist: if no user matching user is found, or the matching User.DoesNotExist: if no user matching user is found, or the matching
user has no social auth associated with the given backend. user has no social auth associated with the given backend.
AssertionError: if the user is not authenticated. AssertionError: if the user is not authenticated.
""" """
user = models.DjangoStorage.user.user_model().objects.get(username=username) match = models.DjangoStorage.user.get_social_auth(provider=auth_provider.BACKEND_CLASS.name, uid=uid)
match = models.DjangoStorage.user.get_social_auth_for_user(user, provider=backend_name)
if not match: if not match or match.user.username != username:
raise User.DoesNotExist raise User.DoesNotExist
user.backend = provider.Registry.get_by_backend_name(backend_name).get_authentication_backend() user = match.user
user.backend = auth_provider.get_authentication_backend()
return user return user
...@@ -257,10 +259,12 @@ def _get_enabled_provider_by_name(provider_name): ...@@ -257,10 +259,12 @@ def _get_enabled_provider_by_name(provider_name):
return enabled_provider return enabled_provider
def _get_url(view_name, backend_name, auth_entry=None, redirect_url=None): def _get_url(view_name, backend_name, auth_entry=None, redirect_url=None,
extra_params=None, url_params=None):
"""Creates a URL to hook into social auth endpoints.""" """Creates a URL to hook into social auth endpoints."""
kwargs = {'backend': backend_name} url_params = url_params or {}
url = reverse(view_name, kwargs=kwargs) url_params['backend'] = backend_name
url = reverse(view_name, kwargs=url_params)
query_params = OrderedDict() query_params = OrderedDict()
if auth_entry: if auth_entry:
...@@ -269,6 +273,9 @@ def _get_url(view_name, backend_name, auth_entry=None, redirect_url=None): ...@@ -269,6 +273,9 @@ def _get_url(view_name, backend_name, auth_entry=None, redirect_url=None):
if redirect_url: if redirect_url:
query_params[AUTH_REDIRECT_KEY] = redirect_url query_params[AUTH_REDIRECT_KEY] = redirect_url
if extra_params:
query_params.update(extra_params)
return u"{url}?{params}".format( return u"{url}?{params}".format(
url=url, url=url,
params=urllib.urlencode(query_params) params=urllib.urlencode(query_params)
...@@ -288,29 +295,32 @@ def get_complete_url(backend_name): ...@@ -288,29 +295,32 @@ def get_complete_url(backend_name):
Raises: Raises:
ValueError: if no provider is enabled with the given backend_name. ValueError: if no provider is enabled with the given backend_name.
""" """
enabled_provider = provider.Registry.get_by_backend_name(backend_name) if not any(provider.Registry.get_enabled_by_backend_name(backend_name)):
if not enabled_provider:
raise ValueError('Provider with backend %s not enabled' % backend_name) raise ValueError('Provider with backend %s not enabled' % backend_name)
return _get_url('social:complete', backend_name) return _get_url('social:complete', backend_name)
def get_disconnect_url(provider_name): def get_disconnect_url(provider_name, association_id):
"""Gets URL for the endpoint that starts the disconnect pipeline. """Gets URL for the endpoint that starts the disconnect pipeline.
Args: Args:
provider_name: string. Name of the provider.BaseProvider child you want provider_name: string. Name of the provider.BaseProvider child you want
to disconnect from. to disconnect from.
association_id: int. Optional ID of a specific row in the UserSocialAuth
table to disconnect (useful if multiple providers use a common backend)
Returns: Returns:
String. URL that starts the disconnection pipeline. String. URL that starts the disconnection pipeline.
Raises: Raises:
ValueError: if no provider is enabled with the given backend_name. ValueError: if no provider is enabled with the given name.
""" """
enabled_provider = _get_enabled_provider_by_name(provider_name) backend_name = _get_enabled_provider_by_name(provider_name).BACKEND_CLASS.name
return _get_url('social:disconnect', enabled_provider.BACKEND_CLASS.name) if association_id:
return _get_url('social:disconnect_individual', backend_name, url_params={'association_id': association_id})
else:
return _get_url('social:disconnect', backend_name)
def get_login_url(provider_name, auth_entry, redirect_url=None): def get_login_url(provider_name, auth_entry, redirect_url=None):
...@@ -340,6 +350,7 @@ def get_login_url(provider_name, auth_entry, redirect_url=None): ...@@ -340,6 +350,7 @@ def get_login_url(provider_name, auth_entry, redirect_url=None):
enabled_provider.BACKEND_CLASS.name, enabled_provider.BACKEND_CLASS.name,
auth_entry=auth_entry, auth_entry=auth_entry,
redirect_url=redirect_url, redirect_url=redirect_url,
extra_params=enabled_provider.get_url_params(),
) )
...@@ -355,7 +366,7 @@ def get_duplicate_provider(messages): ...@@ -355,7 +366,7 @@ def get_duplicate_provider(messages):
unfortunately not in a reusable constant. unfortunately not in a reusable constant.
Returns: Returns:
provider.BaseProvider child instance. The provider of the duplicate string name of the python-social-auth backend that has the duplicate
account, or None if there is no duplicate (and hence no error). account, or None if there is no duplicate (and hence no error).
""" """
social_auth_messages = [m for m in messages if m.message.endswith('is already in use.')] social_auth_messages = [m for m in messages if m.message.endswith('is already in use.')]
...@@ -364,7 +375,8 @@ def get_duplicate_provider(messages): ...@@ -364,7 +375,8 @@ def get_duplicate_provider(messages):
return return
assert len(social_auth_messages) == 1 assert len(social_auth_messages) == 1
return provider.Registry.get_by_backend_name(social_auth_messages[0].extra_tags.split()[1]) backend_name = social_auth_messages[0].extra_tags.split()[1]
return backend_name
def get_provider_user_states(user): def get_provider_user_states(user):
...@@ -378,13 +390,16 @@ def get_provider_user_states(user): ...@@ -378,13 +390,16 @@ def get_provider_user_states(user):
each enabled provider. each enabled provider.
""" """
states = [] states = []
found_user_backends = [ found_user_auths = list(models.DjangoStorage.user.get_social_auth_for_user(user))
social_auth.provider for social_auth in models.DjangoStorage.user.get_social_auth_for_user(user)
]
for enabled_provider in provider.Registry.enabled(): for enabled_provider in provider.Registry.enabled():
association_id = None
for auth in found_user_auths:
if enabled_provider.match_social_auth(auth):
association_id = auth.id
break
states.append( states.append(
ProviderUserState(enabled_provider, user, enabled_provider.BACKEND_CLASS.name in found_user_backends) ProviderUserState(enabled_provider, user, association_id)
) )
return states return states
......
...@@ -5,6 +5,8 @@ invoke the Django armature. ...@@ -5,6 +5,8 @@ invoke the Django armature.
""" """
from social.backends import google, linkedin, facebook from social.backends import google, linkedin, facebook
from social.backends.saml import OID_EDU_PERSON_PRINCIPAL_NAME
from .saml import SAMLAuthBackend
_DEFAULT_ICON_CLASS = 'fa-signin' _DEFAULT_ICON_CLASS = 'fa-signin'
...@@ -109,6 +111,21 @@ class BaseProvider(object): ...@@ -109,6 +111,21 @@ class BaseProvider(object):
for key, value in cls.SETTINGS.iteritems(): for key, value in cls.SETTINGS.iteritems():
setattr(settings, key, value) setattr(settings, key, value)
@classmethod
def get_url_params(cls):
""" Get a dict of GET parameters to append to login links for this provider """
return {}
@classmethod
def is_active_for_pipeline(cls, pipeline):
""" Is this provider being used for the specified pipeline? """
return cls.BACKEND_CLASS.name == pipeline['backend']
@classmethod
def match_social_auth(cls, social_auth):
""" Is this provider being used for this UserSocialAuth entry? """
return cls.BACKEND_CLASS.name == social_auth.provider
class GoogleOauth2(BaseProvider): class GoogleOauth2(BaseProvider):
"""Provider for Google's Oauth2 auth system.""" """Provider for Google's Oauth2 auth system."""
...@@ -146,6 +163,78 @@ class FacebookOauth2(BaseProvider): ...@@ -146,6 +163,78 @@ class FacebookOauth2(BaseProvider):
} }
class SAMLProviderMixin(object):
""" Base class for SAML/Shibboleth providers """
BACKEND_CLASS = SAMLAuthBackend
ICON_CLASS = 'fa-university'
@classmethod
def get_url_params(cls):
""" Get a dict of GET parameters to append to login links for this provider """
return {'idp': cls.IDP["id"]}
@classmethod
def is_active_for_pipeline(cls, pipeline):
""" Is this provider being used for the specified pipeline? """
if cls.BACKEND_CLASS.name == pipeline['backend']:
idp_name = pipeline['kwargs']['response']['idp_name']
return cls.IDP["id"] == idp_name
return False
@classmethod
def match_social_auth(cls, social_auth):
""" Is this provider being used for this UserSocialAuth entry? """
prefix = cls.IDP["id"] + ":"
return cls.BACKEND_CLASS.name == social_auth.provider and social_auth.uid.startswith(prefix)
class TestShibAProvider(SAMLProviderMixin, BaseProvider):
""" Provider for testshib.org public Shibboleth test server. """
NAME = 'TestShib A'
IDP = {
"id": "testshiba", # Required slug
"entity_id": "https://idp.testshib.org/idp/shibboleth",
"url": "https://idp.testshib.org/idp/profile/SAML2/Redirect/SSO",
"attr_email": OID_EDU_PERSON_PRINCIPAL_NAME,
"x509cert": """
MIIEDjCCAvagAwIBAgIBADANBgkqhkiG9w0BAQUFADBnMQswCQYDVQQGEwJVUzEV
MBMGA1UECBMMUGVubnN5bHZhbmlhMRMwEQYDVQQHEwpQaXR0c2J1cmdoMREwDwYD
VQQKEwhUZXN0U2hpYjEZMBcGA1UEAxMQaWRwLnRlc3RzaGliLm9yZzAeFw0wNjA4
MzAyMTEyMjVaFw0xNjA4MjcyMTEyMjVaMGcxCzAJBgNVBAYTAlVTMRUwEwYDVQQI
EwxQZW5uc3lsdmFuaWExEzARBgNVBAcTClBpdHRzYnVyZ2gxETAPBgNVBAoTCFRl
c3RTaGliMRkwFwYDVQQDExBpZHAudGVzdHNoaWIub3JnMIIBIjANBgkqhkiG9w0B
AQEFAAOCAQ8AMIIBCgKCAQEArYkCGuTmJp9eAOSGHwRJo1SNatB5ZOKqDM9ysg7C
yVTDClcpu93gSP10nH4gkCZOlnESNgttg0r+MqL8tfJC6ybddEFB3YBo8PZajKSe
3OQ01Ow3yT4I+Wdg1tsTpSge9gEz7SrC07EkYmHuPtd71CHiUaCWDv+xVfUQX0aT
NPFmDixzUjoYzbGDrtAyCqA8f9CN2txIfJnpHE6q6CmKcoLADS4UrNPlhHSzd614
kR/JYiks0K4kbRqCQF0Dv0P5Di+rEfefC6glV8ysC8dB5/9nb0yh/ojRuJGmgMWH
gWk6h0ihjihqiu4jACovUZ7vVOCgSE5Ipn7OIwqd93zp2wIDAQABo4HEMIHBMB0G
A1UdDgQWBBSsBQ869nh83KqZr5jArr4/7b+QazCBkQYDVR0jBIGJMIGGgBSsBQ86
9nh83KqZr5jArr4/7b+Qa6FrpGkwZzELMAkGA1UEBhMCVVMxFTATBgNVBAgTDFBl
bm5zeWx2YW5pYTETMBEGA1UEBxMKUGl0dHNidXJnaDERMA8GA1UEChMIVGVzdFNo
aWIxGTAXBgNVBAMTEGlkcC50ZXN0c2hpYi5vcmeCAQAwDAYDVR0TBAUwAwEB/zAN
BgkqhkiG9w0BAQUFAAOCAQEAjR29PhrCbk8qLN5MFfSVk98t3CT9jHZoYxd8QMRL
I4j7iYQxXiGJTT1FXs1nd4Rha9un+LqTfeMMYqISdDDI6tv8iNpkOAvZZUosVkUo
93pv1T0RPz35hcHHYq2yee59HJOco2bFlcsH8JBXRSRrJ3Q7Eut+z9uo80JdGNJ4
/SJy5UorZ8KazGj16lfJhOBXldgrhppQBb0Nq6HKHguqmwRfJ+WkxemZXzhediAj
Geka8nz8JjwxpUjAiSWYKLtJhGEaTqCYxCCX2Dw+dOTqUzHOZ7WKv4JXPK5G/Uhr
8K/qhmFT2nIQi538n6rVYLeWj8Bbnl+ev0peYzxFyF5sQA==
"""
}
class TestShibBProvider(SAMLProviderMixin, BaseProvider):
""" Provider for testshib.org public Shibboleth test server. """
NAME = 'TestShib B'
IDP = {
"id": "testshibB", # Required slug
"entity_id": "https://idp.testshib.org/idp/shibboleth",
"url": "https://IDP.TESTSHIB.ORG/idp/profile/SAML2/Redirect/SSO",
"attr_email": OID_EDU_PERSON_PRINCIPAL_NAME,
"x509cert": TestShibAProvider.IDP["x509cert"],
}
class Registry(object): class Registry(object):
"""Singleton registry of third-party auth providers. """Singleton registry of third-party auth providers.
...@@ -211,23 +300,49 @@ class Registry(object): ...@@ -211,23 +300,49 @@ class Registry(object):
return cls._ENABLED.get(provider_name) return cls._ENABLED.get(provider_name)
@classmethod @classmethod
def get_by_backend_name(cls, backend_name): def get_from_pipeline(cls, running_pipeline):
"""Gets provider (or None) by backend name. """Gets the provider that is being used for the specified pipeline (or None).
Args: Args:
backend_name: string. The python-social-auth running_pipeline: The python-social-auth pipeline being used to
backends.base.BaseAuth.name (for example, 'google-oauth2') to authenticate a user.
try and get a provider for.
Returns:
A provider class (a subclass of BaseProvider) or None.
Raises: Raises:
RuntimeError: if the registry has not been configured. RuntimeError: if the registry has not been configured.
""" """
cls._check_configured() cls._check_configured()
for enabled in cls._ENABLED.values(): for enabled in cls._ENABLED.values():
if enabled.BACKEND_CLASS.name == backend_name: if enabled.is_active_for_pipeline(running_pipeline):
return enabled return enabled
@classmethod @classmethod
def get_enabled_by_backend_name(cls, backend_name):
"""Generator returning all enabled providers that use the specified
backend.
Example:
>>> list(get_enabled_by_backend_name("tpa-saml"))
[TestShibAProvider, TestShibBProvider]
Args:
backend_name: The name of a python-social-auth backend used by
one or more providers.
Yields:
Provider classes (subclasses of BaseProvider).
Raises:
RuntimeError: if the registry has not been configured.
"""
cls._check_configured()
for enabled in cls._ENABLED.values():
if enabled.BACKEND_CLASS.name == backend_name:
yield enabled
@classmethod
def _reset(cls): def _reset(cls):
"""Returns the registry to an unconfigured state; for tests only.""" """Returns the registry to an unconfigured state; for tests only."""
cls._CONFIGURED = False cls._CONFIGURED = False
......
"""
Slightly customized python-social-auth backend for SAML 2.0 support
"""
from social.backends.saml import SAMLIdentityProvider, SAMLAuth
class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
"""
Customized version of SAMLAuth that gets the list of IdPs from third_party_auth's list of
enabled providers.
"""
name = "tpa-saml"
def get_idp(self, idp_name):
""" Given the name of an IdP, get a SAMLIdentityProvider instance """
from .provider import Registry # Import here to avoid circular import
for provider in Registry.enabled():
if issubclass(provider.BACKEND_CLASS, SAMLAuth) and provider.IDP["id"] == idp_name:
return SAMLIdentityProvider(idp_name, **provider.IDP)
raise KeyError("SAML IdP {} not found.".format(idp_name))
...@@ -115,12 +115,12 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -115,12 +115,12 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
"""Asserts the user's account settings page context is in the expected state. """Asserts the user's account settings page context is in the expected state.
If duplicate is True, we expect context['duplicate_provider'] to contain If duplicate is True, we expect context['duplicate_provider'] to contain
the duplicate provider object. If linked is passed, we conditionally the duplicate provider backend name. If linked is passed, we conditionally
check that the provider is included in context['auth']['providers'] and check that the provider is included in context['auth']['providers'] and
its connected state is correct. its connected state is correct.
""" """
if duplicate: if duplicate:
self.assertEqual(context['duplicate_provider'].NAME, self.PROVIDER_CLASS.NAME) self.assertEqual(context['duplicate_provider'], self.PROVIDER_CLASS.BACKEND_CLASS.name)
else: else:
self.assertIsNone(context['duplicate_provider']) self.assertIsNone(context['duplicate_provider'])
......
...@@ -38,5 +38,5 @@ class ProviderUserStateTestCase(testutil.TestCase): ...@@ -38,5 +38,5 @@ class ProviderUserStateTestCase(testutil.TestCase):
"""Tests ProviderUserState behavior.""" """Tests ProviderUserState behavior."""
def test_get_unlink_form_name(self): def test_get_unlink_form_name(self):
state = pipeline.ProviderUserState(provider.GoogleOauth2, object(), False) state = pipeline.ProviderUserState(provider.GoogleOauth2, object(), 1000)
self.assertEqual(provider.GoogleOauth2.NAME + '_unlink_form', state.get_unlink_form_name()) self.assertEqual(provider.GoogleOauth2.NAME + '_unlink_form', state.get_unlink_form_name())
...@@ -41,16 +41,16 @@ class GetAuthenticatedUserTestCase(TestCase): ...@@ -41,16 +41,16 @@ class GetAuthenticatedUserTestCase(TestCase):
def test_raises_does_not_exist_if_user_missing(self): def test_raises_does_not_exist_if_user_missing(self):
with self.assertRaises(models.User.DoesNotExist): with self.assertRaises(models.User.DoesNotExist):
pipeline.get_authenticated_user('new_' + self.user.username, 'backend') pipeline.get_authenticated_user(self.enabled_provider, 'new_' + self.user.username, 'user@example.com')
def test_raises_does_not_exist_if_user_found_but_no_association(self): def test_raises_does_not_exist_if_user_found_but_no_association(self):
backend_name = 'backend' backend_name = 'backend'
self.assertIsNotNone(self.get_by_username(self.user.username)) self.assertIsNotNone(self.get_by_username(self.user.username))
self.assertIsNone(provider.Registry.get_by_backend_name(backend_name)) self.assertFalse(any(provider.Registry.get_enabled_by_backend_name(backend_name)))
with self.assertRaises(models.User.DoesNotExist): with self.assertRaises(models.User.DoesNotExist):
pipeline.get_authenticated_user(self.user.username, 'backend') pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'user@example.com')
def test_raises_does_not_exist_if_user_and_association_found_but_no_match(self): def test_raises_does_not_exist_if_user_and_association_found_but_no_match(self):
self.assertIsNotNone(self.get_by_username(self.user.username)) self.assertIsNotNone(self.get_by_username(self.user.username))
...@@ -58,11 +58,11 @@ class GetAuthenticatedUserTestCase(TestCase): ...@@ -58,11 +58,11 @@ class GetAuthenticatedUserTestCase(TestCase):
self.user, 'uid', 'other_' + self.enabled_provider.BACKEND_CLASS.name) self.user, 'uid', 'other_' + self.enabled_provider.BACKEND_CLASS.name)
with self.assertRaises(models.User.DoesNotExist): with self.assertRaises(models.User.DoesNotExist):
pipeline.get_authenticated_user(self.user.username, self.enabled_provider.BACKEND_CLASS.name) pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'uid')
def test_returns_user_with_is_authenticated_and_backend_set_if_match(self): def test_returns_user_with_is_authenticated_and_backend_set_if_match(self):
social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', self.enabled_provider.BACKEND_CLASS.name) social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', self.enabled_provider.BACKEND_CLASS.name)
user = pipeline.get_authenticated_user(self.user.username, self.enabled_provider.BACKEND_CLASS.name) user = pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'uid')
self.assertEqual(self.user, user) self.assertEqual(self.user, user)
self.assertEqual(self.enabled_provider.get_authentication_backend(), user.backend) self.assertEqual(self.enabled_provider.get_authentication_backend(), user.backend)
...@@ -93,8 +93,9 @@ class GetProviderUserStatesTestCase(testutil.TestCase, test.TestCase): ...@@ -93,8 +93,9 @@ class GetProviderUserStatesTestCase(testutil.TestCase, test.TestCase):
def test_states_for_enabled_providers_user_has_accounts_associated_with(self): def test_states_for_enabled_providers_user_has_accounts_associated_with(self):
provider.Registry.configure_once([provider.GoogleOauth2.NAME, provider.LinkedInOauth2.NAME]) provider.Registry.configure_once([provider.GoogleOauth2.NAME, provider.LinkedInOauth2.NAME])
social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', provider.GoogleOauth2.BACKEND_CLASS.name) user_social_auth_google = social_models.DjangoStorage.user.create_social_auth(
social_models.DjangoStorage.user.create_social_auth( self.user, 'uid', provider.GoogleOauth2.BACKEND_CLASS.name)
user_social_auth_linkedin = social_models.DjangoStorage.user.create_social_auth(
self.user, 'uid', provider.LinkedInOauth2.BACKEND_CLASS.name) self.user, 'uid', provider.LinkedInOauth2.BACKEND_CLASS.name)
states = pipeline.get_provider_user_states(self.user) states = pipeline.get_provider_user_states(self.user)
...@@ -106,10 +107,12 @@ class GetProviderUserStatesTestCase(testutil.TestCase, test.TestCase): ...@@ -106,10 +107,12 @@ class GetProviderUserStatesTestCase(testutil.TestCase, test.TestCase):
self.assertTrue(google_state.has_account) self.assertTrue(google_state.has_account)
self.assertEqual(provider.GoogleOauth2, google_state.provider) self.assertEqual(provider.GoogleOauth2, google_state.provider)
self.assertEqual(self.user, google_state.user) self.assertEqual(self.user, google_state.user)
self.assertEqual(user_social_auth_google.id, google_state.association_id)
self.assertTrue(linkedin_state.has_account) self.assertTrue(linkedin_state.has_account)
self.assertEqual(provider.LinkedInOauth2, linkedin_state.provider) self.assertEqual(provider.LinkedInOauth2, linkedin_state.provider)
self.assertEqual(self.user, linkedin_state.user) self.assertEqual(self.user, linkedin_state.user)
self.assertEqual(user_social_auth_linkedin.id, linkedin_state.association_id)
def test_states_for_enabled_providers_user_has_no_account_associated_with(self): def test_states_for_enabled_providers_user_has_no_account_associated_with(self):
provider.Registry.configure_once([provider.GoogleOauth2.NAME, provider.LinkedInOauth2.NAME]) provider.Registry.configure_once([provider.GoogleOauth2.NAME, provider.LinkedInOauth2.NAME])
...@@ -155,13 +158,16 @@ class UrlFormationTestCase(TestCase): ...@@ -155,13 +158,16 @@ class UrlFormationTestCase(TestCase):
self.assertIsNone(provider.Registry.get(provider_name)) self.assertIsNone(provider.Registry.get(provider_name))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pipeline.get_disconnect_url(provider_name) pipeline.get_disconnect_url(provider_name, 1000)
def test_disconnect_url_returns_expected_format(self): def test_disconnect_url_returns_expected_format(self):
disconnect_url = pipeline.get_disconnect_url(self.enabled_provider.NAME) disconnect_url = pipeline.get_disconnect_url(self.enabled_provider.NAME, 1000)
disconnect_url = disconnect_url.rstrip('?')
self.assertTrue(disconnect_url.startswith('/auth/disconnect')) self.assertEqual(
self.assertIn(self.enabled_provider.BACKEND_CLASS.name, disconnect_url) disconnect_url,
'/auth/disconnect/{backend}/{association_id}/'.format(
backend=self.enabled_provider.BACKEND_CLASS.name, association_id=1000)
)
def test_login_url_raises_value_error_if_provider_not_enabled(self): def test_login_url_raises_value_error_if_provider_not_enabled(self):
provider_name = 'not_enabled' provider_name = 'not_enabled'
......
"""Unit tests for provider.py.""" """Unit tests for provider.py."""
from mock import Mock
from third_party_auth import provider from third_party_auth import provider
from third_party_auth.tests import testutil from third_party_auth.tests import testutil
...@@ -67,16 +68,22 @@ class RegistryTest(testutil.TestCase): ...@@ -67,16 +68,22 @@ class RegistryTest(testutil.TestCase):
provider.Registry.configure_once([]) provider.Registry.configure_once([])
self.assertIsNone(provider.Registry.get(provider.LinkedInOauth2.NAME)) self.assertIsNone(provider.Registry.get(provider.LinkedInOauth2.NAME))
def test_get_by_backend_name_raises_runtime_error_if_not_configured(self): def test_get_from_pipeline_returns_none_if_provider_not_enabled(self):
provider.Registry.configure_once([])
self.assertIsNone(provider.Registry.get_from_pipeline(Mock()))
def test_get_enabled_by_backend_name_raises_runtime_error_if_not_configured(self):
with self.assertRaisesRegexp(RuntimeError, '^.*not configured$'): with self.assertRaisesRegexp(RuntimeError, '^.*not configured$'):
provider.Registry.get_by_backend_name('') provider.Registry.get_enabled_by_backend_name('').next()
def test_get_by_backend_name_returns_enabled_provider(self): def test_get_enabled_by_backend_name_returns_enabled_provider(self):
provider.Registry.configure_once([provider.GoogleOauth2.NAME]) provider.Registry.configure_once([provider.GoogleOauth2.NAME])
self.assertIs( found = list(provider.Registry.get_enabled_by_backend_name(provider.GoogleOauth2.BACKEND_CLASS.name))
provider.GoogleOauth2, self.assertEqual(found, [provider.GoogleOauth2])
provider.Registry.get_by_backend_name(provider.GoogleOauth2.BACKEND_CLASS.name))
def test_get_by_backend_name_returns_none_if_provider_not_enabled(self): def test_get_enabled_by_backend_name_returns_none_if_provider_not_enabled(self):
provider.Registry.configure_once([]) provider.Registry.configure_once([])
self.assertIsNone(provider.Registry.get_by_backend_name(provider.GoogleOauth2.BACKEND_CLASS.name)) self.assertEqual(
[],
list(provider.Registry.get_enabled_by_backend_name(provider.GoogleOauth2.BACKEND_CLASS.name))
)
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
from django.conf.urls import include, patterns, url from django.conf.urls import include, patterns, url
from .views import inactive_user_view from .views import inactive_user_view, saml_metadata_view
urlpatterns = patterns( urlpatterns = patterns(
'', '',
url(r'^auth/inactive', inactive_user_view), url(r'^auth/inactive', inactive_user_view),
url(r'^auth/saml/metadata.xml', saml_metadata_view),
url(r'^auth/', include('social.apps.django_app.urls', namespace='social')), url(r'^auth/', include('social.apps.django_app.urls', namespace='social')),
) )
""" """
Extra views required for SSO Extra views required for SSO
""" """
from django.conf import settings
from django.core.urlresolvers import reverse
from django.http import HttpResponse, HttpResponseServerError
from django.shortcuts import redirect from django.shortcuts import redirect
from social.apps.django_app.utils import load_strategy, load_backend
def inactive_user_view(request): def inactive_user_view(request):
...@@ -13,3 +17,19 @@ def inactive_user_view(request): ...@@ -13,3 +17,19 @@ def inactive_user_view(request):
# in a course. Otherwise, just redirect them to the dashboard, which displays a message # in a course. Otherwise, just redirect them to the dashboard, which displays a message
# about activating their account. # about activating their account.
return redirect(request.GET.get('next', 'dashboard')) return redirect(request.GET.get('next', 'dashboard'))
def saml_metadata_view(request):
"""
Get the Service Provider metadata for this edx-platform instance.
You must send this XML to any Shibboleth Identity Provider that you wish to use.
"""
complete_url = reverse('social:complete', args=("tpa-saml", ))
if settings.APPEND_SLASH and not complete_url.endswith('/'):
complete_url = complete_url + '/' # Required for consistency
saml_backend = load_backend(load_strategy(request), "tpa-saml", redirect_uri=complete_url)
metadata, errors = saml_backend.generate_metadata_xml()
if not errors:
return HttpResponse(content=metadata, content_type='text/xml')
return HttpResponseServerError(content=', '.join(errors))
...@@ -9,7 +9,7 @@ For processing xml always prefer this over using lxml.etree directly. ...@@ -9,7 +9,7 @@ For processing xml always prefer this over using lxml.etree directly.
from lxml.etree import * # pylint: disable=wildcard-import, unused-wildcard-import from lxml.etree import * # pylint: disable=wildcard-import, unused-wildcard-import
from lxml.etree import XMLParser as _XMLParser from lxml.etree import XMLParser as _XMLParser
from lxml.etree import _ElementTree # pylint: disable=unused-import from lxml.etree import _Element, _ElementTree # pylint: disable=unused-import, no-name-in-module
# This should be imported after lxml.etree so that it overrides the following attributes. # This should be imported after lxml.etree so that it overrides the following attributes.
from defusedxml.lxml import parse, fromstring, XML from defusedxml.lxml import parse, fromstring, XML
......
...@@ -1754,6 +1754,7 @@ class CombinedSystem(object): ...@@ -1754,6 +1754,7 @@ class CombinedSystem(object):
integrate it into a larger whole. integrate it into a larger whole.
""" """
context = context or {}
if view_name in PREVIEW_VIEWS: if view_name in PREVIEW_VIEWS:
block = self._get_student_block(block) block = self._get_student_block(block)
......
...@@ -432,7 +432,7 @@ class AccountSettingsViewTest(TestCase): ...@@ -432,7 +432,7 @@ class AccountSettingsViewTest(TestCase):
context['user_preferences_api_url'], reverse('preferences_api', kwargs={'username': self.user.username}) context['user_preferences_api_url'], reverse('preferences_api', kwargs={'username': self.user.username})
) )
self.assertEqual(context['duplicate_provider'].BACKEND_CLASS.name, 'facebook') self.assertEqual(context['duplicate_provider'], 'facebook')
self.assertEqual(context['auth']['providers'][0]['name'], 'Facebook') self.assertEqual(context['auth']['providers'][0]['name'], 'Facebook')
self.assertEqual(context['auth']['providers'][1]['name'], 'Google') self.assertEqual(context['auth']['providers'][1]['name'], 'Google')
......
...@@ -189,9 +189,7 @@ def _third_party_auth_context(request, redirect_to): ...@@ -189,9 +189,7 @@ def _third_party_auth_context(request, redirect_to):
running_pipeline = pipeline.get(request) running_pipeline = pipeline.get(request)
if running_pipeline is not None: if running_pipeline is not None:
current_provider = third_party_auth.provider.Registry.get_by_backend_name( current_provider = third_party_auth.provider.Registry.get_from_pipeline(running_pipeline)
running_pipeline.get('backend')
)
context["currentProvider"] = current_provider.NAME context["currentProvider"] = current_provider.NAME
context["finishAuthUrl"] = pipeline.get_complete_url(current_provider.BACKEND_CLASS.name) context["finishAuthUrl"] = pipeline.get_complete_url(current_provider.BACKEND_CLASS.name)
...@@ -382,7 +380,7 @@ def account_settings_context(request): ...@@ -382,7 +380,7 @@ def account_settings_context(request):
), ),
# If the user is connected, sending a POST request to this url removes the connection # If the user is connected, sending a POST request to this url removes the connection
# information for this provider from their edX account. # information for this provider from their edX account.
'disconnect_url': pipeline.get_disconnect_url(state.provider.NAME), 'disconnect_url': pipeline.get_disconnect_url(state.provider.NAME, state.association_id),
} for state in auth_states] } for state in auth_states]
return context return context
...@@ -541,6 +541,25 @@ THIRD_PARTY_AUTH = AUTH_TOKENS.get('THIRD_PARTY_AUTH', THIRD_PARTY_AUTH) ...@@ -541,6 +541,25 @@ THIRD_PARTY_AUTH = AUTH_TOKENS.get('THIRD_PARTY_AUTH', THIRD_PARTY_AUTH)
# The reduced session expiry time during the third party login pipeline. (Value in seconds) # The reduced session expiry time during the third party login pipeline. (Value in seconds)
SOCIAL_AUTH_PIPELINE_TIMEOUT = ENV_TOKENS.get('SOCIAL_AUTH_PIPELINE_TIMEOUT', 600) SOCIAL_AUTH_PIPELINE_TIMEOUT = ENV_TOKENS.get('SOCIAL_AUTH_PIPELINE_TIMEOUT', 600)
##### SAML configuration for third_party_auth #####
if 'SOCIAL_AUTH_TPA_SAML_SP_ENTITY_ID' in ENV_TOKENS:
SOCIAL_AUTH_TPA_SAML_SP_ENTITY_ID = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_ENTITY_ID')
SOCIAL_AUTH_TPA_SAML_SP_NAMEID_FORMAT = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_NAMEID_FORMAT', 'unspecified')
SOCIAL_AUTH_TPA_SAML_SP_EXTRA = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_EXTRA', {})
SOCIAL_AUTH_TPA_SAML_ORG_INFO = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_ORG_INFO')
SOCIAL_AUTH_TPA_SAML_TECHNICAL_CONTACT = ENV_TOKENS.get(
'SOCIAL_AUTH_TPA_SAML_TECHNICAL_CONTACT',
{"givenName": "Technical Support", "emailAddress": TECH_SUPPORT_EMAIL}
)
SOCIAL_AUTH_TPA_SAML_SUPPORT_CONTACT = ENV_TOKENS.get(
'SOCIAL_AUTH_TPA_SAML_SUPPORT_CONTACT',
{"givenName": "Support", "emailAddress": TECH_SUPPORT_EMAIL}
)
SOCIAL_AUTH_TPA_SAML_SECURITY_CONFIG = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SECURITY_CONFIG', {})
SOCIAL_AUTH_TPA_SAML_SP_PUBLIC_CERT = AUTH_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_PUBLIC_CERT')
SOCIAL_AUTH_TPA_SAML_SP_PRIVATE_KEY = AUTH_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_PRIVATE_KEY')
##### OAUTH2 Provider ############## ##### OAUTH2 Provider ##############
if FEATURES.get('ENABLE_OAUTH2_PROVIDER'): if FEATURES.get('ENABLE_OAUTH2_PROVIDER'):
OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER'] OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER']
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
<h2 class="sr">${_("Could Not Link Accounts")}</h2> <h2 class="sr">${_("Could Not Link Accounts")}</h2>
<div class="copy"> <div class="copy">
## Translators: this message is displayed when a user tries to link their account with a third-party authentication provider (for example, Google or LinkedIn) with a given edX account, but their third-party account is already associated with another edX account. provider_name is the name of the third-party authentication provider, and platform_name is the name of the edX deployment. ## Translators: this message is displayed when a user tries to link their account with a third-party authentication provider (for example, Google or LinkedIn) with a given edX account, but their third-party account is already associated with another edX account. provider_name is the name of the third-party authentication provider, and platform_name is the name of the edX deployment.
<p>${_("The {provider_name} account you selected is already linked to another {platform_name} account.").format(provider_name='<strong>{duplicate_provider}</strong>'.format(duplicate_provider=duplicate_provider.NAME), platform_name=platform_name)}</p> <p>${_("The {provider_name} account you selected is already linked to another {platform_name} account.").format(provider_name=duplicate_provider, platform_name=platform_name)}</p>
</div> </div>
</div> </div>
</div> </div>
......
...@@ -22,7 +22,7 @@ from third_party_auth import pipeline ...@@ -22,7 +22,7 @@ from third_party_auth import pipeline
<span class="provider">${state.provider.NAME}</span> <span class="provider">${state.provider.NAME}</span>
<span class="control"> <span class="control">
<form <form
action="${pipeline.get_disconnect_url(state.provider.NAME)}" action="${pipeline.get_disconnect_url(state.provider.NAME, state.association_id)}"
method="post" method="post"
name="${state.get_unlink_form_name()}"> name="${state.get_unlink_form_name()}">
% if state.has_account: % if state.has_account:
......
...@@ -720,7 +720,7 @@ class RegistrationView(APIView): ...@@ -720,7 +720,7 @@ class RegistrationView(APIView):
if third_party_auth.is_enabled(): if third_party_auth.is_enabled():
running_pipeline = third_party_auth.pipeline.get(request) running_pipeline = third_party_auth.pipeline.get(request)
if running_pipeline: if running_pipeline:
current_provider = third_party_auth.provider.Registry.get_by_backend_name(running_pipeline.get('backend')) current_provider = third_party_auth.provider.Registry.get_from_pipeline(running_pipeline)
# Override username / email / full name # Override username / email / full name
field_overrides = current_provider.get_register_form_data( field_overrides = current_provider.get_register_form_data(
......
...@@ -69,7 +69,7 @@ pyparsing==2.0.1 ...@@ -69,7 +69,7 @@ pyparsing==2.0.1
python-memcached==1.48 python-memcached==1.48
python-openid==2.2.5 python-openid==2.2.5
python-dateutil==2.1 python-dateutil==2.1
python-social-auth==0.2.7 # python-social-auth==0.2.7 was here but is temporarily moved to github.txt
pytz==2015.2 pytz==2015.2
pysrt==0.4.7 pysrt==0.4.7
PyYAML==3.10 PyYAML==3.10
......
...@@ -30,6 +30,9 @@ git+https://github.com/pmitros/pyfs.git@96e1922348bfe6d99201b9512a9ed946c87b7e0b ...@@ -30,6 +30,9 @@ git+https://github.com/pmitros/pyfs.git@96e1922348bfe6d99201b9512a9ed946c87b7e0b
git+https://github.com/hmarr/django-debug-toolbar-mongo.git@b0686a76f1ce3532088c4aee6e76b9abe61cc808 git+https://github.com/hmarr/django-debug-toolbar-mongo.git@b0686a76f1ce3532088c4aee6e76b9abe61cc808
# custom opaque-key implementations for ccx # custom opaque-key implementations for ccx
-e git+https://github.com/jazkarta/ccx-keys.git@e6b03704b1bb97c1d2f31301ecb4e3a687c536ea#egg=ccx-keys -e git+https://github.com/jazkarta/ccx-keys.git@e6b03704b1bb97c1d2f31301ecb4e3a687c536ea#egg=ccx-keys
# For SAML Support (To be moved to PyPi installation in base.txt once our changes are merged):
-e git+https://github.com/open-craft/python-saml.git@9602b8133056d8c3caa7c3038761147df3d4b257#egg=python-saml
-e git+https://github.com/open-craft/python-social-auth.git@17def186d4bb7165f9c37037936997ef39ae2f29#egg=python-social-auth
# Our libraries: # Our libraries:
-e git+https://github.com/edx/XBlock.git@74fdc5a361f48e5596acf3846ca3790a33a05253#egg=XBlock -e git+https://github.com/edx/XBlock.git@74fdc5a361f48e5596acf3846ca3790a33a05253#egg=XBlock
......
...@@ -36,3 +36,5 @@ mysql-client ...@@ -36,3 +36,5 @@ mysql-client
virtualenvwrapper virtualenvwrapper
libgeos-ruby1.8 libgeos-ruby1.8
lynx-cur lynx-cur
libxmlsec1-dev
swig
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment