Commit b4904adc by Braden MacDonald

Use ConfigurationModels for third_party_auth, new metadata fetching - PR 8155

parent caca3e1b
......@@ -196,8 +196,9 @@ def auth_pipeline_urls(auth_entry, redirect_url=None):
return {}
return {
provider.NAME: third_party_auth.pipeline.get_login_url(provider.NAME, auth_entry, redirect_url=redirect_url)
for provider in third_party_auth.provider.Registry.enabled()
provider.provider_id: third_party_auth.pipeline.get_login_url(
provider.provider_id, auth_entry, redirect_url=redirect_url
) for provider in third_party_auth.provider.Registry.enabled()
}
......
......@@ -11,12 +11,12 @@ from django.core.urlresolvers import reverse
from util.testing import UrlResetMixin
from xmodule.modulestore.tests.factories import CourseFactory
from student.tests.factories import CourseModeFactory
from third_party_auth.tests.testutil import ThirdPartyAuthTestMixin
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
# This relies on third party auth being enabled and configured
# in the test settings. See the setting `THIRD_PARTY_AUTH`
# and the feature flag `ENABLE_THIRD_PARTY_AUTH`
# This relies on third party auth being enabled in the test
# settings with the feature flag `ENABLE_THIRD_PARTY_AUTH`
THIRD_PARTY_AUTH_BACKENDS = ["google-oauth2", "facebook"]
THIRD_PARTY_AUTH_PROVIDERS = ["Google", "Facebook"]
......@@ -40,7 +40,7 @@ def _finish_auth_url(params):
@ddt.ddt
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class LoginFormTest(UrlResetMixin, ModuleStoreTestCase):
class LoginFormTest(ThirdPartyAuthTestMixin, UrlResetMixin, ModuleStoreTestCase):
"""Test rendering of the login form. """
@patch.dict(settings.FEATURES, {"ENABLE_COMBINED_LOGIN_REGISTRATION": False})
def setUp(self):
......@@ -50,6 +50,8 @@ class LoginFormTest(UrlResetMixin, ModuleStoreTestCase):
self.course = CourseFactory.create()
self.course_id = unicode(self.course.id)
self.courseware_url = reverse("courseware", args=[self.course_id])
self.configure_google_provider(enabled=True)
self.configure_facebook_provider(enabled=True)
@patch.dict(settings.FEATURES, {"ENABLE_THIRD_PARTY_AUTH": False})
@ddt.data(THIRD_PARTY_AUTH_PROVIDERS)
......@@ -148,7 +150,7 @@ class LoginFormTest(UrlResetMixin, ModuleStoreTestCase):
@ddt.ddt
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class RegisterFormTest(UrlResetMixin, ModuleStoreTestCase):
class RegisterFormTest(ThirdPartyAuthTestMixin, UrlResetMixin, ModuleStoreTestCase):
"""Test rendering of the registration form. """
@patch.dict(settings.FEATURES, {"ENABLE_COMBINED_LOGIN_REGISTRATION": False})
def setUp(self):
......@@ -157,6 +159,8 @@ class RegisterFormTest(UrlResetMixin, ModuleStoreTestCase):
self.url = reverse("register_user")
self.course = CourseFactory.create()
self.course_id = unicode(self.course.id)
self.configure_google_provider(enabled=True)
self.configure_facebook_provider(enabled=True)
@patch.dict(settings.FEATURES, {"ENABLE_THIRD_PARTY_AUTH": False})
@ddt.data(*THIRD_PARTY_AUTH_PROVIDERS)
......
......@@ -427,7 +427,7 @@ def register_user(request, extra_context=None):
current_provider = provider.Registry.get_from_pipeline(running_pipeline)
overrides = current_provider.get_register_form_data(running_pipeline.get('kwargs'))
overrides['running_pipeline'] = running_pipeline
overrides['selected_provider'] = current_provider.NAME
overrides['selected_provider'] = current_provider.name
context.update(overrides)
return render_to_response('register.html', context)
......@@ -964,12 +964,12 @@ def login_user(request, error=""): # pylint: disable-msg=too-many-statements,un
username=username, backend_name=backend_name))
return HttpResponse(
_("You've successfully logged into your {provider_name} account, but this account isn't linked with an {platform_name} account yet.").format(
platform_name=settings.PLATFORM_NAME, provider_name=requested_provider.NAME
platform_name=settings.PLATFORM_NAME, provider_name=requested_provider.name
)
+ "<br/><br/>" +
_("Use your {platform_name} username and password to log into {platform_name} below, "
"and then link your {platform_name} account with {provider_name} from your dashboard.").format(
platform_name=settings.PLATFORM_NAME, provider_name=requested_provider.NAME
platform_name=settings.PLATFORM_NAME, provider_name=requested_provider.name
)
+ "<br/><br/>" +
_("If you don't have an {platform_name} account yet, click <strong>Register Now</strong> at the top of the page.").format(
......@@ -1511,7 +1511,7 @@ def create_account_with_params(request, params):
if third_party_auth.is_enabled() and pipeline.running(request):
running_pipeline = pipeline.get(request)
current_provider = provider.Registry.get_from_pipeline(running_pipeline)
provider_name = current_provider.NAME
provider_name = current_provider.name
analytics.track(
user.id,
......
# -*- coding: utf-8 -*-
"""
Admin site configuration for third party authentication
"""
from django.contrib import admin
from config_models.admin import ConfigurationModelAdmin, KeyedConfigurationModelAdmin
from .models import OAuth2ProviderConfig, SAMLProviderConfig, SAMLConfiguration, SAMLProviderData
admin.site.register(OAuth2ProviderConfig, KeyedConfigurationModelAdmin)
class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
""" Django Admin class for SAMLProviderConfig """
def get_list_display(self, request):
""" Don't show every single field in the admin change list """
return (
'name', 'enabled', 'backend_name', 'entity_id', 'metadata_source',
'has_data', 'icon_class', 'change_date', 'changed_by', 'edit_link'
)
def has_data(self, inst):
""" Do we have cached metadata for this SAML provider? """
if not inst.is_active:
return None # N/A
data = SAMLProviderData.current(inst.entity_id)
return bool(data and data.is_valid())
has_data.short_description = u'Metadata Ready'
has_data.boolean = True
admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin)
class SAMLConfigurationAdmin(ConfigurationModelAdmin):
""" Django Admin class for SAMLConfiguration """
def get_list_display(self, request):
""" Shorten the public/private keys in the change view """
return (
'change_date', 'changed_by', 'enabled', 'entity_id',
'org_info_str', 'key_summary',
)
def key_summary(self, inst):
""" Short summary of the key pairs configured """
if not inst.public_key or not inst.private_key:
return u'<em>Key pair incomplete/missing</em>'
pub1, pub2 = inst.public_key[0:10], inst.public_key[-10:]
priv1, priv2 = inst.private_key[0:10], inst.private_key[-10:]
return u'Public: {}…{}<br>Private: {}…{}'.format(pub1, pub2, priv1, priv2)
key_summary.allow_tags = True
admin.site.register(SAMLConfiguration, SAMLConfigurationAdmin)
class SAMLProviderDataAdmin(admin.ModelAdmin):
""" Django Admin class for SAMLProviderData """
list_display = ('entity_id', 'is_valid', 'fetched_at', 'expires_at', 'sso_url')
readonly_fields = ('is_valid', )
def get_readonly_fields(self, request, obj=None):
if obj: # editing an existing object
return self.model._meta.get_all_field_names() # pylint: disable=protected-access
return self.readonly_fields
admin.site.register(SAMLProviderData, SAMLProviderDataAdmin)
"""
DummyProvider: A fake Third Party Auth provider for testing & development purposes.
DummyBackend: A fake Third Party Auth provider for testing & development purposes.
"""
from social.backends.base import BaseAuth
from social.backends.oauth import BaseOAuth2
from social.exceptions import AuthFailed
from .provider import BaseProvider
class DummyBackend(BaseAuth): # pylint: disable=abstract-method
class DummyBackend(BaseOAuth2): # pylint: disable=abstract-method
"""
python-social-auth backend that doesn't actually go to any third party site
"""
......@@ -47,12 +45,3 @@ class DummyBackend(BaseAuth): # pylint: disable=abstract-method
kwargs.update({'response': response, 'backend': self})
return self.strategy.authenticate(*args, **kwargs)
class DummyProvider(BaseProvider):
""" Dummy Provider for testing and development """
BACKEND_CLASS = DummyBackend
ICON_CLASS = 'fa-cube'
NAME = 'Dummy'
SETTINGS = {}
# -*- coding: utf-8 -*-
"""
Management commands for third_party_auth
"""
import datetime
import dateutil.parser
from django.core.management.base import BaseCommand, CommandError
from lxml import etree
import requests
from onelogin.saml2.utils import OneLogin_Saml2_Utils
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
#pylint: disable=superfluous-parens,no-member
class MetadataParseError(Exception):
""" An error occurred while parsing the SAML metadata from an IdP """
pass
class Command(BaseCommand):
""" manage.py commands to manage SAML/Shibboleth SSO """
help = '''Configure/maintain/update SAML-based SSO'''
def handle(self, *args, **options):
if len(args) != 1:
raise CommandError("saml requires one argument: pull")
if not SAMLConfiguration.is_enabled():
self.stdout.write("Warning: SAML support is disabled via SAMLConfiguration.\n")
subcommand = args[0]
if subcommand == "pull":
self.cmd_pull()
else:
raise CommandError("Unknown argment: {}".format(subcommand))
@staticmethod
def tag_name(tag_name):
""" Get the namespaced-qualified name for an XML tag """
return '{urn:oasis:names:tc:SAML:2.0:metadata}' + tag_name
def cmd_pull(self):
""" Fetch the metadata for each provider and update the DB """
# First make a list of all the metadata XML URLs:
url_map = {}
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
config = SAMLProviderConfig.current(idp_slug)
if not config.enabled:
continue
url = config.metadata_source
if url not in url_map:
url_map[url] = []
if config.entity_id not in url_map[url]:
url_map[url].append(config.entity_id)
# Now fetch the metadata:
for url, entity_ids in url_map.items():
try:
self.stdout.write("\n→ Fetching {}\n".format(url))
if not url.lower().startswith('https'):
self.stdout.write("→ WARNING: This URL is not secure! It should use HTTPS.\n")
response = requests.get(url, verify=True) # May raise HTTPError or SSLError
response.raise_for_status() # May raise an HTTPError
try:
parser = etree.XMLParser(remove_comments=True)
xml = etree.fromstring(response.text, parser)
except etree.XMLSyntaxError:
raise
# TODO: Can use OneLogin_Saml2_Utils to validate signed XML if anyone is using that
for entity_id in entity_ids:
self.stdout.write("→ Processing IdP with entityID {}\n".format(entity_id))
public_key, sso_url, expires_at = self._parse_metadata_xml(xml, entity_id)
self._update_data(entity_id, public_key, sso_url, expires_at)
except Exception as err: # pylint: disable=broad-except
self.stderr.write("→ ERROR: {}\n\n".format(err.message))
@classmethod
def _parse_metadata_xml(cls, xml, entity_id):
"""
Given an XML document containing SAML 2.0 metadata, parse it and return a tuple of
(public_key, sso_url, expires_at) for the specified entityID.
Raises MetadataParseError if anything is wrong.
"""
if xml.tag == cls.tag_name('EntityDescriptor'):
entity_desc = xml
else:
if xml.tag != cls.tag_name('EntitiesDescriptor'):
raise MetadataParseError("Expected root element to be <EntitiesDescriptor>, not {}".format(xml.tag))
entity_desc = xml.find(".//{}[@entityID='{}']".format(cls.tag_name('EntityDescriptor'), entity_id))
if not entity_desc:
raise MetadataParseError("Can't find EntityDescriptor for entityID {}".format(entity_id))
expires_at = None
if "validUntil" in xml.attrib:
expires_at = dateutil.parser.parse(xml.attrib["validUntil"])
if "cacheDuration" in xml.attrib:
cache_expires = OneLogin_Saml2_Utils.parse_duration(xml.attrib["cacheDuration"])
if expires_at is None or cache_expires < expires_at:
expires_at = cache_expires
sso_desc = entity_desc.find(cls.tag_name("IDPSSODescriptor"))
if not sso_desc:
raise MetadataParseError("IDPSSODescriptor missing")
if 'urn:oasis:names:tc:SAML:2.0:protocol' not in sso_desc.get("protocolSupportEnumeration"):
raise MetadataParseError("This IdP does not support SAML 2.0")
# Now we just need to get the public_key and sso_url
public_key = sso_desc.findtext("./{}//{}".format(
cls.tag_name("KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
))
if not public_key:
raise MetadataParseError("Public Key missing. Expected an <X509Certificate>")
public_key = public_key.replace(" ", "")
binding_elements = sso_desc.iterfind("./{}".format(cls.tag_name("SingleSignOnService")))
sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements}
try:
# The only binding supported by python-saml and python-social-auth is HTTP-Redirect:
sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect']
except KeyError:
raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.")
return public_key, sso_url, expires_at
def _update_data(self, entity_id, public_key, sso_url, expires_at):
"""
Update/Create the SAMLProviderData for the given entity ID.
"""
data_obj = SAMLProviderData.current(entity_id)
fetched_at = datetime.datetime.now()
if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url):
data_obj.expires_at = expires_at
data_obj.fetched_at = fetched_at
data_obj.save()
self.stdout.write("→ Updated existing SAMLProviderData. Nothing has changed.\n")
else:
SAMLProviderData.objects.create(
entity_id=entity_id,
fetched_at=fetched_at,
expires_at=expires_at,
sso_url=sso_url,
public_key=public_key,
)
self.stdout.write("→ Created new record for SAMLProviderData\n")
# -*- coding: utf-8 -*-
from south.utils import datetime_utils as datetime
from south.db import db
from south.v2 import SchemaMigration
from django.db import models
class Migration(SchemaMigration):
def forwards(self, orm):
# Adding model 'OAuth2ProviderConfig'
db.create_table('third_party_auth_oauth2providerconfig', (
('id', self.gf('django.db.models.fields.AutoField')(primary_key=True)),
('change_date', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)),
('changed_by', self.gf('django.db.models.fields.related.ForeignKey')(to=orm['auth.User'], null=True, on_delete=models.PROTECT)),
('enabled', self.gf('django.db.models.fields.BooleanField')(default=False)),
('icon_class', self.gf('django.db.models.fields.CharField')(default='fa-sign-in', max_length=50)),
('name', self.gf('django.db.models.fields.CharField')(max_length=50)),
('backend_name', self.gf('django.db.models.fields.CharField')(max_length=50, db_index=True)),
('key', self.gf('django.db.models.fields.TextField')(blank=True)),
('secret', self.gf('django.db.models.fields.TextField')(blank=True)),
('other_settings', self.gf('django.db.models.fields.TextField')(blank=True)),
))
db.send_create_signal('third_party_auth', ['OAuth2ProviderConfig'])
# Adding model 'SAMLProviderConfig'
db.create_table('third_party_auth_samlproviderconfig', (
('id', self.gf('django.db.models.fields.AutoField')(primary_key=True)),
('change_date', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)),
('changed_by', self.gf('django.db.models.fields.related.ForeignKey')(to=orm['auth.User'], null=True, on_delete=models.PROTECT)),
('enabled', self.gf('django.db.models.fields.BooleanField')(default=False)),
('icon_class', self.gf('django.db.models.fields.CharField')(default='fa-sign-in', max_length=50)),
('name', self.gf('django.db.models.fields.CharField')(max_length=50)),
('backend_name', self.gf('django.db.models.fields.CharField')(default='tpa-saml', max_length=50)),
('idp_slug', self.gf('django.db.models.fields.SlugField')(max_length=30)),
('entity_id', self.gf('django.db.models.fields.CharField')(max_length=255)),
('metadata_source', self.gf('django.db.models.fields.CharField')(max_length=255)),
('attr_user_permanent_id', self.gf('django.db.models.fields.CharField')(max_length=128, blank=True)),
('attr_full_name', self.gf('django.db.models.fields.CharField')(max_length=128, blank=True)),
('attr_first_name', self.gf('django.db.models.fields.CharField')(max_length=128, blank=True)),
('attr_last_name', self.gf('django.db.models.fields.CharField')(max_length=128, blank=True)),
('attr_username', self.gf('django.db.models.fields.CharField')(max_length=128, blank=True)),
('attr_email', self.gf('django.db.models.fields.CharField')(max_length=128, blank=True)),
('other_settings', self.gf('django.db.models.fields.TextField')(blank=True)),
))
db.send_create_signal('third_party_auth', ['SAMLProviderConfig'])
# Adding model 'SAMLConfiguration'
db.create_table('third_party_auth_samlconfiguration', (
('id', self.gf('django.db.models.fields.AutoField')(primary_key=True)),
('change_date', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)),
('changed_by', self.gf('django.db.models.fields.related.ForeignKey')(to=orm['auth.User'], null=True, on_delete=models.PROTECT)),
('enabled', self.gf('django.db.models.fields.BooleanField')(default=False)),
('private_key', self.gf('django.db.models.fields.TextField')()),
('public_key', self.gf('django.db.models.fields.TextField')()),
('entity_id', self.gf('django.db.models.fields.CharField')(default='http://saml.example.com', max_length=255)),
('org_info_str', self.gf('django.db.models.fields.TextField')(default='{"en-US": {"url": "http://www.example.com", "displayname": "Example Inc.", "name": "example"}}')),
('other_config_str', self.gf('django.db.models.fields.TextField')(default='{\n"SECURITY_CONFIG": {"metadataCacheDuration": 604800, "signMetadata": false}\n}')),
))
db.send_create_signal('third_party_auth', ['SAMLConfiguration'])
# Adding model 'SAMLProviderData'
db.create_table('third_party_auth_samlproviderdata', (
('id', self.gf('django.db.models.fields.AutoField')(primary_key=True)),
('fetched_at', self.gf('django.db.models.fields.DateTimeField')(db_index=True)),
('expires_at', self.gf('django.db.models.fields.DateTimeField')(null=True, db_index=True)),
('entity_id', self.gf('django.db.models.fields.CharField')(max_length=255, db_index=True)),
('sso_url', self.gf('django.db.models.fields.URLField')(max_length=200)),
('public_key', self.gf('django.db.models.fields.TextField')()),
))
db.send_create_signal('third_party_auth', ['SAMLProviderData'])
def backwards(self, orm):
# Deleting model 'OAuth2ProviderConfig'
db.delete_table('third_party_auth_oauth2providerconfig')
# Deleting model 'SAMLProviderConfig'
db.delete_table('third_party_auth_samlproviderconfig')
# Deleting model 'SAMLConfiguration'
db.delete_table('third_party_auth_samlconfiguration')
# Deleting model 'SAMLProviderData'
db.delete_table('third_party_auth_samlproviderdata')
models = {
'auth.group': {
'Meta': {'object_name': 'Group'},
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'name': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '80'}),
'permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'})
},
'auth.permission': {
'Meta': {'ordering': "('content_type__app_label', 'content_type__model', 'codename')", 'unique_together': "(('content_type', 'codename'),)", 'object_name': 'Permission'},
'codename': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
'content_type': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['contenttypes.ContentType']"}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'})
},
'auth.user': {
'Meta': {'object_name': 'User'},
'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}),
'first_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}),
'groups': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Group']", 'symmetrical': 'False', 'blank': 'True'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'is_active': ('django.db.models.fields.BooleanField', [], {'default': 'True'}),
'is_staff': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'is_superuser': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'last_login': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
'last_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}),
'password': ('django.db.models.fields.CharField', [], {'max_length': '128'}),
'user_permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}),
'username': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '30'})
},
'contenttypes.contenttype': {
'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"},
'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '100'})
},
'third_party_auth.oauth2providerconfig': {
'Meta': {'object_name': 'OAuth2ProviderConfig'},
'backend_name': ('django.db.models.fields.CharField', [], {'max_length': '50', 'db_index': 'True'}),
'change_date': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}),
'changed_by': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['auth.User']", 'null': 'True', 'on_delete': 'models.PROTECT'}),
'enabled': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'icon_class': ('django.db.models.fields.CharField', [], {'default': "'fa-sign-in'", 'max_length': '50'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'key': ('django.db.models.fields.TextField', [], {'blank': 'True'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'}),
'other_settings': ('django.db.models.fields.TextField', [], {'blank': 'True'}),
'secret': ('django.db.models.fields.TextField', [], {'blank': 'True'})
},
'third_party_auth.samlconfiguration': {
'Meta': {'object_name': 'SAMLConfiguration'},
'change_date': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}),
'changed_by': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['auth.User']", 'null': 'True', 'on_delete': 'models.PROTECT'}),
'enabled': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'entity_id': ('django.db.models.fields.CharField', [], {'default': "'http://saml.example.com'", 'max_length': '255'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'org_info_str': ('django.db.models.fields.TextField', [], {'default': '\'{"en-US": {"url": "http://www.example.com", "displayname": "Example Inc.", "name": "example"}}\''}),
'other_config_str': ('django.db.models.fields.TextField', [], {'default': '\'{\\n"SECURITY_CONFIG": {"metadataCacheDuration": 604800, "signMetadata": false}\\n}\''}),
'private_key': ('django.db.models.fields.TextField', [], {}),
'public_key': ('django.db.models.fields.TextField', [], {})
},
'third_party_auth.samlproviderconfig': {
'Meta': {'object_name': 'SAMLProviderConfig'},
'attr_email': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'attr_first_name': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'attr_full_name': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'attr_last_name': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'attr_user_permanent_id': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'attr_username': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'backend_name': ('django.db.models.fields.CharField', [], {'default': "'tpa-saml'", 'max_length': '50'}),
'change_date': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}),
'changed_by': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['auth.User']", 'null': 'True', 'on_delete': 'models.PROTECT'}),
'enabled': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'entity_id': ('django.db.models.fields.CharField', [], {'max_length': '255'}),
'icon_class': ('django.db.models.fields.CharField', [], {'default': "'fa-sign-in'", 'max_length': '50'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'idp_slug': ('django.db.models.fields.SlugField', [], {'max_length': '30'}),
'metadata_source': ('django.db.models.fields.CharField', [], {'max_length': '255'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'}),
'other_settings': ('django.db.models.fields.TextField', [], {'blank': 'True'})
},
'third_party_auth.samlproviderdata': {
'Meta': {'ordering': "('-fetched_at',)", 'object_name': 'SAMLProviderData'},
'entity_id': ('django.db.models.fields.CharField', [], {'max_length': '255', 'db_index': 'True'}),
'expires_at': ('django.db.models.fields.DateTimeField', [], {'null': 'True', 'db_index': 'True'}),
'fetched_at': ('django.db.models.fields.DateTimeField', [], {'db_index': 'True'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'public_key': ('django.db.models.fields.TextField', [], {}),
'sso_url': ('django.db.models.fields.URLField', [], {'max_length': '200'})
}
}
complete_apps = ['third_party_auth']
\ No newline at end of file
# -*- coding: utf-8 -*-
from django.conf import settings
import json
from south.v2 import DataMigration
class Migration(DataMigration):
def forwards(self, orm):
""" Convert from the THIRD_PARTY_AUTH setting to OAuth2ProviderConfig """
tpa = getattr(settings, 'THIRD_PARTY_AUTH_OLD_CONFIG', {})
if tpa and not any(orm.OAuth2ProviderConfig.objects.all()):
print("Migrating third party auth config to OAuth2ProviderConfig")
providers = (
# Name, backend, icon, prefix
('Google', 'google-oauth2', 'fa-google-plus', 'SOCIAL_AUTH_GOOGLE_OAUTH2_'),
('LinkedIn', 'linkedin-oauth2', 'fa-linkedin', 'SOCIAL_AUTH_LINKEDIN_OAUTH2_'),
('Facebook', 'facebook', 'fa-facebook', 'SOCIAL_AUTH_FACEBOOK_'),
)
for name, backend, icon, prefix in providers:
if name in tpa:
conf = tpa[name]
conf = {key.replace(prefix, ''): val for key, val in conf.items()}
key = conf.pop('KEY', '')
secret = conf.pop('SECRET', '')
orm.OAuth2ProviderConfig.objects.create(
enabled=True,
name=name,
backend_name=backend,
icon_class=icon,
key=key,
secret=secret,
other_settings=json.dumps(conf),
changed_by=None,
)
print(
"Done. Make changes via /admin/third_party_auth/oauth2providerconfig/ "
"from now on. You can remove THIRD_PARTY_AUTH from ~/lms.auth.json"
)
else:
print("Not migrating third party auth config to OAuth2ProviderConfig.")
def backwards(self, orm):
""" No backwards migration necessary """
pass
models = {
'auth.group': {
'Meta': {'object_name': 'Group'},
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'name': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '80'}),
'permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'})
},
'auth.permission': {
'Meta': {'ordering': "('content_type__app_label', 'content_type__model', 'codename')", 'unique_together': "(('content_type', 'codename'),)", 'object_name': 'Permission'},
'codename': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
'content_type': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['contenttypes.ContentType']"}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'})
},
'auth.user': {
'Meta': {'object_name': 'User'},
'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}),
'first_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}),
'groups': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Group']", 'symmetrical': 'False', 'blank': 'True'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'is_active': ('django.db.models.fields.BooleanField', [], {'default': 'True'}),
'is_staff': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'is_superuser': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'last_login': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
'last_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}),
'password': ('django.db.models.fields.CharField', [], {'max_length': '128'}),
'user_permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}),
'username': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '30'})
},
'contenttypes.contenttype': {
'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"},
'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '100'})
},
'third_party_auth.oauth2providerconfig': {
'Meta': {'object_name': 'OAuth2ProviderConfig'},
'backend_name': ('django.db.models.fields.CharField', [], {'max_length': '50', 'db_index': 'True'}),
'change_date': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}),
'changed_by': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['auth.User']", 'null': 'True', 'on_delete': 'models.PROTECT'}),
'enabled': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'icon_class': ('django.db.models.fields.CharField', [], {'default': "'fa-sign-in'", 'max_length': '50'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'key': ('django.db.models.fields.TextField', [], {'blank': 'True'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'}),
'other_settings': ('django.db.models.fields.TextField', [], {'blank': 'True'}),
'secret': ('django.db.models.fields.TextField', [], {'blank': 'True'})
},
'third_party_auth.samlconfiguration': {
'Meta': {'object_name': 'SAMLConfiguration'},
'change_date': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}),
'changed_by': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['auth.User']", 'null': 'True', 'on_delete': 'models.PROTECT'}),
'enabled': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'entity_id': ('django.db.models.fields.CharField', [], {'default': "'http://saml.example.com'", 'max_length': '255'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'org_info_str': ('django.db.models.fields.TextField', [], {'default': '\'{"en-US": {"url": "http://www.example.com", "displayname": "Example Inc.", "name": "example"}}\''}),
'other_config_str': ('django.db.models.fields.TextField', [], {'default': '\'{\\n"SECURITY_CONFIG": {"metadataCacheDuration": 604800, "signMetadata": false}\\n}\''}),
'private_key': ('django.db.models.fields.TextField', [], {}),
'public_key': ('django.db.models.fields.TextField', [], {})
},
'third_party_auth.samlproviderconfig': {
'Meta': {'object_name': 'SAMLProviderConfig'},
'attr_email': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'attr_first_name': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'attr_full_name': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'attr_last_name': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'attr_user_permanent_id': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'attr_username': ('django.db.models.fields.CharField', [], {'max_length': '128', 'blank': 'True'}),
'backend_name': ('django.db.models.fields.CharField', [], {'default': "'tpa-saml'", 'max_length': '50'}),
'change_date': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}),
'changed_by': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['auth.User']", 'null': 'True', 'on_delete': 'models.PROTECT'}),
'enabled': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'entity_id': ('django.db.models.fields.CharField', [], {'max_length': '255'}),
'icon_class': ('django.db.models.fields.CharField', [], {'default': "'fa-sign-in'", 'max_length': '50'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'idp_slug': ('django.db.models.fields.SlugField', [], {'max_length': '30'}),
'metadata_source': ('django.db.models.fields.CharField', [], {'max_length': '255'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'}),
'other_settings': ('django.db.models.fields.TextField', [], {'blank': 'True'})
},
'third_party_auth.samlproviderdata': {
'Meta': {'ordering': "('-fetched_at',)", 'object_name': 'SAMLProviderData'},
'entity_id': ('django.db.models.fields.CharField', [], {'max_length': '255', 'db_index': 'True'}),
'expires_at': ('django.db.models.fields.DateTimeField', [], {'null': 'True', 'db_index': 'True'}),
'fetched_at': ('django.db.models.fields.DateTimeField', [], {'db_index': 'True'}),
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'public_key': ('django.db.models.fields.TextField', [], {}),
'sso_url': ('django.db.models.fields.URLField', [], {'max_length': '200'})
}
}
complete_apps = ['third_party_auth']
symmetrical = True
# -*- coding: utf-8 -*-
"""
Models used to implement SAML SSO support in third_party_auth
(inlcuding Shibboleth support)
"""
from config_models.models import ConfigurationModel, cache
from django.conf import settings
from django.core.exceptions import ValidationError
from django.db import models
from django.utils import timezone
import json
import logging
from social.backends.base import BaseAuth
from social.backends.oauth import BaseOAuth2
from social.backends.saml import SAMLAuth, SAMLIdentityProvider
from social.exceptions import SocialAuthBaseException
from social.utils import module_member
log = logging.getLogger(__name__)
# A dictionary of {name: class} entries for each python-social-auth backend available.
# Because this setting can specify arbitrary code to load and execute, it is set via
# normal Django settings only and cannot be changed at runtime:
def _load_backend_classes(base_class=BaseAuth):
""" Load the list of python-social-auth backend classes from Django settings """
for class_path in settings.AUTHENTICATION_BACKENDS:
auth_class = module_member(class_path)
if issubclass(auth_class, base_class):
yield auth_class
_PSA_BACKENDS = {backend_class.name: backend_class for backend_class in _load_backend_classes()}
_PSA_OAUTH2_BACKENDS = [backend_class.name for backend_class in _load_backend_classes(BaseOAuth2)]
_PSA_SAML_BACKENDS = [backend_class.name for backend_class in _load_backend_classes(SAMLAuth)]
def clean_json(value, of_type):
""" Simple helper method to parse and clean JSON """
if not value.strip():
return json.dumps(of_type())
try:
value_python = json.loads(value)
except ValueError as err:
raise ValidationError("Invalid JSON: {}".format(err.message))
if not isinstance(value_python, of_type):
raise ValidationError("Expected a JSON {}".format(of_type))
return json.dumps(value_python, indent=4)
class AuthNotConfigured(SocialAuthBaseException):
""" Exception when SAMLProviderData or other required info is missing """
def __init__(self, provider_name):
super(AuthNotConfigured, self).__init__()
self.provider_name = provider_name
def __str__(self):
return 'Authentication with {} is currently unavailable.'.format(
self.provider_name
)
class ProviderConfig(ConfigurationModel):
"""
Abstract Base Class for configuring a third_party_auth provider
"""
icon_class = models.CharField(
max_length=50, default='fa-sign-in',
help_text=(
'The Font Awesome (or custom) icon class to use on the login button for this provider. '
'Examples: fa-google-plus, fa-facebook, fa-linkedin, fa-sign-in, fa-university'
))
name = models.CharField(max_length=50, blank=False, help_text="Name of this provider (shown to users)")
prefix = None # used for provider_id. Set to a string value in subclass
backend_name = None # Set to a field or fixed value in subclass
# "enabled" field is inherited from ConfigurationModel
class Meta(object): # pylint: disable=missing-docstring
abstract = True
@property
def provider_id(self):
""" Unique string key identifying this provider. Must be URL and css class friendly. """
assert self.prefix is not None
return "-".join((self.prefix, ) + tuple(getattr(self, field) for field in self.KEY_FIELDS))
@property
def backend_class(self):
""" Get the python-social-auth backend class used for this provider """
return _PSA_BACKENDS[self.backend_name]
def get_url_params(self):
""" Get a dict of GET parameters to append to login links for this provider """
return {}
def is_active_for_pipeline(self, pipeline):
""" Is this provider being used for the specified pipeline? """
return self.backend_name == pipeline['backend']
def match_social_auth(self, social_auth):
""" Is this provider being used for this UserSocialAuth entry? """
return self.backend_name == social_auth.provider
@classmethod
def get_register_form_data(cls, pipeline_kwargs):
"""Gets dict of data to display on the register form.
common.djangoapps.student.views.register_user uses this to populate the
new account creation form with values supplied by the user's chosen
provider, preventing duplicate data entry.
Args:
pipeline_kwargs: dict of string -> object. Keyword arguments
accumulated by the pipeline thus far.
Returns:
Dict of string -> string. Keys are names of form fields; values are
values for that field. Where there is no value, the empty string
must be used.
"""
# Details about the user sent back from the provider.
details = pipeline_kwargs.get('details')
# Get the username separately to take advantage of the de-duping logic
# built into the pipeline. The provider cannot de-dupe because it can't
# check the state of taken usernames in our system. Note that there is
# technically a data race between the creation of this value and the
# creation of the user object, so it is still possible for users to get
# an error on submit.
suggested_username = pipeline_kwargs.get('username')
return {
'email': details.get('email', ''),
'name': details.get('fullname', ''),
'username': suggested_username,
}
def get_authentication_backend(self):
"""Gets associated Django settings.AUTHENTICATION_BACKEND string."""
return '{}.{}'.format(self.backend_class.__module__, self.backend_class.__name__)
class OAuth2ProviderConfig(ProviderConfig):
"""
Configuration Entry for an OAuth2 based provider.
"""
prefix = 'oa2'
KEY_FIELDS = ('backend_name', ) # Backend name is unique
backend_name = models.CharField(
max_length=50, choices=[(name, name) for name in _PSA_OAUTH2_BACKENDS], blank=False, db_index=True,
help_text=(
"Which python-social-auth OAuth2 provider backend to use. "
"The list of backend choices is determined by the THIRD_PARTY_AUTH_BACKENDS setting."
# To be precise, it's set by AUTHENTICATION_BACKENDS - which aws.py sets from THIRD_PARTY_AUTH_BACKENDS
)
)
key = models.TextField(blank=True, verbose_name="Client ID")
secret = models.TextField(blank=True, verbose_name="Client Secret")
other_settings = models.TextField(blank=True, help_text="Optional JSON object with advanced settings, if any.")
class Meta(object): # pylint: disable=missing-docstring
verbose_name = "Provider Configuration (OAuth2)"
verbose_name_plural = verbose_name
def clean(self):
""" Standardize and validate fields """
super(OAuth2ProviderConfig, self).clean()
self.other_settings = clean_json(self.other_settings, dict)
def get_setting(self, name):
""" Get the value of a setting, or raise KeyError """
if name in ("KEY", "SECRET"):
return getattr(self, name.lower())
if self.other_settings:
other_settings = json.loads(self.other_settings)
assert isinstance(other_settings, dict), "other_settings should be a JSON object (dictionary)"
return other_settings[name]
raise KeyError
class SAMLProviderConfig(ProviderConfig):
"""
Configuration Entry for a SAML/Shibboleth provider.
"""
prefix = 'saml'
KEY_FIELDS = ('idp_slug', )
backend_name = models.CharField(
max_length=50, default='tpa-saml', choices=[(name, name) for name in _PSA_SAML_BACKENDS], blank=False,
help_text="Which python-social-auth provider backend to use. 'tpa-saml' is the standard edX SAML backend.")
idp_slug = models.SlugField(
max_length=30, db_index=True,
help_text=(
'A short string uniquely identifying this provider. '
'Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"'
))
entity_id = models.CharField(
max_length=255, verbose_name="Entity ID", help_text="Example: https://idp.testshib.org/idp/shibboleth")
metadata_source = models.CharField(
max_length=255,
help_text=(
"URL to this provider's XML metadata. Should be an HTTPS URL. "
"Example: https://www.testshib.org/metadata/testshib-providers.xml"
))
attr_user_permanent_id = models.CharField(
max_length=128, blank=True, verbose_name="User ID Attribute",
help_text="URN of the SAML attribute that we can use as a unique, persistent user ID. Leave blank for default.")
attr_full_name = models.CharField(
max_length=128, blank=True, verbose_name="Full Name Attribute",
help_text="URN of SAML attribute containing the user's full name. Leave blank for default.")
attr_first_name = models.CharField(
max_length=128, blank=True, verbose_name="First Name Attribute",
help_text="URN of SAML attribute containing the user's first name. Leave blank for default.")
attr_last_name = models.CharField(
max_length=128, blank=True, verbose_name="Last Name Attribute",
help_text="URN of SAML attribute containing the user's last name. Leave blank for default.")
attr_username = models.CharField(
max_length=128, blank=True, verbose_name="Username Hint Attribute",
help_text="URN of SAML attribute to use as a suggested username for this user. Leave blank for default.")
attr_email = models.CharField(
max_length=128, blank=True, verbose_name="Email Attribute",
help_text="URN of SAML attribute containing the user's email address[es]. Leave blank for default.")
other_settings = models.TextField(
verbose_name="Advanced settings", blank=True,
help_text=(
'For advanced use cases, enter a JSON object with addtional configuration. '
'The tpa-saml backend supports only {"requiredEntitlements": ["urn:..."]} '
'which can be used to require the presence of a specific eduPersonEntitlement.'
))
def clean(self):
""" Standardize and validate fields """
super(SAMLProviderConfig, self).clean()
self.other_settings = clean_json(self.other_settings, dict)
class Meta(object): # pylint: disable=missing-docstring
verbose_name = "Provider Configuration (SAML IdP)"
verbose_name_plural = "Provider Configuration (SAML IdPs)"
def get_url_params(self):
""" Get a dict of GET parameters to append to login links for this provider """
return {'idp': self.idp_slug}
def is_active_for_pipeline(self, pipeline):
""" Is this provider being used for the specified pipeline? """
return self.backend_name == pipeline['backend'] and self.idp_slug == pipeline['kwargs']['response']['idp_name']
def match_social_auth(self, social_auth):
""" Is this provider being used for this UserSocialAuth entry? """
prefix = self.idp_slug + ":"
return self.backend_name == social_auth.provider and social_auth.uid.startswith(prefix)
def get_config(self):
"""
Return a SAMLIdentityProvider instance for use by SAMLAuthBackend.
Essentially this just returns the values of this object and its
associated 'SAMLProviderData' entry.
"""
if self.other_settings:
conf = json.loads(self.other_settings)
else:
conf = {}
attrs = (
'attr_user_permanent_id', 'attr_full_name', 'attr_first_name',
'attr_last_name', 'attr_username', 'attr_email', 'entity_id')
for field in attrs:
val = getattr(self, field)
if val:
conf[field] = val
# Now get the data fetched automatically from the metadata.xml:
data = SAMLProviderData.current(self.entity_id)
if not data or not data.is_valid():
log.error("No SAMLProviderData found for %s. Run 'manage.py saml pull' to fix or debug.", self.entity_id)
raise AuthNotConfigured(provider_name=self.name)
conf['x509cert'] = data.public_key
conf['url'] = data.sso_url
return SAMLIdentityProvider(self.idp_slug, **conf)
class SAMLConfiguration(ConfigurationModel):
"""
General configuration required for this edX instance to act as a SAML
Service Provider and allow users to authenticate via third party SAML
Identity Providers (IdPs)
"""
private_key = models.TextField(
help_text=(
'To generate a key pair as two files, run '
'"openssl req -new -x509 -days 3652 -nodes -out saml.crt -keyout saml.key". '
'Paste the contents of saml.key here.'
)
)
public_key = models.TextField(help_text="Public key certificate.")
entity_id = models.CharField(max_length=255, default="http://saml.example.com", verbose_name="Entity ID")
org_info_str = models.TextField(
verbose_name="Organization Info",
default='{"en-US": {"url": "http://www.example.com", "displayname": "Example Inc.", "name": "example"}}',
help_text="JSON dictionary of 'url', 'displayname', and 'name' for each language",
)
other_config_str = models.TextField(
default='{\n"SECURITY_CONFIG": {"metadataCacheDuration": 604800, "signMetadata": false}\n}',
help_text=(
"JSON object defining advanced settings that are passed on to python-saml. "
"Valid keys that can be set here include: SECURITY_CONFIG, SP_NAMEID_FORMATS, SP_EXTRA"
),
)
class Meta(object): # pylint: disable=missing-docstring
verbose_name = "SAML Configuration"
verbose_name_plural = verbose_name
def clean(self):
""" Standardize and validate fields """
super(SAMLConfiguration, self).clean()
self.org_info_str = clean_json(self.org_info_str, dict)
self.other_config_str = clean_json(self.other_config_str, dict)
self.private_key = self.private_key.replace("-----BEGIN PRIVATE KEY-----", "").strip()
self.private_key = self.private_key.replace("-----END PRIVATE KEY-----", "").strip()
self.public_key = self.public_key.replace("-----BEGIN CERTIFICATE-----", "").strip()
self.public_key = self.public_key.replace("-----END CERTIFICATE-----", "").strip()
def get_setting(self, name):
""" Get the value of a setting, or raise KeyError """
if name == "ORG_INFO":
return json.loads(self.org_info_str)
if name == "SP_ENTITY_ID":
return self.entity_id
if name == "SP_PUBLIC_CERT":
return self.public_key
if name == "SP_PRIVATE_KEY":
return self.private_key
if name == "TECHNICAL_CONTACT":
return {"givenName": "Technical Support", "emailAddress": settings.TECH_SUPPORT_EMAIL}
if name == "SUPPORT_CONTACT":
return {"givenName": "SAML Support", "emailAddress": settings.TECH_SUPPORT_EMAIL}
other_config = json.loads(self.other_config_str)
return other_config[name] # SECURITY_CONFIG, SP_NAMEID_FORMATS, SP_EXTRA
class SAMLProviderData(models.Model):
"""
Data about a SAML IdP that is fetched automatically by 'manage.py saml pull'
This data is only required during the actual authentication process.
"""
cache_timeout = 600
fetched_at = models.DateTimeField(db_index=True, null=False)
expires_at = models.DateTimeField(db_index=True, null=True)
entity_id = models.CharField(max_length=255, db_index=True) # This is the key for lookups in this table
sso_url = models.URLField(verbose_name="SSO URL")
public_key = models.TextField()
class Meta(object): # pylint: disable=missing-docstring
verbose_name = "SAML Provider Data"
verbose_name_plural = verbose_name
ordering = ('-fetched_at', )
def is_valid(self):
""" Is this data valid? """
if self.expires_at and timezone.now() > self.expires_at:
return False
return bool(self.entity_id and self.sso_url and self.public_key)
is_valid.boolean = True
@classmethod
def cache_key_name(cls, entity_id):
""" Return the name of the key to use to cache the current data """
return 'configuration/{}/current/{}'.format(cls.__name__, entity_id)
@classmethod
def current(cls, entity_id):
"""
Return the active data entry, if any, otherwise None
"""
cached = cache.get(cls.cache_key_name(entity_id))
if cached is not None:
return cached
try:
current = cls.objects.filter(entity_id=entity_id).order_by('-fetched_at')[0]
except IndexError:
current = None
cache.set(cls.cache_key_name(entity_id), current, cls.cache_timeout)
return current
......@@ -209,7 +209,7 @@ class ProviderUserState(object):
def get_unlink_form_name(self):
"""Gets the name used in HTML forms that unlink a provider account."""
return self.provider.NAME + '_unlink_form'
return self.provider.provider_id + '_unlink_form'
def get(request):
......@@ -239,7 +239,7 @@ def get_authenticated_user(auth_provider, username, uid):
user has no social auth associated with the given backend.
AssertionError: if the user is not authenticated.
"""
match = models.DjangoStorage.user.get_social_auth(provider=auth_provider.BACKEND_CLASS.name, uid=uid)
match = models.DjangoStorage.user.get_social_auth(provider=auth_provider.backend_name, uid=uid)
if not match or match.user.username != username:
raise User.DoesNotExist
......@@ -249,12 +249,12 @@ def get_authenticated_user(auth_provider, username, uid):
return user
def _get_enabled_provider_by_name(provider_name):
"""Gets an enabled provider by its NAME member or throws."""
enabled_provider = provider.Registry.get(provider_name)
def _get_enabled_provider(provider_id):
"""Gets an enabled provider by its provider_id member or throws."""
enabled_provider = provider.Registry.get(provider_id)
if not enabled_provider:
raise ValueError('Provider %s not enabled' % provider_name)
raise ValueError('Provider %s not enabled' % provider_id)
return enabled_provider
......@@ -301,11 +301,11 @@ def get_complete_url(backend_name):
return _get_url('social:complete', backend_name)
def get_disconnect_url(provider_name, association_id):
def get_disconnect_url(provider_id, association_id):
"""Gets URL for the endpoint that starts the disconnect pipeline.
Args:
provider_name: string. Name of the provider.BaseProvider child you want
provider_id: string identifier of the models.ProviderConfig child you want
to disconnect from.
association_id: int. Optional ID of a specific row in the UserSocialAuth
table to disconnect (useful if multiple providers use a common backend)
......@@ -314,21 +314,21 @@ def get_disconnect_url(provider_name, association_id):
String. URL that starts the disconnection pipeline.
Raises:
ValueError: if no provider is enabled with the given name.
ValueError: if no provider is enabled with the given ID.
"""
backend_name = _get_enabled_provider_by_name(provider_name).BACKEND_CLASS.name
backend_name = _get_enabled_provider(provider_id).backend_name
if association_id:
return _get_url('social:disconnect_individual', backend_name, url_params={'association_id': association_id})
else:
return _get_url('social:disconnect', backend_name)
def get_login_url(provider_name, auth_entry, redirect_url=None):
def get_login_url(provider_id, auth_entry, redirect_url=None):
"""Gets the login URL for the endpoint that kicks off auth with a provider.
Args:
provider_name: string. The name of the provider.Provider that has been
enabled.
provider_id: string identifier of the models.ProviderConfig child you want
to disconnect from.
auth_entry: string. Query argument specifying the desired entry point
for the auth pipeline. Used by the pipeline for later branching.
Must be one of _AUTH_ENTRY_CHOICES.
......@@ -341,13 +341,13 @@ def get_login_url(provider_name, auth_entry, redirect_url=None):
String. URL that starts the auth pipeline for a provider.
Raises:
ValueError: if no provider is enabled with the given provider_name.
ValueError: if no provider is enabled with the given provider_id.
"""
assert auth_entry in _AUTH_ENTRY_CHOICES
enabled_provider = _get_enabled_provider_by_name(provider_name)
enabled_provider = _get_enabled_provider(provider_id)
return _get_url(
'social:begin',
enabled_provider.BACKEND_CLASS.name,
enabled_provider.backend_name,
auth_entry=auth_entry,
redirect_url=redirect_url,
extra_params=enabled_provider.get_url_params(),
......
"""Third-party auth provider definitions.
Loaded by Django's settings mechanism. Consequently, this module must not
invoke the Django armature.
"""
from social.backends import google, linkedin, facebook
from social.backends.saml import OID_EDU_PERSON_PRINCIPAL_NAME
from .saml import SAMLAuthBackend
_DEFAULT_ICON_CLASS = 'fa-signin'
class BaseProvider(object):
"""Abstract base class for third-party auth providers.
All providers must subclass BaseProvider -- otherwise, they cannot be put
in the provider Registry.
"""
# Class. The provider's backing social.backends.base.BaseAuth child.
BACKEND_CLASS = None
# String. Name of the FontAwesome glyph to use for sign in buttons (or the
# name of a user-supplied custom glyph that is present at runtime).
ICON_CLASS = _DEFAULT_ICON_CLASS
# String. User-facing name of the provider. Must be unique across all
# enabled providers. Will be presented in the UI.
NAME = None
# Dict of string -> object. Settings that will be merged into Django's
# settings instance. In most cases the value will be None, since real
# values are merged from .json files (foo.auth.json; foo.env.json) onto the
# settings instance during application initialization.
SETTINGS = {}
@classmethod
def get_authentication_backend(cls):
"""Gets associated Django settings.AUTHENTICATION_BACKEND string."""
return '%s.%s' % (cls.BACKEND_CLASS.__module__, cls.BACKEND_CLASS.__name__)
@classmethod
def get_email(cls, provider_details):
"""Gets user's email address.
Provider responses can contain arbitrary data. This method can be
overridden to extract an email address from the provider details
extracted by the social_details pipeline step.
Args:
provider_details: dict of string -> string. Data about the
user passed back by the provider.
Returns:
String or None. The user's email address, if any.
"""
return provider_details.get('email')
@classmethod
def get_name(cls, provider_details):
"""Gets user's name.
Provider responses can contain arbitrary data. This method can be
overridden to extract a full name for a user from the provider details
extracted by the social_details pipeline step.
Args:
provider_details: dict of string -> string. Data about the
user passed back by the provider.
Returns:
String or None. The user's full name, if any.
"""
return provider_details.get('fullname')
@classmethod
def get_register_form_data(cls, pipeline_kwargs):
"""Gets dict of data to display on the register form.
common.djangoapps.student.views.register_user uses this to populate the
new account creation form with values supplied by the user's chosen
provider, preventing duplicate data entry.
Args:
pipeline_kwargs: dict of string -> object. Keyword arguments
accumulated by the pipeline thus far.
Returns:
Dict of string -> string. Keys are names of form fields; values are
values for that field. Where there is no value, the empty string
must be used.
"""
# Details about the user sent back from the provider.
details = pipeline_kwargs.get('details')
# Get the username separately to take advantage of the de-duping logic
# built into the pipeline. The provider cannot de-dupe because it can't
# check the state of taken usernames in our system. Note that there is
# technically a data race between the creation of this value and the
# creation of the user object, so it is still possible for users to get
# an error on submit.
suggested_username = pipeline_kwargs.get('username')
return {
'email': cls.get_email(details) or '',
'name': cls.get_name(details) or '',
'username': suggested_username,
}
@classmethod
def merge_onto(cls, settings):
"""Merge class-level settings onto a django settings module."""
for key, value in cls.SETTINGS.iteritems():
setattr(settings, key, value)
@classmethod
def get_url_params(cls):
""" Get a dict of GET parameters to append to login links for this provider """
return {}
@classmethod
def is_active_for_pipeline(cls, pipeline):
""" Is this provider being used for the specified pipeline? """
return cls.BACKEND_CLASS.name == pipeline['backend']
@classmethod
def match_social_auth(cls, social_auth):
""" Is this provider being used for this UserSocialAuth entry? """
return cls.BACKEND_CLASS.name == social_auth.provider
class GoogleOauth2(BaseProvider):
"""Provider for Google's Oauth2 auth system."""
BACKEND_CLASS = google.GoogleOAuth2
ICON_CLASS = 'fa-google-plus'
NAME = 'Google'
SETTINGS = {
'SOCIAL_AUTH_GOOGLE_OAUTH2_KEY': None,
'SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET': None,
}
class LinkedInOauth2(BaseProvider):
"""Provider for LinkedIn's Oauth2 auth system."""
BACKEND_CLASS = linkedin.LinkedinOAuth2
ICON_CLASS = 'fa-linkedin'
NAME = 'LinkedIn'
SETTINGS = {
'SOCIAL_AUTH_LINKEDIN_OAUTH2_KEY': None,
'SOCIAL_AUTH_LINKEDIN_OAUTH2_SECRET': None,
}
class FacebookOauth2(BaseProvider):
"""Provider for LinkedIn's Oauth2 auth system."""
BACKEND_CLASS = facebook.FacebookOAuth2
ICON_CLASS = 'fa-facebook'
NAME = 'Facebook'
SETTINGS = {
'SOCIAL_AUTH_FACEBOOK_KEY': None,
'SOCIAL_AUTH_FACEBOOK_SECRET': None,
}
class SAMLProviderMixin(object):
""" Base class for SAML/Shibboleth providers """
BACKEND_CLASS = SAMLAuthBackend
ICON_CLASS = 'fa-university'
@classmethod
def get_url_params(cls):
""" Get a dict of GET parameters to append to login links for this provider """
return {'idp': cls.IDP["id"]}
@classmethod
def is_active_for_pipeline(cls, pipeline):
""" Is this provider being used for the specified pipeline? """
if cls.BACKEND_CLASS.name == pipeline['backend']:
idp_name = pipeline['kwargs']['response']['idp_name']
return cls.IDP["id"] == idp_name
return False
@classmethod
def match_social_auth(cls, social_auth):
""" Is this provider being used for this UserSocialAuth entry? """
prefix = cls.IDP["id"] + ":"
return cls.BACKEND_CLASS.name == social_auth.provider and social_auth.uid.startswith(prefix)
class TestShibAProvider(SAMLProviderMixin, BaseProvider):
""" Provider for testshib.org public Shibboleth test server. """
NAME = 'TestShib A'
IDP = {
"id": "testshiba", # Required slug
"entity_id": "https://idp.testshib.org/idp/shibboleth",
"url": "https://idp.testshib.org/idp/profile/SAML2/Redirect/SSO",
"attr_email": OID_EDU_PERSON_PRINCIPAL_NAME,
"x509cert": """
MIIEDjCCAvagAwIBAgIBADANBgkqhkiG9w0BAQUFADBnMQswCQYDVQQGEwJVUzEV
MBMGA1UECBMMUGVubnN5bHZhbmlhMRMwEQYDVQQHEwpQaXR0c2J1cmdoMREwDwYD
VQQKEwhUZXN0U2hpYjEZMBcGA1UEAxMQaWRwLnRlc3RzaGliLm9yZzAeFw0wNjA4
MzAyMTEyMjVaFw0xNjA4MjcyMTEyMjVaMGcxCzAJBgNVBAYTAlVTMRUwEwYDVQQI
EwxQZW5uc3lsdmFuaWExEzARBgNVBAcTClBpdHRzYnVyZ2gxETAPBgNVBAoTCFRl
c3RTaGliMRkwFwYDVQQDExBpZHAudGVzdHNoaWIub3JnMIIBIjANBgkqhkiG9w0B
AQEFAAOCAQ8AMIIBCgKCAQEArYkCGuTmJp9eAOSGHwRJo1SNatB5ZOKqDM9ysg7C
yVTDClcpu93gSP10nH4gkCZOlnESNgttg0r+MqL8tfJC6ybddEFB3YBo8PZajKSe
3OQ01Ow3yT4I+Wdg1tsTpSge9gEz7SrC07EkYmHuPtd71CHiUaCWDv+xVfUQX0aT
NPFmDixzUjoYzbGDrtAyCqA8f9CN2txIfJnpHE6q6CmKcoLADS4UrNPlhHSzd614
kR/JYiks0K4kbRqCQF0Dv0P5Di+rEfefC6glV8ysC8dB5/9nb0yh/ojRuJGmgMWH
gWk6h0ihjihqiu4jACovUZ7vVOCgSE5Ipn7OIwqd93zp2wIDAQABo4HEMIHBMB0G
A1UdDgQWBBSsBQ869nh83KqZr5jArr4/7b+QazCBkQYDVR0jBIGJMIGGgBSsBQ86
9nh83KqZr5jArr4/7b+Qa6FrpGkwZzELMAkGA1UEBhMCVVMxFTATBgNVBAgTDFBl
bm5zeWx2YW5pYTETMBEGA1UEBxMKUGl0dHNidXJnaDERMA8GA1UEChMIVGVzdFNo
aWIxGTAXBgNVBAMTEGlkcC50ZXN0c2hpYi5vcmeCAQAwDAYDVR0TBAUwAwEB/zAN
BgkqhkiG9w0BAQUFAAOCAQEAjR29PhrCbk8qLN5MFfSVk98t3CT9jHZoYxd8QMRL
I4j7iYQxXiGJTT1FXs1nd4Rha9un+LqTfeMMYqISdDDI6tv8iNpkOAvZZUosVkUo
93pv1T0RPz35hcHHYq2yee59HJOco2bFlcsH8JBXRSRrJ3Q7Eut+z9uo80JdGNJ4
/SJy5UorZ8KazGj16lfJhOBXldgrhppQBb0Nq6HKHguqmwRfJ+WkxemZXzhediAj
Geka8nz8JjwxpUjAiSWYKLtJhGEaTqCYxCCX2Dw+dOTqUzHOZ7WKv4JXPK5G/Uhr
8K/qhmFT2nIQi538n6rVYLeWj8Bbnl+ev0peYzxFyF5sQA==
"""
}
class TestShibBProvider(SAMLProviderMixin, BaseProvider):
""" Provider for testshib.org public Shibboleth test server. """
NAME = 'TestShib B'
IDP = {
"id": "testshibB", # Required slug
"entity_id": "https://idp.testshib.org/idp/shibboleth",
"url": "https://IDP.TESTSHIB.ORG/idp/profile/SAML2/Redirect/SSO",
"attr_email": OID_EDU_PERSON_PRINCIPAL_NAME,
"x509cert": TestShibAProvider.IDP["x509cert"],
}
Third-party auth provider configuration API.
"""
from .models import (
OAuth2ProviderConfig, SAMLConfiguration, SAMLProviderConfig,
_PSA_OAUTH2_BACKENDS, _PSA_SAML_BACKENDS
)
class Registry(object):
"""Singleton registry of third-party auth providers.
Providers must subclass BaseProvider in order to be usable in the registry.
"""
API for querying third-party auth ProviderConfig objects.
_CONFIGURED = False
_ENABLED = {}
@classmethod
def _check_configured(cls):
"""Ensures registry is configured."""
if not cls._CONFIGURED:
raise RuntimeError('Registry not configured')
@classmethod
def _get_all(cls):
"""Gets all provider implementations loaded into the Python runtime."""
# BaseProvider does so have __subclassess__. pylint: disable-msg=no-member
return {klass.NAME: klass for klass in BaseProvider.__subclasses__()}
@classmethod
def _enable(cls, provider):
"""Enables a single provider."""
if provider.NAME in cls._ENABLED:
raise ValueError('Provider %s already enabled' % provider.NAME)
cls._ENABLED[provider.NAME] = provider
@classmethod
def configure_once(cls, provider_names):
"""Configures providers.
Args:
provider_names: list of string. The providers to configure.
Raises:
ValueError: if the registry has already been configured, or if any
of the passed provider_names does not have a corresponding
BaseProvider child implementation.
Providers must subclass ProviderConfig in order to be usable in the registry.
"""
if cls._CONFIGURED:
raise ValueError('Provider registry already configured')
# Flip the bit eagerly -- configure() should not be re-callable if one
# _enable call fails.
cls._CONFIGURED = True
for name in provider_names:
all_providers = cls._get_all()
if name not in all_providers:
raise ValueError('No implementation found for provider ' + name)
cls._enable(all_providers.get(name))
@classmethod
def _enabled_providers(cls):
""" Helper method to iterate over all providers """
for backend_name in _PSA_OAUTH2_BACKENDS:
provider = OAuth2ProviderConfig.current(backend_name)
if provider.enabled:
yield provider
if SAMLConfiguration.is_enabled():
idp_slugs = SAMLProviderConfig.key_values('idp_slug', flat=True)
for idp_slug in idp_slugs:
provider = SAMLProviderConfig.current(idp_slug)
if provider.enabled and provider.backend_name in _PSA_SAML_BACKENDS:
yield provider
@classmethod
def enabled(cls):
"""Returns list of enabled providers."""
cls._check_configured()
return sorted(cls._ENABLED.values(), key=lambda provider: provider.NAME)
return sorted(cls._enabled_providers(), key=lambda provider: provider.name)
@classmethod
def get(cls, provider_name):
"""Gets provider named provider_name string if enabled, else None."""
cls._check_configured()
return cls._ENABLED.get(provider_name)
def get(cls, provider_id):
"""Gets provider by provider_id string if enabled, else None."""
if '-' not in provider_id: # Check format - see models.py:ProviderConfig
raise ValueError("Invalid provider_id. Expect something like oa2-google")
try:
return next(provider for provider in cls._enabled_providers() if provider.provider_id == provider_id)
except StopIteration:
return None
@classmethod
def get_from_pipeline(cls, running_pipeline):
......@@ -308,13 +51,9 @@ class Registry(object):
authenticate a user.
Returns:
A provider class (a subclass of BaseProvider) or None.
Raises:
RuntimeError: if the registry has not been configured.
An instance of ProviderConfig or None.
"""
cls._check_configured()
for enabled in cls._ENABLED.values():
for enabled in cls._enabled_providers():
if enabled.is_active_for_pipeline(running_pipeline):
return enabled
......@@ -325,25 +64,22 @@ class Registry(object):
Example:
>>> list(get_enabled_by_backend_name("tpa-saml"))
[TestShibAProvider, TestShibBProvider]
[<SAMLProviderConfig>, <SAMLProviderConfig>]
Args:
backend_name: The name of a python-social-auth backend used by
one or more providers.
Yields:
Provider classes (subclasses of BaseProvider).
Raises:
RuntimeError: if the registry has not been configured.
Instances of ProviderConfig.
"""
cls._check_configured()
for enabled in cls._ENABLED.values():
if enabled.BACKEND_CLASS.name == backend_name:
yield enabled
@classmethod
def _reset(cls):
"""Returns the registry to an unconfigured state; for tests only."""
cls._CONFIGURED = False
cls._ENABLED = {}
if backend_name in _PSA_OAUTH2_BACKENDS:
provider = OAuth2ProviderConfig.current(backend_name)
if provider.enabled:
yield provider
elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled():
idp_names = SAMLProviderConfig.key_values('idp_slug', flat=True)
for idp_name in idp_names:
provider = SAMLProviderConfig.current(idp_name)
if provider.backend_name == backend_name and provider.enabled:
yield provider
"""
Slightly customized python-social-auth backend for SAML 2.0 support
"""
import logging
from social.backends.saml import SAMLAuth, OID_EDU_PERSON_ENTITLEMENT
from social.exceptions import AuthForbidden
from social.backends.saml import SAMLIdentityProvider, SAMLAuth
log = logging.getLogger(__name__)
class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
......@@ -14,8 +17,33 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
def get_idp(self, idp_name):
""" Given the name of an IdP, get a SAMLIdentityProvider instance """
from .provider import Registry # Import here to avoid circular import
for provider in Registry.enabled():
if issubclass(provider.BACKEND_CLASS, SAMLAuth) and provider.IDP["id"] == idp_name:
return SAMLIdentityProvider(idp_name, **provider.IDP)
raise KeyError("SAML IdP {} not found.".format(idp_name))
from .models import SAMLProviderConfig
return SAMLProviderConfig.current(idp_name).get_config()
def setting(self, name, default=None):
""" Get a setting, from SAMLConfiguration """
if not hasattr(self, '_config'):
from .models import SAMLConfiguration
self._config = SAMLConfiguration.current() # pylint: disable=attribute-defined-outside-init
if not self._config.enabled:
from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured("SAML Authentication is not enabled.")
try:
return self._config.get_setting(name)
except KeyError:
return self.strategy.setting(name, default)
def _check_entitlements(self, idp, attributes):
"""
Check if we require the presence of any specific eduPersonEntitlement.
raise AuthForbidden if the user should not be authenticated, or do nothing
to allow the login pipeline to continue.
"""
if "requiredEntitlements" in idp.conf:
entitlements = attributes.get(OID_EDU_PERSON_ENTITLEMENT, [])
for expected in idp.conf['requiredEntitlements']:
if expected not in entitlements:
log.warning(
"SAML user from IdP %s rejected due to missing eduPersonEntitlement %s", idp.name, expected)
raise AuthForbidden(self)
"""Settings for the third-party auth module.
Defers configuration of settings so we can inspect the provider registry and
create settings placeholders for only those values actually needed by a given
deployment. Required by Django; consequently, this file must not invoke the
Django armature.
The flow for settings registration is:
The base settings file contains a boolean, ENABLE_THIRD_PARTY_AUTH, indicating
whether this module is enabled. Ancillary settings files (aws.py, dev.py) put
options in THIRD_PARTY_SETTINGS. startup.py probes the ENABLE_THIRD_PARTY_AUTH.
whether this module is enabled. startup.py probes the ENABLE_THIRD_PARTY_AUTH.
If true, it:
a) loads this module.
b) calls apply_settings(), passing in settings.THIRD_PARTY_AUTH.
THIRD_PARTY AUTH is a dict of the form
'THIRD_PARTY_AUTH': {
'<PROVIDER_NAME>': {
'<PROVIDER_SETTING_NAME>': '<PROVIDER_SETTING_VALUE>',
[...]
},
[...]
}
If you are using a dev settings file, your settings dict starts at the
level of <PROVIDER_NAME> and is a map of provider name string to
settings dict. If you are using an auth.json file, it should contain a
THIRD_PARTY_AUTH entry as above.
c) apply_settings() builds a list of <PROVIDER_NAMES>. These are the
enabled third party auth providers for the deployment. These are enabled
in provider.Registry, the canonical list of enabled providers.
d) then, it sets global, provider-independent settings.
e) then, it sets provider-specific settings. For each enabled provider, we
read its SETTINGS member. These are merged onto the Django settings
object. In most cases these are stubs and the real values are set from
THIRD_PARTY_AUTH. All values that are set from this dict must first be
initialized from SETTINGS. This allows us to validate the dict and
ensure that the values match expected configuration options on the
provider.
f) finally, the (key, value) pairs from the dict file are merged onto the
django settings object.
b) calls apply_settings(), passing in the Django settings
"""
from . import provider
_FIELDS_STORED_IN_SESSION = ['auth_entry', 'next']
_MIDDLEWARE_CLASSES = (
'third_party_auth.middleware.ExceptionMiddleware',
......@@ -53,25 +17,7 @@ _MIDDLEWARE_CLASSES = (
_SOCIAL_AUTH_LOGIN_REDIRECT_URL = '/dashboard'
def _merge_auth_info(django_settings, auth_info):
"""Merge auth_info dict onto django_settings module."""
enabled_provider_names = []
to_merge = []
for provider_name, provider_dict in auth_info.items():
enabled_provider_names.append(provider_name)
# Merge iff all settings have been intialized.
for key in provider_dict:
if key not in dir(django_settings):
raise ValueError('Auth setting %s not initialized' % key)
to_merge.append(provider_dict)
for passed_validation in to_merge:
for key, value in passed_validation.iteritems():
setattr(django_settings, key, value)
def _set_global_settings(django_settings):
def apply_settings(django_settings):
"""Set provider-independent settings."""
# Whitelisted URL query parameters retrained in the pipeline session.
......@@ -115,6 +61,9 @@ def _set_global_settings(django_settings):
'third_party_auth.pipeline.login_analytics',
)
# Required so that we can use unmodified PSA OAuth2 backends:
django_settings.SOCIAL_AUTH_STRATEGY = 'third_party_auth.strategy.ConfigurationModelStrategy'
# We let the user specify their email address during signup.
django_settings.SOCIAL_AUTH_PROTECTED_USER_FIELDS = ['email']
......@@ -136,30 +85,3 @@ def _set_global_settings(django_settings):
'social.apps.django_app.context_processors.backends',
'social.apps.django_app.context_processors.login_redirect',
)
def _set_provider_settings(django_settings, enabled_providers, auth_info):
"""Sets provider-specific settings."""
# Must prepend here so we get called first.
django_settings.AUTHENTICATION_BACKENDS = (
tuple(enabled_provider.get_authentication_backend() for enabled_provider in enabled_providers) +
django_settings.AUTHENTICATION_BACKENDS)
# Merge settings from provider classes, and configure all placeholders.
for enabled_provider in enabled_providers:
enabled_provider.merge_onto(django_settings)
# Merge settings from <deployment>.auth.json, overwriting placeholders.
_merge_auth_info(django_settings, auth_info)
def apply_settings(auth_info, django_settings):
"""Applies settings from auth_info dict to django_settings module."""
if django_settings.FEATURES.get('ENABLE_DUMMY_THIRD_PARTY_AUTH_PROVIDER'):
# The Dummy provider is handy for testing and development.
from .dummy import DummyProvider # pylint: disable=unused-variable
provider_names = auth_info.keys()
provider.Registry.configure_once(provider_names)
enabled_providers = provider.Registry.enabled()
_set_global_settings(django_settings)
_set_provider_settings(django_settings, enabled_providers, auth_info)
"""
A custom Strategy for python-social-auth that allows us to fetch configuration from
ConfigurationModels rather than django.settings
"""
from .models import OAuth2ProviderConfig
from social.backends.oauth import BaseOAuth2
from social.strategies.django_strategy import DjangoStrategy
class ConfigurationModelStrategy(DjangoStrategy):
"""
A DjangoStrategy customized to load settings from ConfigurationModels
for upstream python-social-auth backends that we cannot otherwise modify.
"""
def setting(self, name, default=None, backend=None):
"""
Load the setting from a ConfigurationModel if possible, or fall back to the normal
Django settings lookup.
BaseOAuth2 subclasses will call this method for every setting they want to look up.
SAMLAuthBackend subclasses will call this method only after first checking if the
setting 'name' is configured via SAMLProviderConfig.
"""
if isinstance(backend, BaseOAuth2):
provider_config = OAuth2ProviderConfig.current(backend.name)
if not provider_config.enabled:
raise Exception("Can't fetch setting of a disabled backend/provider.")
try:
return provider_config.get_setting(name)
except KeyError:
pass
# At this point, we know 'name' is not set in a [OAuth2|SAML]ProviderConfig row.
# It's probably a global Django setting like 'FIELDS_STORED_IN_SESSION':
return super(ConfigurationModelStrategy, self).setting(name, default, backend)
......@@ -32,15 +32,8 @@ from third_party_auth.tests import testutil
class IntegrationTest(testutil.TestCase, test.TestCase):
"""Abstract base class for provider integration tests."""
# Configuration. You will need to override these values in your test cases.
# Class. The third_party_auth.provider.BaseProvider child we are testing.
PROVIDER_CLASS = None
# Dict of string -> object. Settings that will be merged onto Django's
# settings object before test execution. In most cases, this is
# PROVIDER_CLASS.SETTINGS with test values.
PROVIDER_SETTINGS = {}
# Override setUp and set this:
provider = None
# Methods you must override in your children.
......@@ -94,10 +87,10 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
"""
self.assertEqual(200, response.status_code)
# Check that the correct provider was selected.
self.assertIn('successfully signed in with <strong>%s</strong>' % self.PROVIDER_CLASS.NAME, response.content)
self.assertIn('successfully signed in with <strong>%s</strong>' % self.provider.name, response.content)
# Expect that each truthy value we've prepopulated the register form
# with is actually present.
for prepopulated_form_value in self.PROVIDER_CLASS.get_register_form_data(pipeline_kwargs).values():
for prepopulated_form_value in self.provider.get_register_form_data(pipeline_kwargs).values():
if prepopulated_form_value:
self.assertIn(prepopulated_form_value, response.content)
......@@ -106,12 +99,15 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
def setUp(self):
super(IntegrationTest, self).setUp()
self.configure_runtime()
self.backend_name = self.PROVIDER_CLASS.BACKEND_CLASS.name
self.client = test.Client()
self.request_factory = test.RequestFactory()
def assert_account_settings_context_looks_correct(self, context, user, duplicate=False, linked=None):
@property
def backend_name(self):
""" Shortcut for the backend name """
return self.provider.backend_name
# pylint: disable=invalid-name
def assert_account_settings_context_looks_correct(self, context, _user, duplicate=False, linked=None):
"""Asserts the user's account settings page context is in the expected state.
If duplicate is True, we expect context['duplicate_provider'] to contain
......@@ -120,13 +116,13 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
its connected state is correct.
"""
if duplicate:
self.assertEqual(context['duplicate_provider'], self.PROVIDER_CLASS.BACKEND_CLASS.name)
self.assertEqual(context['duplicate_provider'], self.provider.backend_name)
else:
self.assertIsNone(context['duplicate_provider'])
if linked is not None:
expected_provider = [
provider for provider in context['auth']['providers'] if provider['name'] == self.PROVIDER_CLASS.NAME
provider for provider in context['auth']['providers'] if provider['name'] == self.provider.name
][0]
self.assertIsNotNone(expected_provider)
self.assertEqual(expected_provider['connected'], linked)
......@@ -197,7 +193,10 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
def assert_json_failure_response_is_missing_social_auth(self, response):
"""Asserts failure on /login for missing social auth looks right."""
self.assertEqual(403, response.status_code)
self.assertIn("successfully logged into your %s account, but this account isn't linked" % self.PROVIDER_CLASS.NAME, response.content)
self.assertIn(
"successfully logged into your %s account, but this account isn't linked" % self.provider.name,
response.content
)
def assert_json_failure_response_is_username_collision(self, response):
"""Asserts the json response indicates a username collision."""
......@@ -211,7 +210,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
self.assertEqual(200, response.status_code)
payload = json.loads(response.content)
self.assertTrue(payload.get('success'))
self.assertEqual(pipeline.get_complete_url(self.PROVIDER_CLASS.BACKEND_CLASS.name), payload.get('redirect_url'))
self.assertEqual(pipeline.get_complete_url(self.provider.backend_name), payload.get('redirect_url'))
def assert_login_response_before_pipeline_looks_correct(self, response):
"""Asserts a GET of /login not in the pipeline looks correct."""
......@@ -219,7 +218,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
# The combined login/registration page dynamically generates the login button,
# but we can still check that the provider name is passed in the data attribute
# for the container element.
self.assertIn(self.PROVIDER_CLASS.NAME, response.content)
self.assertIn(self.provider.name, response.content)
def assert_login_response_in_pipeline_looks_correct(self, response):
"""Asserts a GET of /login in the pipeline looks correct."""
......@@ -258,28 +257,21 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
# The combined login/registration page dynamically generates the register button,
# but we can still check that the provider name is passed in the data attribute
# for the container element.
self.assertIn(self.PROVIDER_CLASS.NAME, response.content)
self.assertIn(self.provider.name, response.content)
def assert_social_auth_does_not_exist_for_user(self, user, strategy):
"""Asserts a user does not have an auth with the expected provider."""
social_auths = strategy.storage.user.get_social_auth_for_user(
user, provider=self.PROVIDER_CLASS.BACKEND_CLASS.name)
user, provider=self.provider.backend_name)
self.assertEqual(0, len(social_auths))
def assert_social_auth_exists_for_user(self, user, strategy):
"""Asserts a user has a social auth with the expected provider."""
social_auths = strategy.storage.user.get_social_auth_for_user(
user, provider=self.PROVIDER_CLASS.BACKEND_CLASS.name)
user, provider=self.provider.backend_name)
self.assertEqual(1, len(social_auths))
self.assertEqual(self.backend_name, social_auths[0].provider)
def configure_runtime(self):
"""Configures settings details."""
auth_settings.apply_settings({self.PROVIDER_CLASS.NAME: self.PROVIDER_SETTINGS}, django_settings)
# Force settings to propagate into cached members on
# social.apps.django_app.utils.
reload(social_utils)
def create_user_models_for_existing_account(self, strategy, email, password, username, skip_social_auth=False):
"""Creates user, profile, registration, and (usually) social auth.
......@@ -296,7 +288,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
registration.save()
if not skip_social_auth:
social_utils.Storage.user.create_social_auth(user, uid, self.PROVIDER_CLASS.BACKEND_CLASS.name)
social_utils.Storage.user.create_social_auth(user, uid, self.provider.backend_name)
return user
......@@ -370,7 +362,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
self.assertEqual(response.status_code, 302)
self.assertEqual(
response["Location"],
pipeline.get_complete_url(self.PROVIDER_CLASS.BACKEND_CLASS.name)
pipeline.get_complete_url(self.provider.backend_name)
)
self.assertEqual(response.cookies[django_settings.EDXMKTG_LOGGED_IN_COOKIE_NAME].value, 'true')
self.assertIn(django_settings.EDXMKTG_USER_INFO_COOKIE_NAME, response.cookies)
......@@ -417,7 +409,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
# Instrument the pipeline to get to the dashboard with the full
# expected state.
self.client.get(
pipeline.get_login_url(self.PROVIDER_CLASS.NAME, pipeline.AUTH_ENTRY_LOGIN))
pipeline.get_login_url(self.provider.provider_id, pipeline.AUTH_ENTRY_LOGIN))
actions.do_complete(request.backend, social_views._do_login) # pylint: disable=protected-access
mako_middleware_process_request(strategy.request)
......@@ -465,7 +457,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
# Instrument the pipeline to get to the dashboard with the full
# expected state.
self.client.get(
pipeline.get_login_url(self.PROVIDER_CLASS.NAME, pipeline.AUTH_ENTRY_LOGIN))
pipeline.get_login_url(self.provider.provider_id, pipeline.AUTH_ENTRY_LOGIN))
actions.do_complete(request.backend, social_views._do_login) # pylint: disable=protected-access
mako_middleware_process_request(strategy.request)
......@@ -524,7 +516,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
self.assert_social_auth_exists_for_user(user, strategy)
self.client.get('/login')
self.client.get(pipeline.get_login_url(self.PROVIDER_CLASS.NAME, pipeline.AUTH_ENTRY_LOGIN))
self.client.get(pipeline.get_login_url(self.provider.provider_id, pipeline.AUTH_ENTRY_LOGIN))
actions.do_complete(request.backend, social_views._do_login) # pylint: disable=protected-access
mako_middleware_process_request(strategy.request)
......@@ -536,7 +528,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
request._messages = fallback.FallbackStorage(request)
middleware.ExceptionMiddleware().process_exception(
request,
exceptions.AuthAlreadyAssociated(self.PROVIDER_CLASS.BACKEND_CLASS.name, 'account is already in use.'))
exceptions.AuthAlreadyAssociated(self.provider.backend_name, 'account is already in use.'))
self.assert_account_settings_context_looks_correct(
account_settings_context(request), user, duplicate=True, linked=True)
......@@ -561,7 +553,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
# Synthesize that request and check that it redirects to the correct
# provider page.
self.assert_redirect_to_provider_looks_correct(self.client.get(
pipeline.get_login_url(self.PROVIDER_CLASS.NAME, pipeline.AUTH_ENTRY_LOGIN)))
pipeline.get_login_url(self.provider.provider_id, pipeline.AUTH_ENTRY_LOGIN)))
# Next, the provider makes a request against /auth/complete/<provider>
# to resume the pipeline.
......@@ -641,7 +633,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
# Synthesize that request and check that it redirects to the correct
# provider page.
self.assert_redirect_to_provider_looks_correct(self.client.get(
pipeline.get_login_url(self.PROVIDER_CLASS.NAME, pipeline.AUTH_ENTRY_LOGIN)))
pipeline.get_login_url(self.provider.provider_id, pipeline.AUTH_ENTRY_LOGIN)))
# Next, the provider makes a request against /auth/complete/<provider>.
# pylint: disable=protected-access
......
......@@ -7,11 +7,14 @@ from third_party_auth.tests.specs import base
class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest):
"""Integration tests for provider.GoogleOauth2."""
PROVIDER_CLASS = provider.GoogleOauth2
PROVIDER_SETTINGS = {
'SOCIAL_AUTH_GOOGLE_OAUTH2_KEY': 'google_oauth2_key',
'SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET': 'google_oauth2_secret',
}
def setUp(self):
super(GoogleOauth2IntegrationTest, self).setUp()
self.provider = self.configure_google_provider(
enabled=True,
key='google_oauth2_key',
secret='google_oauth2_secret',
)
TOKEN_RESPONSE_DATA = {
'access_token': 'access_token_value',
'expires_in': 'expires_in_value',
......
......@@ -7,11 +7,14 @@ from third_party_auth.tests.specs import base
class LinkedInOauth2IntegrationTest(base.Oauth2IntegrationTest):
"""Integration tests for provider.LinkedInOauth2."""
PROVIDER_CLASS = provider.LinkedInOauth2
PROVIDER_SETTINGS = {
'SOCIAL_AUTH_LINKEDIN_OAUTH2_KEY': 'linkedin_oauth2_key',
'SOCIAL_AUTH_LINKEDIN_OAUTH2_SECRET': 'linkedin_oauth2_secret',
}
def setUp(self):
super(LinkedInOauth2IntegrationTest, self).setUp()
self.provider = self.configure_linkedin_provider(
enabled=True,
key='linkedin_oauth2_key',
secret='linkedin_oauth2_secret',
)
TOKEN_RESPONSE_DATA = {
'access_token': 'access_token_value',
'expires_in': 'expires_in_value',
......
......@@ -4,6 +4,7 @@ import random
from third_party_auth import pipeline, provider
from third_party_auth.tests import testutil
import unittest
# Allow tests access to protected methods (or module-protected methods) under
......@@ -34,9 +35,11 @@ class MakeRandomPasswordTest(testutil.TestCase):
self.assertEqual(expected, pipeline.make_random_password(choice_fn=random_instance.choice))
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled')
class ProviderUserStateTestCase(testutil.TestCase):
"""Tests ProviderUserState behavior."""
def test_get_unlink_form_name(self):
state = pipeline.ProviderUserState(provider.GoogleOauth2, object(), 1000)
self.assertEqual(provider.GoogleOauth2.NAME + '_unlink_form', state.get_unlink_form_name())
google_provider = self.configure_google_provider(enabled=True)
state = pipeline.ProviderUserState(google_provider, object(), 1000)
self.assertEqual(google_provider.provider_id + '_unlink_form', state.get_unlink_form_name())
......@@ -21,9 +21,7 @@ class TestCase(testutil.TestCase, test.TestCase):
def setUp(self):
super(TestCase, self).setUp()
self.enabled_provider_name = provider.GoogleOauth2.NAME
provider.Registry.configure_once([self.enabled_provider_name])
self.enabled_provider = provider.Registry.get(self.enabled_provider_name)
self.enabled_provider = self.configure_google_provider(enabled=True)
@unittest.skipUnless(
......@@ -55,13 +53,13 @@ class GetAuthenticatedUserTestCase(TestCase):
def test_raises_does_not_exist_if_user_and_association_found_but_no_match(self):
self.assertIsNotNone(self.get_by_username(self.user.username))
social_models.DjangoStorage.user.create_social_auth(
self.user, 'uid', 'other_' + self.enabled_provider.BACKEND_CLASS.name)
self.user, 'uid', 'other_' + self.enabled_provider.backend_name)
with self.assertRaises(models.User.DoesNotExist):
pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'uid')
def test_returns_user_with_is_authenticated_and_backend_set_if_match(self):
social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', self.enabled_provider.BACKEND_CLASS.name)
social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', self.enabled_provider.backend_name)
user = pipeline.get_authenticated_user(self.enabled_provider, self.user.username, 'uid')
self.assertEqual(self.user, user)
......@@ -78,58 +76,70 @@ class GetProviderUserStatesTestCase(testutil.TestCase, test.TestCase):
self.user = social_models.DjangoStorage.user.create_user(username='username', password='password')
def test_returns_empty_list_if_no_enabled_providers(self):
provider.Registry.configure_once([])
self.assertFalse(provider.Registry.enabled())
self.assertEquals([], pipeline.get_provider_user_states(self.user))
def test_state_not_returned_for_disabled_provider(self):
disabled_provider = provider.GoogleOauth2
enabled_provider = provider.LinkedInOauth2
provider.Registry.configure_once([enabled_provider.NAME])
social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', disabled_provider.BACKEND_CLASS.name)
disabled_provider = self.configure_google_provider(enabled=False)
enabled_provider = self.configure_facebook_provider(enabled=True)
social_models.DjangoStorage.user.create_social_auth(self.user, 'uid', disabled_provider.backend_name)
states = pipeline.get_provider_user_states(self.user)
self.assertEqual(1, len(states))
self.assertNotIn(disabled_provider, (state.provider for state in states))
self.assertNotIn(disabled_provider.provider_id, (state.provider.provider_id for state in states))
self.assertIn(enabled_provider.provider_id, (state.provider.provider_id for state in states))
def test_states_for_enabled_providers_user_has_accounts_associated_with(self):
provider.Registry.configure_once([provider.GoogleOauth2.NAME, provider.LinkedInOauth2.NAME])
# Enable two providers - Google and LinkedIn:
google_provider = self.configure_google_provider(enabled=True)
linkedin_provider = self.configure_linkedin_provider(enabled=True)
user_social_auth_google = social_models.DjangoStorage.user.create_social_auth(
self.user, 'uid', provider.GoogleOauth2.BACKEND_CLASS.name)
self.user, 'uid', google_provider.backend_name)
user_social_auth_linkedin = social_models.DjangoStorage.user.create_social_auth(
self.user, 'uid', provider.LinkedInOauth2.BACKEND_CLASS.name)
self.user, 'uid', linkedin_provider.backend_name)
states = pipeline.get_provider_user_states(self.user)
self.assertEqual(2, len(states))
google_state = [state for state in states if state.provider == provider.GoogleOauth2][0]
linkedin_state = [state for state in states if state.provider == provider.LinkedInOauth2][0]
google_state = [state for state in states if state.provider.provider_id == google_provider.provider_id][0]
linkedin_state = [state for state in states if state.provider.provider_id == linkedin_provider.provider_id][0]
self.assertTrue(google_state.has_account)
self.assertEqual(provider.GoogleOauth2, google_state.provider)
self.assertEqual(google_provider.provider_id, google_state.provider.provider_id)
# Also check the row ID. Note this 'id' changes whenever the configuration does:
self.assertEqual(google_provider.id, google_state.provider.id) # pylint: disable=no-member
self.assertEqual(self.user, google_state.user)
self.assertEqual(user_social_auth_google.id, google_state.association_id)
self.assertTrue(linkedin_state.has_account)
self.assertEqual(provider.LinkedInOauth2, linkedin_state.provider)
self.assertEqual(linkedin_provider.provider_id, linkedin_state.provider.provider_id)
self.assertEqual(linkedin_provider.id, linkedin_state.provider.id) # pylint: disable=no-member
self.assertEqual(self.user, linkedin_state.user)
self.assertEqual(user_social_auth_linkedin.id, linkedin_state.association_id)
def test_states_for_enabled_providers_user_has_no_account_associated_with(self):
provider.Registry.configure_once([provider.GoogleOauth2.NAME, provider.LinkedInOauth2.NAME])
# Enable two providers - Google and LinkedIn:
google_provider = self.configure_google_provider(enabled=True)
linkedin_provider = self.configure_linkedin_provider(enabled=True)
self.assertEqual(len(provider.Registry.enabled()), 2)
states = pipeline.get_provider_user_states(self.user)
self.assertEqual([], [x for x in social_models.DjangoStorage.user.objects.all()])
self.assertEqual(2, len(states))
google_state = [state for state in states if state.provider == provider.GoogleOauth2][0]
linkedin_state = [state for state in states if state.provider == provider.LinkedInOauth2][0]
google_state = [state for state in states if state.provider.provider_id == google_provider.provider_id][0]
linkedin_state = [state for state in states if state.provider.provider_id == linkedin_provider.provider_id][0]
self.assertFalse(google_state.has_account)
self.assertEqual(provider.GoogleOauth2, google_state.provider)
self.assertEqual(google_provider.provider_id, google_state.provider.provider_id)
# Also check the row ID. Note this 'id' changes whenever the configuration does:
self.assertEqual(google_provider.id, google_state.provider.id) # pylint: disable=no-member
self.assertEqual(self.user, google_state.user)
self.assertFalse(linkedin_state.has_account)
self.assertEqual(provider.LinkedInOauth2, linkedin_state.provider)
self.assertEqual(linkedin_provider.provider_id, linkedin_state.provider.provider_id)
self.assertEqual(linkedin_provider.id, linkedin_state.provider.id) # pylint: disable=no-member
self.assertEqual(self.user, linkedin_state.user)
......@@ -139,7 +149,7 @@ class UrlFormationTestCase(TestCase):
"""Tests formation of URLs for pipeline hook points."""
def test_complete_url_raises_value_error_if_provider_not_enabled(self):
provider_name = 'not_enabled'
provider_name = 'oa2-not-enabled'
self.assertIsNone(provider.Registry.get(provider_name))
......@@ -147,13 +157,13 @@ class UrlFormationTestCase(TestCase):
pipeline.get_complete_url(provider_name)
def test_complete_url_returns_expected_format(self):
complete_url = pipeline.get_complete_url(self.enabled_provider.BACKEND_CLASS.name)
complete_url = pipeline.get_complete_url(self.enabled_provider.backend_name)
self.assertTrue(complete_url.startswith('/auth/complete'))
self.assertIn(self.enabled_provider.BACKEND_CLASS.name, complete_url)
self.assertIn(self.enabled_provider.backend_name, complete_url)
def test_disconnect_url_raises_value_error_if_provider_not_enabled(self):
provider_name = 'not_enabled'
provider_name = 'oa2-not-enabled'
self.assertIsNone(provider.Registry.get(provider_name))
......@@ -161,25 +171,40 @@ class UrlFormationTestCase(TestCase):
pipeline.get_disconnect_url(provider_name, 1000)
def test_disconnect_url_returns_expected_format(self):
disconnect_url = pipeline.get_disconnect_url(self.enabled_provider.NAME, 1000)
disconnect_url = pipeline.get_disconnect_url(self.enabled_provider.provider_id, 1000)
disconnect_url = disconnect_url.rstrip('?')
self.assertEqual(
disconnect_url,
'/auth/disconnect/{backend}/{association_id}/'.format(
backend=self.enabled_provider.BACKEND_CLASS.name, association_id=1000)
backend=self.enabled_provider.backend_name, association_id=1000)
)
def test_login_url_raises_value_error_if_provider_not_enabled(self):
provider_name = 'not_enabled'
provider_id = 'oa2-not-enabled'
self.assertIsNone(provider.Registry.get(provider_name))
self.assertIsNone(provider.Registry.get(provider_id))
with self.assertRaises(ValueError):
pipeline.get_login_url(provider_name, pipeline.AUTH_ENTRY_LOGIN)
pipeline.get_login_url(provider_id, pipeline.AUTH_ENTRY_LOGIN)
def test_login_url_returns_expected_format(self):
login_url = pipeline.get_login_url(self.enabled_provider.NAME, pipeline.AUTH_ENTRY_LOGIN)
login_url = pipeline.get_login_url(self.enabled_provider.provider_id, pipeline.AUTH_ENTRY_LOGIN)
self.assertTrue(login_url.startswith('/auth/login'))
self.assertIn(self.enabled_provider.BACKEND_CLASS.name, login_url)
self.assertIn(self.enabled_provider.backend_name, login_url)
self.assertTrue(login_url.endswith(pipeline.AUTH_ENTRY_LOGIN))
def test_for_value_error_if_provider_id_invalid(self):
provider_id = 'invalid' # Format is normally "{prefix}-{identifier}"
with self.assertRaises(ValueError):
provider.Registry.get(provider_id)
with self.assertRaises(ValueError):
pipeline.get_login_url(provider_id, pipeline.AUTH_ENTRY_LOGIN)
with self.assertRaises(ValueError):
pipeline.get_disconnect_url(provider_id, 1000)
with self.assertRaises(ValueError):
pipeline.get_complete_url(provider_id)
"""Unit tests for provider.py."""
from mock import Mock
from mock import Mock, patch
from third_party_auth import provider
from third_party_auth.tests import testutil
import unittest
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled')
class RegistryTest(testutil.TestCase):
"""Tests registry discovery and operation."""
# Allow access to protected methods (or module-protected methods) under
# test. pylint: disable-msg=protected-access
def test_calling_configure_once_twice_raises_value_error(self):
provider.Registry.configure_once([provider.GoogleOauth2.NAME])
with self.assertRaisesRegexp(ValueError, '^.*already configured$'):
provider.Registry.configure_once([provider.GoogleOauth2.NAME])
def test_configure_once_adds_gettable_providers(self):
provider.Registry.configure_once([provider.GoogleOauth2.NAME])
self.assertIs(provider.GoogleOauth2, provider.Registry.get(provider.GoogleOauth2.NAME))
def test_configuring_provider_with_no_implementation_raises_value_error(self):
with self.assertRaisesRegexp(ValueError, '^.*no_implementation$'):
provider.Registry.configure_once(['no_implementation'])
def test_configuring_single_provider_twice_raises_value_error(self):
provider.Registry._enable(provider.GoogleOauth2)
with self.assertRaisesRegexp(ValueError, '^.*already enabled'):
provider.Registry.configure_once([provider.GoogleOauth2.NAME])
def test_custom_provider_can_be_enabled(self):
name = 'CustomProvider'
with self.assertRaisesRegexp(ValueError, '^No implementation.*$'):
provider.Registry.configure_once([name])
class CustomProvider(provider.BaseProvider):
"""Custom class to ensure BaseProvider children outside provider can be enabled."""
NAME = name
provider.Registry._reset()
provider.Registry.configure_once([CustomProvider.NAME])
self.assertEqual([CustomProvider], provider.Registry.enabled())
def test_enabled_raises_runtime_error_if_not_configured(self):
with self.assertRaisesRegexp(RuntimeError, '^.*not configured$'):
provider.Registry.enabled()
facebook_provider = self.configure_facebook_provider(enabled=True)
# pylint: disable=no-member
self.assertEqual(facebook_provider.id, provider.Registry.get(facebook_provider.provider_id).id)
def test_no_providers_by_default(self):
enabled_providers = provider.Registry.enabled()
self.assertEqual(len(enabled_providers), 0, "By default, no providers are enabled.")
def test_runtime_configuration(self):
self.configure_google_provider(enabled=True)
enabled_providers = provider.Registry.enabled()
self.assertEqual(len(enabled_providers), 1)
self.assertEqual(enabled_providers[0].name, "Google")
self.assertEqual(enabled_providers[0].secret, "opensesame")
self.configure_google_provider(enabled=False)
enabled_providers = provider.Registry.enabled()
self.assertEqual(len(enabled_providers), 0)
self.configure_google_provider(enabled=True, secret="alohomora")
enabled_providers = provider.Registry.enabled()
self.assertEqual(len(enabled_providers), 1)
self.assertEqual(enabled_providers[0].secret, "alohomora")
def test_cannot_load_arbitrary_backends(self):
""" Test that only backend_names listed in settings.AUTHENTICATION_BACKENDS can be used """
self.configure_oauth_provider(enabled=True, name="Disallowed", backend_name="disallowed")
self.enable_saml()
self.configure_saml_provider(enabled=True, name="Disallowed", idp_slug="test", backend_name="disallowed")
self.assertEqual(len(provider.Registry.enabled()), 0)
def test_enabled_returns_list_of_enabled_providers_sorted_by_name(self):
all_providers = provider.Registry._get_all()
provider.Registry.configure_once(all_providers.keys())
self.assertEqual(
sorted(all_providers.values(), key=lambda provider: provider.NAME), provider.Registry.enabled())
provider_names = ["Stack Overflow", "Google", "LinkedIn", "GitHub"]
backend_names = []
for name in provider_names:
backend_name = name.lower().replace(' ', '')
backend_names.append(backend_name)
self.configure_oauth_provider(enabled=True, name=name, backend_name=backend_name)
def test_get_raises_runtime_error_if_not_configured(self):
with self.assertRaisesRegexp(RuntimeError, '^.*not configured$'):
provider.Registry.get('anything')
with patch('third_party_auth.provider._PSA_OAUTH2_BACKENDS', backend_names):
self.assertEqual(sorted(provider_names), [prov.name for prov in provider.Registry.enabled()])
def test_get_returns_enabled_provider(self):
provider.Registry.configure_once([provider.GoogleOauth2.NAME])
self.assertIs(provider.GoogleOauth2, provider.Registry.get(provider.GoogleOauth2.NAME))
google_provider = self.configure_google_provider(enabled=True)
# pylint: disable=no-member
self.assertEqual(google_provider.id, provider.Registry.get(google_provider.provider_id).id)
def test_get_returns_none_if_provider_not_enabled(self):
provider.Registry.configure_once([])
self.assertIsNone(provider.Registry.get(provider.LinkedInOauth2.NAME))
linkedin_provider_id = "oa2-linkedin-oauth2"
# At this point there should be no configuration entries at all so no providers should be enabled
self.assertEqual(provider.Registry.enabled(), [])
self.assertIsNone(provider.Registry.get(linkedin_provider_id))
# Now explicitly disabled this provider:
self.configure_linkedin_provider(enabled=False)
self.assertIsNone(provider.Registry.get(linkedin_provider_id))
self.configure_linkedin_provider(enabled=True)
self.assertEqual(provider.Registry.get(linkedin_provider_id).provider_id, linkedin_provider_id)
def test_get_from_pipeline_returns_none_if_provider_not_enabled(self):
provider.Registry.configure_once([])
self.assertEqual(provider.Registry.enabled(), [], "By default, no providers are enabled.")
self.assertIsNone(provider.Registry.get_from_pipeline(Mock()))
def test_get_enabled_by_backend_name_raises_runtime_error_if_not_configured(self):
with self.assertRaisesRegexp(RuntimeError, '^.*not configured$'):
provider.Registry.get_enabled_by_backend_name('').next()
def test_get_enabled_by_backend_name_returns_enabled_provider(self):
provider.Registry.configure_once([provider.GoogleOauth2.NAME])
found = list(provider.Registry.get_enabled_by_backend_name(provider.GoogleOauth2.BACKEND_CLASS.name))
self.assertEqual(found, [provider.GoogleOauth2])
google_provider = self.configure_google_provider(enabled=True)
found = list(provider.Registry.get_enabled_by_backend_name(google_provider.backend_name))
self.assertEqual(found, [google_provider])
def test_get_enabled_by_backend_name_returns_none_if_provider_not_enabled(self):
provider.Registry.configure_once([])
self.assertEqual(
[],
list(provider.Registry.get_enabled_by_backend_name(provider.GoogleOauth2.BACKEND_CLASS.name))
)
google_provider = self.configure_google_provider(enabled=False)
found = list(provider.Registry.get_enabled_by_backend_name(google_provider.backend_name))
self.assertEqual(found, [])
......@@ -2,6 +2,7 @@
from third_party_auth import provider, settings
from third_party_auth.tests import testutil
import unittest
_ORIGINAL_AUTHENTICATION_BACKENDS = ('first_authentication_backend',)
......@@ -30,56 +31,26 @@ class SettingsUnitTest(testutil.TestCase):
self.settings = testutil.FakeDjangoSettings(_SETTINGS_MAP)
def test_apply_settings_adds_exception_middleware(self):
settings.apply_settings({}, self.settings)
settings.apply_settings(self.settings)
for middleware_name in settings._MIDDLEWARE_CLASSES:
self.assertIn(middleware_name, self.settings.MIDDLEWARE_CLASSES)
def test_apply_settings_adds_fields_stored_in_session(self):
settings.apply_settings({}, self.settings)
settings.apply_settings(self.settings)
self.assertEqual(settings._FIELDS_STORED_IN_SESSION, self.settings.FIELDS_STORED_IN_SESSION)
def test_apply_settings_adds_third_party_auth_to_installed_apps(self):
settings.apply_settings({}, self.settings)
settings.apply_settings(self.settings)
self.assertIn('third_party_auth', self.settings.INSTALLED_APPS)
def test_apply_settings_enables_no_providers_and_completes_when_app_info_empty(self):
settings.apply_settings({}, self.settings)
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled')
def test_apply_settings_enables_no_providers_by_default(self):
# Providers are only enabled via ConfigurationModels in the database
settings.apply_settings(self.settings)
self.assertEqual([], provider.Registry.enabled())
def test_apply_settings_initializes_stubs_and_merges_settings_from_auth_info(self):
for key in provider.GoogleOauth2.SETTINGS:
self.assertFalse(hasattr(self.settings, key))
auth_info = {
provider.GoogleOauth2.NAME: {
'SOCIAL_AUTH_GOOGLE_OAUTH2_KEY': 'google_oauth2_key',
},
}
settings.apply_settings(auth_info, self.settings)
self.assertEqual('google_oauth2_key', self.settings.SOCIAL_AUTH_GOOGLE_OAUTH2_KEY)
self.assertIsNone(self.settings.SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET)
def test_apply_settings_prepends_auth_backends(self):
self.assertEqual(_ORIGINAL_AUTHENTICATION_BACKENDS, self.settings.AUTHENTICATION_BACKENDS)
settings.apply_settings({provider.GoogleOauth2.NAME: {}, provider.LinkedInOauth2.NAME: {}}, self.settings)
self.assertEqual((
provider.GoogleOauth2.get_authentication_backend(), provider.LinkedInOauth2.get_authentication_backend()) +
_ORIGINAL_AUTHENTICATION_BACKENDS,
self.settings.AUTHENTICATION_BACKENDS)
def test_apply_settings_raises_value_error_if_provider_contains_uninitialized_setting(self):
bad_setting_name = 'bad_setting'
self.assertNotIn('bad_setting_name', provider.GoogleOauth2.SETTINGS)
auth_info = {
provider.GoogleOauth2.NAME: {
bad_setting_name: None,
},
}
with self.assertRaisesRegexp(ValueError, '^.*not initialized$'):
settings.apply_settings(auth_info, self.settings)
def test_apply_settings_turns_off_raising_social_exceptions(self):
# Guard against submitting a conf change that's convenient in dev but
# bad in prod.
settings.apply_settings({}, self.settings)
settings.apply_settings(self.settings)
self.assertFalse(self.settings.SOCIAL_AUTH_RAISE_EXCEPTIONS)
"""Integration tests for settings.py."""
from django.conf import settings
from third_party_auth import provider
from third_party_auth import settings as auth_settings
from third_party_auth.tests import testutil
class SettingsIntegrationTest(testutil.TestCase):
"""Integration tests of auth settings pipeline.
Note that ENABLE_THIRD_PARTY_AUTH is True in lms/envs/test.py and False in
cms/envs/test.py. This implicitly gives us coverage of the full settings
mechanism with both values, so we do not have explicit test methods as they
are superfluous.
"""
def test_can_enable_google_oauth2(self):
auth_settings.apply_settings({'Google': {'SOCIAL_AUTH_GOOGLE_OAUTH2_KEY': 'google_key'}}, settings)
self.assertEqual([provider.GoogleOauth2], provider.Registry.enabled())
self.assertEqual('google_key', settings.SOCIAL_AUTH_GOOGLE_OAUTH2_KEY)
def test_can_enable_linkedin_oauth2(self):
auth_settings.apply_settings({'LinkedIn': {'SOCIAL_AUTH_LINKEDIN_OAUTH2_KEY': 'linkedin_key'}}, settings)
self.assertEqual([provider.LinkedInOauth2], provider.Registry.enabled())
self.assertEqual('linkedin_key', settings.SOCIAL_AUTH_LINKEDIN_OAUTH2_KEY)
......@@ -5,13 +5,15 @@ Used by Django and non-Django tests; must not have Django deps.
"""
from contextlib import contextmanager
import unittest
from django.conf import settings
import django.test
import mock
from third_party_auth import provider
from third_party_auth.models import OAuth2ProviderConfig, SAMLProviderConfig, SAMLConfiguration, cache as config_cache
AUTH_FEATURES_KEY = 'ENABLE_THIRD_PARTY_AUTH'
AUTH_FEATURE_ENABLED = AUTH_FEATURES_KEY in settings.FEATURES
class FakeDjangoSettings(object):
......@@ -23,22 +25,66 @@ class FakeDjangoSettings(object):
setattr(self, key, value)
class TestCase(unittest.TestCase):
"""Base class for auth test cases."""
# Allow access to protected methods (or module-protected methods) under
# test.
# pylint: disable-msg=protected-access
def setUp(self):
super(TestCase, self).setUp()
self._original_providers = provider.Registry._get_all()
provider.Registry._reset()
class ThirdPartyAuthTestMixin(object):
""" Helper methods useful for testing third party auth functionality """
def tearDown(self):
provider.Registry._reset()
provider.Registry.configure_once(self._original_providers)
super(TestCase, self).tearDown()
config_cache.clear()
super(ThirdPartyAuthTestMixin, self).tearDown()
def enable_saml(self, **kwargs):
""" Enable SAML support (via SAMLConfiguration, not for any particular provider) """
kwargs.setdefault('enabled', True)
SAMLConfiguration(**kwargs).save()
@staticmethod
def configure_oauth_provider(**kwargs):
""" Update the settings for an OAuth2-based third party auth provider """
obj = OAuth2ProviderConfig(**kwargs)
obj.save()
return obj
def configure_saml_provider(self, **kwargs):
""" Update the settings for a SAML-based third party auth provider """
self.assertTrue(SAMLConfiguration.is_enabled(), "SAML Provider Configuration only works if SAML is enabled.")
obj = SAMLProviderConfig(**kwargs)
obj.save()
return obj
@classmethod
def configure_google_provider(cls, **kwargs):
""" Update the settings for the Google third party auth provider/backend """
kwargs.setdefault("name", "Google")
kwargs.setdefault("backend_name", "google-oauth2")
kwargs.setdefault("icon_class", "fa-google-plus")
kwargs.setdefault("key", "test-fake-key.apps.googleusercontent.com")
kwargs.setdefault("secret", "opensesame")
return cls.configure_oauth_provider(**kwargs)
@classmethod
def configure_facebook_provider(cls, **kwargs):
""" Update the settings for the Facebook third party auth provider/backend """
kwargs.setdefault("name", "Facebook")
kwargs.setdefault("backend_name", "facebook")
kwargs.setdefault("icon_class", "fa-facebook")
kwargs.setdefault("key", "FB_TEST_APP")
kwargs.setdefault("secret", "opensesame")
return cls.configure_oauth_provider(**kwargs)
@classmethod
def configure_linkedin_provider(cls, **kwargs):
""" Update the settings for the LinkedIn third party auth provider/backend """
kwargs.setdefault("name", "LinkedIn")
kwargs.setdefault("backend_name", "linkedin-oauth2")
kwargs.setdefault("icon_class", "fa-linkedin")
kwargs.setdefault("key", "test")
kwargs.setdefault("secret", "test")
return cls.configure_oauth_provider(**kwargs)
class TestCase(ThirdPartyAuthTestMixin, django.test.TestCase):
"""Base class for auth test cases."""
pass
@contextmanager
......
......@@ -9,9 +9,11 @@ from social.apps.django_app.default.models import UserSocialAuth
from student.tests.factories import UserFactory
from .testutil import ThirdPartyAuthTestMixin
@httpretty.activate
class ThirdPartyOAuthTestMixin(object):
class ThirdPartyOAuthTestMixin(ThirdPartyAuthTestMixin):
"""
Mixin with tests for third party oauth views. A TestCase that includes
this must define the following:
......@@ -32,6 +34,10 @@ class ThirdPartyOAuthTestMixin(object):
if create_user:
self.user = UserFactory()
UserSocialAuth.objects.create(user=self.user, provider=self.BACKEND, uid=self.social_uid)
if self.BACKEND == 'google-oauth2':
self.configure_google_provider(enabled=True)
elif self.BACKEND == 'facebook':
self.configure_facebook_provider(enabled=True)
def _setup_provider_response(self, success=False, email=''):
"""
......
......@@ -3,9 +3,10 @@ Extra views required for SSO
"""
from django.conf import settings
from django.core.urlresolvers import reverse
from django.http import HttpResponse, HttpResponseServerError
from django.http import HttpResponse, HttpResponseServerError, Http404
from django.shortcuts import redirect
from social.apps.django_app.utils import load_strategy, load_backend
from .models import SAMLConfiguration
def inactive_user_view(request):
......@@ -24,6 +25,8 @@ def saml_metadata_view(request):
Get the Service Provider metadata for this edx-platform instance.
You must send this XML to any Shibboleth Identity Provider that you wish to use.
"""
if not SAMLConfiguration.is_enabled():
raise Http404
complete_url = reverse('social:complete', args=("tpa-saml", ))
if settings.APPEND_SLASH and not complete_url.endswith('/'):
complete_url = complete_url + '/' # Required for consistency
......
......@@ -232,7 +232,7 @@ class CombinedLoginAndRegisterPage(PageObject):
Only the "Dummy" provider is used for bok choy because it is the only
one that doesn't send traffic to external servers.
"""
self.q(css="button.{}-Dummy".format(self.current_form)).click()
self.q(css="button.{}-oa2-dummy".format(self.current_form)).click()
def password_reset(self, email):
"""Navigates to, fills in, and submits the password reset form.
......
......@@ -437,9 +437,10 @@ class AccountSettingsPageTest(AccountSettingsTestMixin, WebAppTest):
Currently there is no way to test the whole authentication process
because that would require accounts with the providers.
"""
for field_id, title, link_title in [
['auth-facebook', 'Facebook', 'Link'],
['auth-google', 'Google', 'Link'],
]:
providers = (
['auth-oa2-facebook', 'Facebook', 'Link'],
['auth-oa2-google-oauth2', 'Google', 'Link'],
)
for field_id, title, link_title in providers:
self.assertEqual(self.account_settings_page.title_for_field(field_id), title)
self.assertEqual(self.account_settings_page.link_title_for_link_field(field_id), link_title)
......@@ -166,7 +166,7 @@ class LoginFromCombinedPageTest(UniqueCourseTest):
# Now unlink the account (To test the account settings view and also to prevent cross-test side effects)
account_settings = AccountSettingsPage(self.browser).visit()
field_id = "auth-dummy"
field_id = "auth-oa2-dummy"
account_settings.wait_for_field(field_id)
self.assertEqual("Unlink", account_settings.link_title_for_link_field(field_id))
account_settings.click_on_link_in_link_field(field_id)
......@@ -305,7 +305,7 @@ class RegisterFromCombinedPageTest(UniqueCourseTest):
# Now unlink the account (To test the account settings view and also to prevent cross-test side effects)
account_settings = AccountSettingsPage(self.browser).visit()
field_id = "auth-dummy"
field_id = "auth-oa2-dummy"
account_settings.wait_for_field(field_id)
self.assertEqual("Unlink", account_settings.link_title_for_link_field(field_id))
account_settings.click_on_link_in_link_field(field_id)
......
[
{
"pk": 1,
"model": "third_party_auth.oauth2providerconfig",
"fields": {
"enabled": true,
"change_date": "2001-02-03T04:05:06Z",
"changed_by": null,
"name": "Google",
"icon_class": "fa-google-plus",
"backend_name": "google-oauth2",
"key": "test",
"secret": "test",
"other_settings": "{}"
}
},
{
"pk": 2,
"model": "third_party_auth.oauth2providerconfig",
"fields": {
"enabled": true,
"change_date": "2001-02-03T04:05:06Z",
"changed_by": null,
"name": "Facebook",
"icon_class": "fa-facebook",
"backend_name": "facebook",
"key": "test",
"secret": "test",
"other_settings": "{}"
}
},
{
"pk": 3,
"model": "third_party_auth.oauth2providerconfig",
"fields": {
"enabled": true,
"change_date": "2001-02-03T04:05:06Z",
"changed_by": null,
"name": "Dummy",
"icon_class": "fa-sign-in",
"backend_name": "dummy",
"key": "",
"secret": "",
"other_settings": "{}"
}
}
]
......@@ -23,7 +23,7 @@ from openedx.core.djangoapps.user_api.accounts.api import activate_account, crea
from openedx.core.djangoapps.user_api.accounts import EMAIL_MAX_LENGTH
from student.tests.factories import CourseModeFactory, UserFactory
from student_account.views import account_settings_context
from third_party_auth.tests.testutil import simulate_running_pipeline
from third_party_auth.tests.testutil import simulate_running_pipeline, ThirdPartyAuthTestMixin
from util.testing import UrlResetMixin
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory
......@@ -204,7 +204,7 @@ class StudentAccountUpdateTest(UrlResetMixin, TestCase):
@ddt.ddt
class StudentAccountLoginAndRegistrationTest(UrlResetMixin, ModuleStoreTestCase):
class StudentAccountLoginAndRegistrationTest(ThirdPartyAuthTestMixin, UrlResetMixin, ModuleStoreTestCase):
""" Tests for the student account views that update the user's account information. """
USERNAME = "bob"
......@@ -214,6 +214,9 @@ class StudentAccountLoginAndRegistrationTest(UrlResetMixin, ModuleStoreTestCase)
@mock.patch.dict(settings.FEATURES, {'EMBARGO': True})
def setUp(self):
super(StudentAccountLoginAndRegistrationTest, self).setUp('embargo')
# For these tests, two third party auth providers are enabled by default:
self.configure_google_provider(enabled=True)
self.configure_facebook_provider(enabled=True)
@ddt.data(
("account_login", "login"),
......@@ -290,7 +293,7 @@ class StudentAccountLoginAndRegistrationTest(UrlResetMixin, ModuleStoreTestCase)
@ddt.unpack
def test_third_party_auth(self, url_name, current_backend, current_provider):
params = [
('course_id', 'edX/DemoX/Demo_Course'),
('course_id', 'course-v1:Org+Course+Run'),
('enrollment_action', 'enroll'),
('course_mode', 'honor'),
('email_opt_in', 'true'),
......@@ -310,12 +313,14 @@ class StudentAccountLoginAndRegistrationTest(UrlResetMixin, ModuleStoreTestCase)
# This relies on the THIRD_PARTY_AUTH configuration in the test settings
expected_providers = [
{
"id": "oa2-facebook",
"name": "Facebook",
"iconClass": "fa-facebook",
"loginUrl": self._third_party_login_url("facebook", "login", params),
"registerUrl": self._third_party_login_url("facebook", "register", params)
},
{
"id": "oa2-google-oauth2",
"name": "Google",
"iconClass": "fa-google-plus",
"loginUrl": self._third_party_login_url("google-oauth2", "login", params),
......@@ -347,11 +352,14 @@ class StudentAccountLoginAndRegistrationTest(UrlResetMixin, ModuleStoreTestCase)
def _assert_third_party_auth_data(self, response, current_backend, current_provider, providers):
"""Verify that third party auth info is rendered correctly in a DOM data attribute. """
finish_auth_url = None
if current_backend:
finish_auth_url = reverse("social:complete", kwargs={"backend": current_backend}) + "?"
auth_info = markupsafe.escape(
json.dumps({
"currentProvider": current_provider,
"providers": providers,
"finishAuthUrl": "/auth/complete/{}?".format(current_backend) if current_backend else None,
"finishAuthUrl": finish_auth_url,
"errorMessage": None,
})
)
......@@ -382,7 +390,7 @@ class StudentAccountLoginAndRegistrationTest(UrlResetMixin, ModuleStoreTestCase)
})
class AccountSettingsViewTest(TestCase):
class AccountSettingsViewTest(ThirdPartyAuthTestMixin, TestCase):
""" Tests for the account settings view. """
USERNAME = 'student'
......@@ -406,6 +414,10 @@ class AccountSettingsViewTest(TestCase):
self.request = RequestFactory()
self.request.user = self.user
# For these tests, two third party auth providers are enabled by default:
self.configure_google_provider(enabled=True)
self.configure_facebook_provider(enabled=True)
# Python-social saves auth failure notifcations in Django messages.
# See pipeline.get_duplicate_provider() for details.
self.request.COOKIES = {}
......
......@@ -171,15 +171,16 @@ def _third_party_auth_context(request, redirect_to):
if third_party_auth.is_enabled():
context["providers"] = [
{
"name": enabled.NAME,
"iconClass": enabled.ICON_CLASS,
"id": enabled.provider_id,
"name": enabled.name,
"iconClass": enabled.icon_class,
"loginUrl": pipeline.get_login_url(
enabled.NAME,
enabled.provider_id,
pipeline.AUTH_ENTRY_LOGIN,
redirect_url=redirect_to,
),
"registerUrl": pipeline.get_login_url(
enabled.NAME,
enabled.provider_id,
pipeline.AUTH_ENTRY_REGISTER,
redirect_url=redirect_to,
),
......@@ -190,13 +191,14 @@ def _third_party_auth_context(request, redirect_to):
running_pipeline = pipeline.get(request)
if running_pipeline is not None:
current_provider = third_party_auth.provider.Registry.get_from_pipeline(running_pipeline)
context["currentProvider"] = current_provider.NAME
context["finishAuthUrl"] = pipeline.get_complete_url(current_provider.BACKEND_CLASS.name)
context["currentProvider"] = current_provider.name
context["finishAuthUrl"] = pipeline.get_complete_url(current_provider.backend_name)
# Check for any error messages we may want to display:
for msg in messages.get_messages(request):
if msg.extra_tags.split()[0] == "social-auth":
context['errorMessage'] = unicode(msg)
# msg may or may not be translated. Try translating [again] in case we are able to:
context['errorMessage'] = _(msg) # pylint: disable=translation-of-non-string
break
return context
......@@ -368,19 +370,20 @@ def account_settings_context(request):
auth_states = pipeline.get_provider_user_states(user)
context['auth']['providers'] = [{
'name': state.provider.NAME, # The name of the provider e.g. Facebook
'id': state.provider.provider_id,
'name': state.provider.name, # The name of the provider e.g. Facebook
'connected': state.has_account, # Whether the user's edX account is connected with the provider.
# If the user is not connected, they should be directed to this page to authenticate
# with the particular provider.
'connect_url': pipeline.get_login_url(
state.provider.NAME,
state.provider.provider_id,
pipeline.AUTH_ENTRY_ACCOUNT_SETTINGS,
# The url the user should be directed to after the auth process has completed.
redirect_url=reverse('account_settings'),
),
# If the user is connected, sending a POST request to this url removes the connection
# information for this provider from their edX account.
'disconnect_url': pipeline.get_disconnect_url(state.provider.NAME, state.association_id),
'disconnect_url': pipeline.get_disconnect_url(state.provider.provider_id, state.association_id),
} for state in auth_states]
return context
......@@ -536,29 +536,21 @@ TIME_ZONE_DISPLAYED_FOR_DEADLINES = ENV_TOKENS.get("TIME_ZONE_DISPLAYED_FOR_DEAD
X_FRAME_OPTIONS = ENV_TOKENS.get('X_FRAME_OPTIONS', X_FRAME_OPTIONS)
##### Third-party auth options ################################################
THIRD_PARTY_AUTH = AUTH_TOKENS.get('THIRD_PARTY_AUTH', THIRD_PARTY_AUTH)
# The reduced session expiry time during the third party login pipeline. (Value in seconds)
SOCIAL_AUTH_PIPELINE_TIMEOUT = ENV_TOKENS.get('SOCIAL_AUTH_PIPELINE_TIMEOUT', 600)
if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
AUTHENTICATION_BACKENDS = (
ENV_TOKENS.get('THIRD_PARTY_AUTH_BACKENDS', [
'social.backends.google.GoogleOAuth2',
'social.backends.linkedin.LinkedinOAuth2',
'social.backends.facebook.FacebookOAuth2',
'third_party_auth.saml.SAMLAuthBackend',
]) + list(AUTHENTICATION_BACKENDS)
)
##### SAML configuration for third_party_auth #####
# The reduced session expiry time during the third party login pipeline. (Value in seconds)
SOCIAL_AUTH_PIPELINE_TIMEOUT = ENV_TOKENS.get('SOCIAL_AUTH_PIPELINE_TIMEOUT', 600)
if 'SOCIAL_AUTH_TPA_SAML_SP_ENTITY_ID' in ENV_TOKENS:
SOCIAL_AUTH_TPA_SAML_SP_ENTITY_ID = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_ENTITY_ID')
SOCIAL_AUTH_TPA_SAML_SP_NAMEID_FORMAT = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_NAMEID_FORMAT', 'unspecified')
SOCIAL_AUTH_TPA_SAML_SP_EXTRA = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_EXTRA', {})
SOCIAL_AUTH_TPA_SAML_ORG_INFO = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_ORG_INFO')
SOCIAL_AUTH_TPA_SAML_TECHNICAL_CONTACT = ENV_TOKENS.get(
'SOCIAL_AUTH_TPA_SAML_TECHNICAL_CONTACT',
{"givenName": "Technical Support", "emailAddress": TECH_SUPPORT_EMAIL}
)
SOCIAL_AUTH_TPA_SAML_SUPPORT_CONTACT = ENV_TOKENS.get(
'SOCIAL_AUTH_TPA_SAML_SUPPORT_CONTACT',
{"givenName": "Support", "emailAddress": TECH_SUPPORT_EMAIL}
)
SOCIAL_AUTH_TPA_SAML_SECURITY_CONFIG = ENV_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SECURITY_CONFIG', {})
SOCIAL_AUTH_TPA_SAML_SP_PUBLIC_CERT = AUTH_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_PUBLIC_CERT')
SOCIAL_AUTH_TPA_SAML_SP_PRIVATE_KEY = AUTH_TOKENS.get('SOCIAL_AUTH_TPA_SAML_SP_PRIVATE_KEY')
# third_party_auth config moved to ConfigurationModels. This is for data migration only:
THIRD_PARTY_AUTH_OLD_CONFIG = AUTH_TOKENS.get('THIRD_PARTY_AUTH', None)
##### OAUTH2 Provider ##############
if FEATURES.get('ENABLE_OAUTH2_PROVIDER'):
......
......@@ -117,17 +117,6 @@
"username": "lms"
},
"SECRET_KEY": "",
"THIRD_PARTY_AUTH": {
"Dummy": {},
"Google": {
"SOCIAL_AUTH_GOOGLE_OAUTH2_KEY": "test",
"SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET": "test"
},
"Facebook": {
"SOCIAL_AUTH_FACEBOOK_KEY": "test",
"SOCIAL_AUTH_FACEBOOK_SECRET": "test"
}
},
"DJFS": {
"type": "s3fs",
"bucket": "test",
......
......@@ -79,7 +79,6 @@
"ENABLE_INSTRUCTOR_ANALYTICS": true,
"ENABLE_S3_GRADE_DOWNLOADS": true,
"ENABLE_THIRD_PARTY_AUTH": true,
"ENABLE_DUMMY_THIRD_PARTY_AUTH_PROVIDER": true,
"ENABLE_COMBINED_LOGIN_REGISTRATION": true,
"PREVIEW_LMS_BASE": "localhost:8003",
"SUBDOMAIN_BRANDING": false,
......@@ -119,6 +118,13 @@
"SYSLOG_SERVER": "",
"TECH_SUPPORT_EMAIL": "technical@example.com",
"THEME_NAME": "",
"THIRD_PARTY_AUTH_BACKENDS": [
"social.backends.google.GoogleOAuth2",
"social.backends.linkedin.LinkedinOAuth2",
"social.backends.facebook.FacebookOAuth2",
"third_party_auth.dummy.DummyBackend",
"third_party_auth.saml.SAMLAuthBackend"
],
"TIME_ZONE": "America/New_York",
"WIKI_ENABLED": true
}
......@@ -2385,10 +2385,6 @@ for app_name in OPTIONAL_APPS:
continue
INSTALLED_APPS += (app_name,)
# Stub for third_party_auth options.
# See common/djangoapps/third_party_auth/settings.py for configuration details.
THIRD_PARTY_AUTH = {}
### ADVANCED_SECURITY_CONFIG
# Empty by default
ADVANCED_SECURITY_CONFIG = {}
......
......@@ -170,6 +170,10 @@ FEATURES['STORE_BILLING_INFO'] = True
FEATURES['ENABLE_PAID_COURSE_REGISTRATION'] = True
FEATURES['ENABLE_COSMETIC_DISPLAY_PRICE'] = True
########################## Third Party Auth #######################
if FEATURES.get('ENABLE_THIRD_PARTY_AUTH') and 'third_party_auth.dummy.DummyBackend' not in AUTHENTICATION_BACKENDS:
AUTHENTICATION_BACKENDS = ['third_party_auth.dummy.DummyBackend'] + list(AUTHENTICATION_BACKENDS)
#####################################################################
# See if the developer has any local overrides.
......
......@@ -238,18 +238,13 @@ PASSWORD_COMPLEXITY = {}
######### Third-party auth ##########
FEATURES['ENABLE_THIRD_PARTY_AUTH'] = True
THIRD_PARTY_AUTH = {
"Google": {
"SOCIAL_AUTH_GOOGLE_OAUTH2_KEY": "test",
"SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET": "test",
},
"Facebook": {
"SOCIAL_AUTH_FACEBOOK_KEY": "test",
"SOCIAL_AUTH_FACEBOOK_SECRET": "test",
},
}
FEATURES['ENABLE_DUMMY_THIRD_PARTY_AUTH_PROVIDER'] = True
AUTHENTICATION_BACKENDS = (
'social.backends.google.GoogleOAuth2',
'social.backends.linkedin.LinkedinOAuth2',
'social.backends.facebook.FacebookOAuth2',
'third_party_auth.dummy.DummyBackend',
'third_party_auth.saml.SAMLAuthBackend',
) + AUTHENTICATION_BACKENDS
################################## OPENID #####################################
FEATURES['AUTH_USE_OPENID'] = True
......
......@@ -141,4 +141,4 @@ def enable_third_party_auth():
"""
from third_party_auth import settings as auth_settings
auth_settings.apply_settings(settings.THIRD_PARTY_AUTH, settings)
auth_settings.apply_settings(settings)
......@@ -32,12 +32,14 @@ define(['backbone', 'jquery', 'underscore', 'common/js/spec_helpers/ajax_helpers
var AUTH_DATA = {
'providers': [
{
'id': 'oa2-network1',
'name': "Network1",
'connected': true,
'connect_url': 'yetanother1.com/auth/connect',
'disconnect_url': 'yetanother1.com/auth/disconnect'
},
{
'id': 'oa2-network2',
'name': "Network2",
'connected': true,
'connect_url': 'yetanother2.com/auth/connect',
......
......@@ -25,12 +25,14 @@ define([
currentProvider: null,
providers: [
{
id: 'oa2-google-oauth2',
name: 'Google',
iconClass: 'fa-google-plus',
loginUrl: '/auth/login/google-oauth2/?auth_entry=account_login',
registerUrl: '/auth/login/google-oauth2/?auth_entry=account_register'
},
{
id: 'oa2-facebook',
name: 'Facebook',
iconClass: 'fa-facebook',
loginUrl: '/auth/login/facebook/?auth_entry=account_login',
......@@ -195,8 +197,8 @@ define([
createLoginView(this);
// Verify that Google and Facebook registration buttons are displayed
expect($('.button-Google')).toBeVisible();
expect($('.button-Facebook')).toBeVisible();
expect($('.button-oa2-google-oauth2')).toBeVisible();
expect($('.button-oa2-facebook')).toBeVisible();
});
it('displays a link to the password reset form', function() {
......
......@@ -32,12 +32,14 @@ define([
currentProvider: null,
providers: [
{
id: 'oa2-google-oauth2',
name: 'Google',
iconClass: 'fa-google-plus',
loginUrl: '/auth/login/google-oauth2/?auth_entry=account_login',
registerUrl: '/auth/login/google-oauth2/?auth_entry=account_register'
},
{
id: 'oa2-facebook',
name: 'Facebook',
iconClass: 'fa-facebook',
loginUrl: '/auth/login/facebook/?auth_entry=account_login',
......@@ -284,8 +286,8 @@ define([
createRegisterView(this);
// Verify that Google and Facebook registration buttons are displayed
expect($('.button-Google')).toBeVisible();
expect($('.button-Facebook')).toBeVisible();
expect($('.button-oa2-google-oauth2')).toBeVisible();
expect($('.button-oa2-facebook')).toBeVisible();
});
it('validates registration form fields', function() {
......
......@@ -137,7 +137,7 @@
screenReaderTitle: interpolate_text(
gettext("Connect your {accountName} account"), {accountName: provider['name']}
),
valueAttribute: 'auth-' + provider.name.toLowerCase(),
valueAttribute: 'auth-' + provider.id,
helpMessage: '',
connected: provider.connected,
connectUrl: provider.connect_url,
......
......@@ -532,30 +532,30 @@
margin-right: 0;
}
&.button-Google:hover, &.button-Google:focus {
&.button-oa2-google-oauth2:hover, &.button-oa2-google-oauth2:focus {
background-color: #dd4b39;
border: 1px solid #A5382B;
}
&.button-Google:hover {
&.button-oa2-google-oauth2:hover {
box-shadow: 0 2px 1px 0 #8D3024;
}
&.button-Facebook:hover, &.button-Facebook:focus {
&.button-oa2-facebook:hover, &.button-oa2-facebook:focus {
background-color: #3b5998;
border: 1px solid #263A62;
}
&.button-Facebook:hover {
&.button-oa2-facebook:hover {
box-shadow: 0 2px 1px 0 #30487C;
}
&.button-LinkedIn:hover , &.button-LinkedIn:focus {
&.button-oa2-linkedin-oauth2:hover , &.button-oa2-linkedin-oauth2:focus {
background-color: #0077b5;
border: 1px solid #06527D;
}
&.button-LinkedIn:hover {
&.button-oa2-linkedin-oauth2:hover {
box-shadow: 0 2px 1px 0 #005D8E;
}
......
......@@ -388,7 +388,7 @@ $sm-btn-linkedin: #0077b5;
margin-bottom: $baseline;
}
&.button-Google {
&.button-oa2-google-oauth2 {
color: $sm-btn-google;
.icon {
......@@ -407,7 +407,7 @@ $sm-btn-linkedin: #0077b5;
}
}
&.button-Facebook {
&.button-oa2-facebook {
color: $sm-btn-facebook;
.icon {
......@@ -426,7 +426,7 @@ $sm-btn-linkedin: #0077b5;
}
}
&.button-LinkedIn {
&.button-oa2-linkedin-oauth2 {
color: $sm-btn-linkedin;
.icon {
......
......@@ -221,7 +221,7 @@ from microsite_configuration import microsite
% for enabled in provider.Registry.enabled():
## Translators: provider_name is the name of an external, third-party user authentication provider (like Google or LinkedIn).
<button type="submit" class="button button-primary button-${enabled.NAME} login-${enabled.NAME}" onclick="thirdPartySignin(event, '${pipeline_url[enabled.NAME]}');"><span class="icon fa ${enabled.ICON_CLASS}"></span>${_('Sign in with {provider_name}').format(provider_name=enabled.NAME)}</button>
<button type="submit" class="button button-primary button-${enabled.provider_id} login-${enabled.provider_id}" onclick="thirdPartySignin(event, '${pipeline_url[enabled.provider_id]}');"><span class="icon fa ${enabled.icon_class}"></span>${_('Sign in with {provider_name}').format(provider_name=enabled.name)}</button>
% endfor
</div>
......
......@@ -132,7 +132,7 @@ import calendar
% for enabled in provider.Registry.enabled():
## Translators: provider_name is the name of an external, third-party user authentication service (like Google or LinkedIn).
<button type="submit" class="button button-primary button-${enabled.NAME} register-${enabled.NAME}" onclick="thirdPartySignin(event, '${pipeline_urls[enabled.NAME]}');"><span class="icon fa ${enabled.ICON_CLASS}"></span>${_('Sign up with {provider_name}').format(provider_name=enabled.NAME)}</button>
<button type="submit" class="button button-primary button-${enabled.provider_id} register-${enabled.provider_id}" onclick="thirdPartySignin(event, '${pipeline_urls[enabled.provider_id]}');"><span class="icon fa ${enabled.icon_class}"></span>${_('Sign up with {provider_name}').format(provider_name=enabled.name)}</button>
% endfor
</div>
......
......@@ -49,7 +49,7 @@
<% _.each( context.providers, function( provider ) {
if ( provider.loginUrl ) { %>
<button type="button" class="button button-primary button-<%- provider.name %> login-provider login-<%- provider.name %>" data-provider-url="<%- provider.loginUrl %>">
<button type="button" class="button button-primary button-<%- provider.id %> login-provider login-<%- provider.id %>" data-provider-url="<%- provider.loginUrl %>">
<div class="icon fa <%- provider.iconClass %>" aria-hidden="true"></div>
<%- provider.name %>
</button>
......
......@@ -29,7 +29,7 @@
<%
_.each( context.providers, function( provider) {
if ( provider.registerUrl ) { %>
<button type="button" class="button button-primary button-<%- provider.name %> login-provider register-<%- provider.name %>" data-provider-url="<%- provider.registerUrl %>">
<button type="button" class="button button-primary button-<%- provider.id %> login-provider register-<%- provider.id %>" data-provider-url="<%- provider.registerUrl %>">
<span class="icon fa <%- provider.iconClass %>" aria-hidden="true"></span>
<%- provider.name %>
</button>
......
......@@ -19,10 +19,10 @@ from third_party_auth import pipeline
<i class="icon fa fa-unlink"></i><span class="copy">${_('Not Linked')}</span>
% endif
</div>
<span class="provider">${state.provider.NAME}</span>
<span class="provider">${state.provider.name}</span>
<span class="control">
<form
action="${pipeline.get_disconnect_url(state.provider.NAME, state.association_id)}"
action="${pipeline.get_disconnect_url(state.provider.provider_id, state.association_id)}"
method="post"
name="${state.get_unlink_form_name()}">
% if state.has_account:
......@@ -33,7 +33,7 @@ from third_party_auth import pipeline
${_("Unlink")}
</a>
% else:
<a href="${pipeline.get_login_url(state.provider.NAME, pipeline.AUTH_ENTRY_PROFILE)}">
<a href="${pipeline.get_login_url(state.provider.provider_id, pipeline.AUTH_ENTRY_PROFILE)}">
## Translators: clicking on this creates a link between a user's edX account and their account with an external authentication provider (like Google or LinkedIn).
${_("Link")}
</a>
......
......@@ -25,7 +25,7 @@ from opaque_keys.edx.locations import SlashSeparatedCourseKey
from django_comment_common import models
from student.tests.factories import UserFactory
from third_party_auth.tests.testutil import simulate_running_pipeline
from third_party_auth.tests.testutil import simulate_running_pipeline, ThirdPartyAuthTestMixin
from third_party_auth.tests.utils import (
ThirdPartyOAuthTestMixin, ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle
)
......@@ -800,7 +800,7 @@ class PasswordResetViewTest(ApiTestCase):
@ddt.ddt
@skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class RegistrationViewTest(ApiTestCase):
class RegistrationViewTest(ThirdPartyAuthTestMixin, ApiTestCase):
"""Tests for the registration end-points of the User API. """
maxDiff = None
......@@ -907,6 +907,7 @@ class RegistrationViewTest(ApiTestCase):
def test_register_form_third_party_auth_running(self):
no_extra_fields_setting = {}
self.configure_google_provider(enabled=True)
with simulate_running_pipeline(
"openedx.core.djangoapps.user_api.views.third_party_auth.pipeline",
"google-oauth2", email="bob@example.com",
......
......@@ -32,7 +32,7 @@ git+https://github.com/hmarr/django-debug-toolbar-mongo.git@b0686a76f1ce3532088c
-e git+https://github.com/jazkarta/ccx-keys.git@e6b03704b1bb97c1d2f31301ecb4e3a687c536ea#egg=ccx-keys
# For SAML Support (To be moved to PyPi installation in base.txt once our changes are merged):
-e git+https://github.com/open-craft/python-saml.git@9602b8133056d8c3caa7c3038761147df3d4b257#egg=python-saml
-e git+https://github.com/open-craft/python-social-auth.git@17def186d4bb7165f9c37037936997ef39ae2f29#egg=python-social-auth
-e git+https://github.com/open-craft/python-social-auth.git@02ab628b8961b969021de87aeb23551da4e751b7#egg=python-social-auth
# Our libraries:
-e git+https://github.com/edx/XBlock.git@74fdc5a361f48e5596acf3846ca3790a33a05253#egg=XBlock
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment