Implement Site settings for Third Party Auth providers

parent cdf413bd
...@@ -42,7 +42,7 @@ class ConfigurationModelManager(models.Manager): ...@@ -42,7 +42,7 @@ class ConfigurationModelManager(models.Manager):
assert self.model.KEY_FIELDS != (), "Just use model.current() if there are no KEY_FIELDS" assert self.model.KEY_FIELDS != (), "Just use model.current() if there are no KEY_FIELDS"
return self.get_queryset().extra( # pylint: disable=no-member return self.get_queryset().extra( # pylint: disable=no-member
where=["{table_name}.id IN ({subquery})".format( where=["{table_name}.id IN ({subquery})".format(
table_name=self.model._meta.db_table, # pylint: disable=protected-access table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member
subquery=self._current_ids_subquery(), subquery=self._current_ids_subquery(),
)], )],
select={'is_active': 1}, # This annotation is used by the admin changelist. sqlite requires '1', not 'True' select={'is_active': 1}, # This annotation is used by the admin changelist. sqlite requires '1', not 'True'
...@@ -57,15 +57,15 @@ class ConfigurationModelManager(models.Manager): ...@@ -57,15 +57,15 @@ class ConfigurationModelManager(models.Manager):
subquery = self._current_ids_subquery() subquery = self._current_ids_subquery()
return self.get_queryset().extra( # pylint: disable=no-member return self.get_queryset().extra( # pylint: disable=no-member
select={'is_active': "{table_name}.id IN ({subquery})".format( select={'is_active': "{table_name}.id IN ({subquery})".format(
table_name=self.model._meta.db_table, # pylint: disable=protected-access table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member
subquery=subquery, subquery=subquery,
)} )}
) )
else: else:
return self.get_queryset().extra( # pylint: disable=no-member return self.get_queryset().extra( # pylint: disable=no-member
select={'is_active': "{table_name}.id = {pk}".format( select={'is_active': "{table_name}.id = {pk}".format(
table_name=self.model._meta.db_table, # pylint: disable=protected-access table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member
pk=self.model.current().pk, pk=self.model.current().pk, # pylint: disable=no-member
)} )}
) )
...@@ -145,9 +145,15 @@ class ConfigurationModel(models.Model): ...@@ -145,9 +145,15 @@ class ConfigurationModel(models.Model):
return current return current
@classmethod @classmethod
def is_enabled(cls): def is_enabled(cls, *key_fields):
"""Returns True if this feature is configured as enabled, else False.""" """
return cls.current().enabled Returns True if this feature is configured as enabled, else False.
Arguments:
key_fields: The positional arguments are the KEY_FIELDS used to identify the
configuration to be checked.
"""
return cls.current(*key_fields).enabled
@classmethod @classmethod
def key_values_cache_key_name(cls, *key_fields): def key_values_cache_key_name(cls, *key_fields):
......
...@@ -33,7 +33,7 @@ class OAuth2ProviderConfigAdmin(KeyedConfigurationModelAdmin): ...@@ -33,7 +33,7 @@ class OAuth2ProviderConfigAdmin(KeyedConfigurationModelAdmin):
def get_list_display(self, request): def get_list_display(self, request):
""" Don't show every single field in the admin change list """ """ Don't show every single field in the admin change list """
return ( return (
'name', 'enabled', 'backend_name', 'secondary', 'skip_registration_form', 'name', 'enabled', 'site', 'backend_name', 'secondary', 'skip_registration_form',
'skip_email_verification', 'change_date', 'changed_by', 'edit_link', 'skip_email_verification', 'change_date', 'changed_by', 'edit_link',
) )
...@@ -52,7 +52,7 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin): ...@@ -52,7 +52,7 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
def get_list_display(self, request): def get_list_display(self, request):
""" Don't show every single field in the admin change list """ """ Don't show every single field in the admin change list """
return ( return (
'name', 'enabled', 'backend_name', 'entity_id', 'metadata_source', 'name', 'enabled', 'site', 'backend_name', 'entity_id', 'metadata_source',
'has_data', 'mode', 'change_date', 'changed_by', 'edit_link', 'has_data', 'mode', 'change_date', 'changed_by', 'edit_link',
) )
...@@ -86,13 +86,13 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin): ...@@ -86,13 +86,13 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin) admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin)
class SAMLConfigurationAdmin(ConfigurationModelAdmin): class SAMLConfigurationAdmin(KeyedConfigurationModelAdmin):
""" Django Admin class for SAMLConfiguration """ """ Django Admin class for SAMLConfiguration """
def get_list_display(self, request): def get_list_display(self, request):
""" Shorten the public/private keys in the change view """ """ Shorten the public/private keys in the change view """
return ( return (
'change_date', 'changed_by', 'enabled', 'entity_id', 'site', 'change_date', 'changed_by', 'enabled', 'entity_id',
'org_info_str', 'key_summary', 'org_info_str', 'key_summary', 'edit_link',
) )
def key_summary(self, inst): def key_summary(self, inst):
...@@ -136,6 +136,7 @@ class LTIProviderConfigAdmin(KeyedConfigurationModelAdmin): ...@@ -136,6 +136,7 @@ class LTIProviderConfigAdmin(KeyedConfigurationModelAdmin):
return ( return (
'name', 'name',
'enabled', 'enabled',
'site',
'lti_consumer_key', 'lti_consumer_key',
'lti_max_timestamp_age', 'lti_max_timestamp_age',
'change_date', 'change_date',
......
...@@ -198,9 +198,9 @@ class LTIAuthBackend(BaseAuth): ...@@ -198,9 +198,9 @@ class LTIAuthBackend(BaseAuth):
""" """
from .models import LTIProviderConfig from .models import LTIProviderConfig
provider_config = LTIProviderConfig.current(lti_consumer_key) provider_config = LTIProviderConfig.current(lti_consumer_key)
if provider_config and provider_config.enabled: if provider_config and provider_config.enabled_for_current_site:
return ( return (
provider_config.enabled, provider_config.enabled_for_current_site,
provider_config.get_lti_consumer_secret(), provider_config.get_lti_consumer_secret(),
provider_config.lti_max_timestamp_age, provider_config.lti_max_timestamp_age,
) )
......
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.conf import settings
from django.db import migrations, models
def fill_oauth2_slug(apps, schema_editor):
"""
Fill in the provider_slug to be the same as backend_name for backwards compatability.
"""
OAuth2ProviderConfig = apps.get_model('third_party_auth', 'OAuth2ProviderConfig')
for config in OAuth2ProviderConfig.objects.all():
config.provider_slug = config.backend_name
config.save()
class Migration(migrations.Migration):
dependencies = [
('sites', '0001_initial'),
('third_party_auth', '0004_add_visible_field'),
]
operations = [
migrations.AddField(
model_name='oauth2providerconfig',
name='provider_slug',
field=models.SlugField(
default='temp',
help_text=b'A short string uniquely identifying this provider. Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"',
max_length=30
),
preserve_default=False,
),
migrations.RunPython(fill_oauth2_slug, reverse_code=migrations.RunPython.noop),
migrations.AddField(
model_name='ltiproviderconfig',
name='site',
field=models.ForeignKey(
related_name='ltiproviderconfigs',
default=settings.SITE_ID,
to='sites.Site',
help_text='The Site that this provider configuration belongs to.'
),
),
migrations.AddField(
model_name='oauth2providerconfig',
name='site',
field=models.ForeignKey(
related_name='oauth2providerconfigs',
default=settings.SITE_ID,
to='sites.Site',
help_text='The Site that this provider configuration belongs to.'
),
),
migrations.AddField(
model_name='samlproviderconfig',
name='site',
field=models.ForeignKey(
related_name='samlproviderconfigs',
default=settings.SITE_ID,
to='sites.Site',
help_text='The Site that this provider configuration belongs to.'
),
),
migrations.AddField(
model_name='samlconfiguration',
name='site',
field=models.ForeignKey(
related_name='samlconfigurations',
default=settings.SITE_ID,
to='sites.Site',
help_text='The Site that this SAML configuration belongs to.'
),
),
]
...@@ -7,6 +7,7 @@ from __future__ import absolute_import ...@@ -7,6 +7,7 @@ from __future__ import absolute_import
from config_models.models import ConfigurationModel, cache from config_models.models import ConfigurationModel, cache
from django.conf import settings from django.conf import settings
from django.contrib.sites.models import Site
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import models from django.db import models
from django.utils import timezone from django.utils import timezone
...@@ -22,6 +23,7 @@ from .lti import LTIAuthBackend, LTI_PARAMS_KEY ...@@ -22,6 +23,7 @@ from .lti import LTIAuthBackend, LTI_PARAMS_KEY
from social.exceptions import SocialAuthBaseException from social.exceptions import SocialAuthBaseException
from social.utils import module_member from social.utils import module_member
from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers
from openedx.core.djangoapps.theming.helpers import get_current_request
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -106,6 +108,14 @@ class ProviderConfig(ConfigurationModel): ...@@ -106,6 +108,14 @@ class ProviderConfig(ConfigurationModel):
'in a separate list of "Institution" login providers.' 'in a separate list of "Institution" login providers.'
), ),
) )
site = models.ForeignKey(
Site,
default=settings.SITE_ID,
related_name='%(class)ss',
help_text=_(
'The Site that this provider configuration belongs to.'
),
)
skip_registration_form = models.BooleanField( skip_registration_form = models.BooleanField(
default=False, default=False,
help_text=_( help_text=_(
...@@ -226,7 +236,14 @@ class ProviderConfig(ConfigurationModel): ...@@ -226,7 +236,14 @@ class ProviderConfig(ConfigurationModel):
Determines whether the provider ought to be shown as an option with Determines whether the provider ought to be shown as an option with
which to authenticate on the login screen, registration screen, and elsewhere. which to authenticate on the login screen, registration screen, and elsewhere.
""" """
return bool(self.enabled and self.accepts_logins and self.visible) return bool(self.enabled_for_current_site and self.accepts_logins and self.visible)
@property
def enabled_for_current_site(self):
"""
Determines if the provider is able to be used with the current site.
"""
return self.enabled and self.site == Site.objects.get_current(get_current_request())
class OAuth2ProviderConfig(ProviderConfig): class OAuth2ProviderConfig(ProviderConfig):
...@@ -235,7 +252,7 @@ class OAuth2ProviderConfig(ProviderConfig): ...@@ -235,7 +252,7 @@ class OAuth2ProviderConfig(ProviderConfig):
Also works for OAuth1 providers. Also works for OAuth1 providers.
""" """
prefix = 'oa2' prefix = 'oa2'
KEY_FIELDS = ('backend_name', ) # Backend name is unique KEY_FIELDS = ('provider_slug', ) # Backend name is unique
backend_name = models.CharField( backend_name = models.CharField(
max_length=50, blank=False, db_index=True, max_length=50, blank=False, db_index=True,
help_text=( help_text=(
...@@ -244,6 +261,12 @@ class OAuth2ProviderConfig(ProviderConfig): ...@@ -244,6 +261,12 @@ class OAuth2ProviderConfig(ProviderConfig):
# To be precise, it's set by AUTHENTICATION_BACKENDS - which aws.py sets from THIRD_PARTY_AUTH_BACKENDS # To be precise, it's set by AUTHENTICATION_BACKENDS - which aws.py sets from THIRD_PARTY_AUTH_BACKENDS
) )
) )
provider_slug = models.SlugField(
max_length=30, db_index=True,
help_text=(
'A short string uniquely identifying this provider. '
'Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"'
))
key = models.TextField(blank=True, verbose_name="Client ID") key = models.TextField(blank=True, verbose_name="Client ID")
secret = models.TextField( secret = models.TextField(
blank=True, blank=True,
...@@ -406,6 +429,15 @@ class SAMLConfiguration(ConfigurationModel): ...@@ -406,6 +429,15 @@ class SAMLConfiguration(ConfigurationModel):
Service Provider and allow users to authenticate via third party SAML Service Provider and allow users to authenticate via third party SAML
Identity Providers (IdPs) Identity Providers (IdPs)
""" """
KEY_FIELDS = ('site_id', )
site = models.ForeignKey(
Site,
default=settings.SITE_ID,
related_name='%(class)ss',
help_text=_(
'The Site that this SAML configuration belongs to.'
),
)
private_key = models.TextField( private_key = models.TextField(
help_text=( help_text=(
'To generate a key pair as two files, run ' 'To generate a key pair as two files, run '
......
""" """
Third-party auth provider configuration API. Third-party auth provider configuration API.
""" """
from django.contrib.sites.models import Site
from openedx.core.djangoapps.theming.helpers import get_current_request
from .models import ( from .models import (
OAuth2ProviderConfig, SAMLConfiguration, SAMLProviderConfig, LTIProviderConfig, OAuth2ProviderConfig, SAMLConfiguration, SAMLProviderConfig, LTIProviderConfig,
_PSA_OAUTH2_BACKENDS, _PSA_SAML_BACKENDS, _LTI_BACKENDS, _PSA_OAUTH2_BACKENDS, _PSA_SAML_BACKENDS, _LTI_BACKENDS,
...@@ -15,20 +18,23 @@ class Registry(object): ...@@ -15,20 +18,23 @@ class Registry(object):
""" """
@classmethod @classmethod
def _enabled_providers(cls): def _enabled_providers(cls):
""" Helper method to iterate over all providers """ """
Helper method that returns a generator used to iterate over all providers
of the current site.
"""
for backend_name in _PSA_OAUTH2_BACKENDS: for backend_name in _PSA_OAUTH2_BACKENDS:
provider = OAuth2ProviderConfig.current(backend_name) provider = OAuth2ProviderConfig.current(backend_name)
if provider.enabled: if provider.enabled_for_current_site:
yield provider yield provider
if SAMLConfiguration.is_enabled(): if SAMLConfiguration.is_enabled(Site.objects.get_current(get_current_request())):
idp_slugs = SAMLProviderConfig.key_values('idp_slug', flat=True) idp_slugs = SAMLProviderConfig.key_values('idp_slug', flat=True)
for idp_slug in idp_slugs: for idp_slug in idp_slugs:
provider = SAMLProviderConfig.current(idp_slug) provider = SAMLProviderConfig.current(idp_slug)
if provider.enabled and provider.backend_name in _PSA_SAML_BACKENDS: if provider.enabled_for_current_site and provider.backend_name in _PSA_SAML_BACKENDS:
yield provider yield provider
for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True): for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
provider = LTIProviderConfig.current(consumer_key) provider = LTIProviderConfig.current(consumer_key)
if provider.enabled and provider.backend_name in _LTI_BACKENDS: if provider.enabled_for_current_site and provider.backend_name in _LTI_BACKENDS:
yield provider yield provider
@classmethod @classmethod
...@@ -69,7 +75,7 @@ class Registry(object): ...@@ -69,7 +75,7 @@ class Registry(object):
@classmethod @classmethod
def get_enabled_by_backend_name(cls, backend_name): def get_enabled_by_backend_name(cls, backend_name):
"""Generator returning all enabled providers that use the specified """Generator returning all enabled providers that use the specified
backend. backend on the current site.
Example: Example:
>>> list(get_enabled_by_backend_name("tpa-saml")) >>> list(get_enabled_by_backend_name("tpa-saml"))
...@@ -84,16 +90,17 @@ class Registry(object): ...@@ -84,16 +90,17 @@ class Registry(object):
""" """
if backend_name in _PSA_OAUTH2_BACKENDS: if backend_name in _PSA_OAUTH2_BACKENDS:
provider = OAuth2ProviderConfig.current(backend_name) provider = OAuth2ProviderConfig.current(backend_name)
if provider.enabled: if provider.enabled_for_current_site:
yield provider yield provider
elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled(): elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled(
Site.objects.get_current(get_current_request())):
idp_names = SAMLProviderConfig.key_values('idp_slug', flat=True) idp_names = SAMLProviderConfig.key_values('idp_slug', flat=True)
for idp_name in idp_names: for idp_name in idp_names:
provider = SAMLProviderConfig.current(idp_name) provider = SAMLProviderConfig.current(idp_name)
if provider.backend_name == backend_name and provider.enabled: if provider.backend_name == backend_name and provider.enabled_for_current_site:
yield provider yield provider
elif backend_name in _LTI_BACKENDS: elif backend_name in _LTI_BACKENDS:
for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True): for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
provider = LTIProviderConfig.current(consumer_key) provider = LTIProviderConfig.current(consumer_key)
if provider.backend_name == backend_name and provider.enabled: if provider.backend_name == backend_name and provider.enabled_for_current_site:
yield provider yield provider
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
Slightly customized python-social-auth backend for SAML 2.0 support Slightly customized python-social-auth backend for SAML 2.0 support
""" """
import logging import logging
from django.contrib.sites.models import Site
from django.http import Http404 from django.http import Http404
from django.utils.functional import cached_property from django.utils.functional import cached_property
from openedx.core.djangoapps.theming.helpers import get_current_request
from social.backends.saml import SAMLAuth, OID_EDU_PERSON_ENTITLEMENT from social.backends.saml import SAMLAuth, OID_EDU_PERSON_ENTITLEMENT
from social.exceptions import AuthForbidden, AuthMissingParameter from social.exceptions import AuthForbidden, AuthMissingParameter
...@@ -41,10 +43,12 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method ...@@ -41,10 +43,12 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
if not self._config.enabled: if not self._config.enabled:
log.error('SAML authentication is not enabled') log.error('SAML authentication is not enabled')
raise Http404 raise Http404
# TODO: remove this check once the fix is merged upstream: # TODO: remove this check once the fix is merged upstream:
# https://github.com/omab/python-social-auth/pull/821 # https://github.com/omab/python-social-auth/pull/821
if 'idp' not in self.strategy.request_data(): if 'idp' not in self.strategy.request_data():
raise AuthMissingParameter(self, 'idp') raise AuthMissingParameter(self, 'idp')
return super(SAMLAuthBackend, self).auth_url() return super(SAMLAuthBackend, self).auth_url()
def _check_entitlements(self, idp, attributes): def _check_entitlements(self, idp, attributes):
...@@ -93,4 +97,4 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method ...@@ -93,4 +97,4 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
@cached_property @cached_property
def _config(self): def _config(self):
from .models import SAMLConfiguration from .models import SAMLConfiguration
return SAMLConfiguration.current() return SAMLConfiguration.current(Site.objects.get_current(get_current_request()))
...@@ -26,7 +26,7 @@ class ConfigurationModelStrategy(DjangoStrategy): ...@@ -26,7 +26,7 @@ class ConfigurationModelStrategy(DjangoStrategy):
""" """
if isinstance(backend, OAuthAuth): if isinstance(backend, OAuthAuth):
provider_config = OAuth2ProviderConfig.current(backend.name) provider_config = OAuth2ProviderConfig.current(backend.name)
if not provider_config.enabled: if not provider_config.enabled_for_current_site:
raise Exception("Can't fetch setting of a disabled backend/provider.") raise Exception("Can't fetch setting of a disabled backend/provider.")
try: try:
return provider_config.get_setting(name) return provider_config.get_setting(name)
......
...@@ -36,16 +36,13 @@ def fetch_saml_metadata(): ...@@ -36,16 +36,13 @@ def fetch_saml_metadata():
num_failed: Number of providers that could not be updated num_failed: Number of providers that could not be updated
num_total: Total number of providers whose metadata was fetched num_total: Total number of providers whose metadata was fetched
""" """
if not SAMLConfiguration.is_enabled():
return (0, 0, 0) # Nothing to do until SAML is enabled.
num_changed, num_failed = 0, 0 num_changed, num_failed = 0, 0
# First make a list of all the metadata XML URLs: # First make a list of all the metadata XML URLs:
url_map = {} url_map = {}
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True): for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
config = SAMLProviderConfig.current(idp_slug) config = SAMLProviderConfig.current(idp_slug)
if not config.enabled: if not config.enabled or not SAMLConfiguration.is_enabled(config.site):
continue continue
url = config.metadata_source url = config.metadata_source
if url not in url_map: if url not in url_map:
......
"""Unit tests for provider.py.""" """Unit tests for provider.py."""
from django.contrib.sites.models import Site
from mock import Mock, patch from mock import Mock, patch
from openedx.core.djangoapps.site_configuration.tests.test_util import with_site_configuration
from third_party_auth import provider from third_party_auth import provider
from third_party_auth.tests import testutil from third_party_auth.tests import testutil
import unittest import unittest
SITE_DOMAIN_A = 'professionalx.example.com'
SITE_DOMAIN_B = 'somethingelse.example.com'
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled') @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled')
class RegistryTest(testutil.TestCase): class RegistryTest(testutil.TestCase):
...@@ -84,6 +89,22 @@ class RegistryTest(testutil.TestCase): ...@@ -84,6 +89,22 @@ class RegistryTest(testutil.TestCase):
self.assertNotIn(no_log_in_provider.provider_id, provider_ids) self.assertNotIn(no_log_in_provider.provider_id, provider_ids)
self.assertIn(normal_provider.provider_id, provider_ids) self.assertIn(normal_provider.provider_id, provider_ids)
def test_provider_enabled_for_current_site(self):
"""
Verify that enabled_for_current_site returns True when the provider matches the current site.
"""
prov = self.configure_google_provider(visible=True, enabled=True, site=Site.objects.get_current())
self.assertEqual(prov.enabled_for_current_site, True)
@with_site_configuration(SITE_DOMAIN_A)
def test_provider_disabled_for_mismatching_site(self):
"""
Verify that enabled_for_current_site returns False when the provider is configured for a different site.
"""
site_b = Site.objects.get_or_create(domain=SITE_DOMAIN_B, name=SITE_DOMAIN_B)[0]
prov = self.configure_google_provider(visible=True, enabled=True, site=site_b)
self.assertEqual(prov.enabled_for_current_site, False)
def test_get_returns_enabled_provider(self): def test_get_returns_enabled_provider(self):
google_provider = self.configure_google_provider(enabled=True) google_provider = self.configure_google_provider(enabled=True)
self.assertEqual(google_provider.id, provider.Registry.get(google_provider.provider_id).id) self.assertEqual(google_provider.id, provider.Registry.get(google_provider.provider_id).id)
......
...@@ -7,6 +7,7 @@ Used by Django and non-Django tests; must not have Django deps. ...@@ -7,6 +7,7 @@ Used by Django and non-Django tests; must not have Django deps.
from contextlib import contextmanager from contextlib import contextmanager
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.sites.models import Site
from provider.oauth2.models import Client as OAuth2Client from provider.oauth2.models import Client as OAuth2Client
from provider import constants from provider import constants
import django.test import django.test
...@@ -76,13 +77,17 @@ class ThirdPartyAuthTestMixin(object): ...@@ -76,13 +77,17 @@ class ThirdPartyAuthTestMixin(object):
@staticmethod @staticmethod
def configure_oauth_provider(**kwargs): def configure_oauth_provider(**kwargs):
""" Update the settings for an OAuth2-based third party auth provider """ """ Update the settings for an OAuth2-based third party auth provider """
kwargs.setdefault('provider_slug', kwargs['backend_name'])
obj = OAuth2ProviderConfig(**kwargs) obj = OAuth2ProviderConfig(**kwargs)
obj.save() obj.save()
return obj return obj
def configure_saml_provider(self, **kwargs): def configure_saml_provider(self, **kwargs):
""" Update the settings for a SAML-based third party auth provider """ """ Update the settings for a SAML-based third party auth provider """
self.assertTrue(SAMLConfiguration.is_enabled(), "SAML Provider Configuration only works if SAML is enabled.") self.assertTrue(
SAMLConfiguration.is_enabled(Site.objects.get_current()),
"SAML Provider Configuration only works if SAML is enabled."
)
obj = SAMLProviderConfig(**kwargs) obj = SAMLProviderConfig(**kwargs)
obj.save() obj.save()
return obj return obj
......
...@@ -36,7 +36,7 @@ def saml_metadata_view(request): ...@@ -36,7 +36,7 @@ def saml_metadata_view(request):
Get the Service Provider metadata for this edx-platform instance. Get the Service Provider metadata for this edx-platform instance.
You must send this XML to any Shibboleth Identity Provider that you wish to use. You must send this XML to any Shibboleth Identity Provider that you wish to use.
""" """
if not SAMLConfiguration.is_enabled(): if not SAMLConfiguration.is_enabled(request.site):
raise Http404 raise Http404
complete_url = reverse('social:complete', args=("tpa-saml", )) complete_url = reverse('social:complete', args=("tpa-saml", ))
if settings.APPEND_SLASH and not complete_url.endswith('/'): if settings.APPEND_SLASH and not complete_url.endswith('/'):
......
...@@ -10,8 +10,10 @@ ...@@ -10,8 +10,10 @@
"icon_class": "fa-google-plus", "icon_class": "fa-google-plus",
"icon_image": null, "icon_image": null,
"backend_name": "google-oauth2", "backend_name": "google-oauth2",
"provider_slug": "google-oauth2",
"key": "test", "key": "test",
"secret": "test", "secret": "test",
"site": 2,
"other_settings": "{}", "other_settings": "{}",
"visible": true "visible": true
} }
...@@ -27,8 +29,10 @@ ...@@ -27,8 +29,10 @@
"icon_class": "fa-facebook", "icon_class": "fa-facebook",
"icon_image": null, "icon_image": null,
"backend_name": "facebook", "backend_name": "facebook",
"provider_slug": "facebook",
"key": "test", "key": "test",
"secret": "test", "secret": "test",
"site": 2,
"other_settings": "{}", "other_settings": "{}",
"visible": true "visible": true
} }
...@@ -44,8 +48,10 @@ ...@@ -44,8 +48,10 @@
"icon_class": "", "icon_class": "",
"icon_image": "test-icon.png", "icon_image": "test-icon.png",
"backend_name": "dummy", "backend_name": "dummy",
"provider_slug": "dummy",
"key": "", "key": "",
"secret": "", "secret": "",
"site": 2,
"other_settings": "{}", "other_settings": "{}",
"visible": true "visible": true
} }
......
...@@ -37,7 +37,8 @@ def with_site_configuration(domain="test.localhost", configuration=None): ...@@ -37,7 +37,8 @@ def with_site_configuration(domain="test.localhost", configuration=None):
with patch('openedx.core.djangoapps.site_configuration.helpers.get_current_site_configuration', with patch('openedx.core.djangoapps.site_configuration.helpers.get_current_site_configuration',
return_value=site_configuration): return_value=site_configuration):
with patch('openedx.core.djangoapps.theming.helpers.get_current_site', return_value=site): with patch('openedx.core.djangoapps.theming.helpers.get_current_site', return_value=site):
return func(*args, **kwargs) with patch('django.contrib.sites.models.SiteManager.get_current', return_value=site):
return func(*args, **kwargs)
return _decorated return _decorated
return _decorator return _decorator
...@@ -63,4 +64,5 @@ def with_site_configuration_context(domain="test.localhost", configuration=None) ...@@ -63,4 +64,5 @@ def with_site_configuration_context(domain="test.localhost", configuration=None)
with patch('openedx.core.djangoapps.site_configuration.helpers.get_current_site_configuration', with patch('openedx.core.djangoapps.site_configuration.helpers.get_current_site_configuration',
return_value=site_configuration): return_value=site_configuration):
with patch('openedx.core.djangoapps.theming.helpers.get_current_site', return_value=site): with patch('openedx.core.djangoapps.theming.helpers.get_current_site', return_value=site):
yield with patch('django.contrib.sites.models.SiteManager.get_current', return_value=site):
yield
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment