Implement Site settings for Third Party Auth providers

parent cdf413bd
......@@ -42,7 +42,7 @@ class ConfigurationModelManager(models.Manager):
assert self.model.KEY_FIELDS != (), "Just use model.current() if there are no KEY_FIELDS"
return self.get_queryset().extra( # pylint: disable=no-member
where=["{table_name}.id IN ({subquery})".format(
table_name=self.model._meta.db_table, # pylint: disable=protected-access
table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member
subquery=self._current_ids_subquery(),
)],
select={'is_active': 1}, # This annotation is used by the admin changelist. sqlite requires '1', not 'True'
......@@ -57,15 +57,15 @@ class ConfigurationModelManager(models.Manager):
subquery = self._current_ids_subquery()
return self.get_queryset().extra( # pylint: disable=no-member
select={'is_active': "{table_name}.id IN ({subquery})".format(
table_name=self.model._meta.db_table, # pylint: disable=protected-access
table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member
subquery=subquery,
)}
)
else:
return self.get_queryset().extra( # pylint: disable=no-member
select={'is_active': "{table_name}.id = {pk}".format(
table_name=self.model._meta.db_table, # pylint: disable=protected-access
pk=self.model.current().pk,
table_name=self.model._meta.db_table, # pylint: disable=protected-access, no-member
pk=self.model.current().pk, # pylint: disable=no-member
)}
)
......@@ -145,9 +145,15 @@ class ConfigurationModel(models.Model):
return current
@classmethod
def is_enabled(cls):
"""Returns True if this feature is configured as enabled, else False."""
return cls.current().enabled
def is_enabled(cls, *key_fields):
"""
Returns True if this feature is configured as enabled, else False.
Arguments:
key_fields: The positional arguments are the KEY_FIELDS used to identify the
configuration to be checked.
"""
return cls.current(*key_fields).enabled
@classmethod
def key_values_cache_key_name(cls, *key_fields):
......
......@@ -33,7 +33,7 @@ class OAuth2ProviderConfigAdmin(KeyedConfigurationModelAdmin):
def get_list_display(self, request):
""" Don't show every single field in the admin change list """
return (
'name', 'enabled', 'backend_name', 'secondary', 'skip_registration_form',
'name', 'enabled', 'site', 'backend_name', 'secondary', 'skip_registration_form',
'skip_email_verification', 'change_date', 'changed_by', 'edit_link',
)
......@@ -52,7 +52,7 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
def get_list_display(self, request):
""" Don't show every single field in the admin change list """
return (
'name', 'enabled', 'backend_name', 'entity_id', 'metadata_source',
'name', 'enabled', 'site', 'backend_name', 'entity_id', 'metadata_source',
'has_data', 'mode', 'change_date', 'changed_by', 'edit_link',
)
......@@ -86,13 +86,13 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin):
admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin)
class SAMLConfigurationAdmin(ConfigurationModelAdmin):
class SAMLConfigurationAdmin(KeyedConfigurationModelAdmin):
""" Django Admin class for SAMLConfiguration """
def get_list_display(self, request):
""" Shorten the public/private keys in the change view """
return (
'change_date', 'changed_by', 'enabled', 'entity_id',
'org_info_str', 'key_summary',
'site', 'change_date', 'changed_by', 'enabled', 'entity_id',
'org_info_str', 'key_summary', 'edit_link',
)
def key_summary(self, inst):
......@@ -136,6 +136,7 @@ class LTIProviderConfigAdmin(KeyedConfigurationModelAdmin):
return (
'name',
'enabled',
'site',
'lti_consumer_key',
'lti_max_timestamp_age',
'change_date',
......
......@@ -198,9 +198,9 @@ class LTIAuthBackend(BaseAuth):
"""
from .models import LTIProviderConfig
provider_config = LTIProviderConfig.current(lti_consumer_key)
if provider_config and provider_config.enabled:
if provider_config and provider_config.enabled_for_current_site:
return (
provider_config.enabled,
provider_config.enabled_for_current_site,
provider_config.get_lti_consumer_secret(),
provider_config.lti_max_timestamp_age,
)
......
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.conf import settings
from django.db import migrations, models
def fill_oauth2_slug(apps, schema_editor):
"""
Fill in the provider_slug to be the same as backend_name for backwards compatability.
"""
OAuth2ProviderConfig = apps.get_model('third_party_auth', 'OAuth2ProviderConfig')
for config in OAuth2ProviderConfig.objects.all():
config.provider_slug = config.backend_name
config.save()
class Migration(migrations.Migration):
dependencies = [
('sites', '0001_initial'),
('third_party_auth', '0004_add_visible_field'),
]
operations = [
migrations.AddField(
model_name='oauth2providerconfig',
name='provider_slug',
field=models.SlugField(
default='temp',
help_text=b'A short string uniquely identifying this provider. Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"',
max_length=30
),
preserve_default=False,
),
migrations.RunPython(fill_oauth2_slug, reverse_code=migrations.RunPython.noop),
migrations.AddField(
model_name='ltiproviderconfig',
name='site',
field=models.ForeignKey(
related_name='ltiproviderconfigs',
default=settings.SITE_ID,
to='sites.Site',
help_text='The Site that this provider configuration belongs to.'
),
),
migrations.AddField(
model_name='oauth2providerconfig',
name='site',
field=models.ForeignKey(
related_name='oauth2providerconfigs',
default=settings.SITE_ID,
to='sites.Site',
help_text='The Site that this provider configuration belongs to.'
),
),
migrations.AddField(
model_name='samlproviderconfig',
name='site',
field=models.ForeignKey(
related_name='samlproviderconfigs',
default=settings.SITE_ID,
to='sites.Site',
help_text='The Site that this provider configuration belongs to.'
),
),
migrations.AddField(
model_name='samlconfiguration',
name='site',
field=models.ForeignKey(
related_name='samlconfigurations',
default=settings.SITE_ID,
to='sites.Site',
help_text='The Site that this SAML configuration belongs to.'
),
),
]
......@@ -7,6 +7,7 @@ from __future__ import absolute_import
from config_models.models import ConfigurationModel, cache
from django.conf import settings
from django.contrib.sites.models import Site
from django.core.exceptions import ValidationError
from django.db import models
from django.utils import timezone
......@@ -22,6 +23,7 @@ from .lti import LTIAuthBackend, LTI_PARAMS_KEY
from social.exceptions import SocialAuthBaseException
from social.utils import module_member
from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers
from openedx.core.djangoapps.theming.helpers import get_current_request
log = logging.getLogger(__name__)
......@@ -106,6 +108,14 @@ class ProviderConfig(ConfigurationModel):
'in a separate list of "Institution" login providers.'
),
)
site = models.ForeignKey(
Site,
default=settings.SITE_ID,
related_name='%(class)ss',
help_text=_(
'The Site that this provider configuration belongs to.'
),
)
skip_registration_form = models.BooleanField(
default=False,
help_text=_(
......@@ -226,7 +236,14 @@ class ProviderConfig(ConfigurationModel):
Determines whether the provider ought to be shown as an option with
which to authenticate on the login screen, registration screen, and elsewhere.
"""
return bool(self.enabled and self.accepts_logins and self.visible)
return bool(self.enabled_for_current_site and self.accepts_logins and self.visible)
@property
def enabled_for_current_site(self):
"""
Determines if the provider is able to be used with the current site.
"""
return self.enabled and self.site == Site.objects.get_current(get_current_request())
class OAuth2ProviderConfig(ProviderConfig):
......@@ -235,7 +252,7 @@ class OAuth2ProviderConfig(ProviderConfig):
Also works for OAuth1 providers.
"""
prefix = 'oa2'
KEY_FIELDS = ('backend_name', ) # Backend name is unique
KEY_FIELDS = ('provider_slug', ) # Backend name is unique
backend_name = models.CharField(
max_length=50, blank=False, db_index=True,
help_text=(
......@@ -244,6 +261,12 @@ class OAuth2ProviderConfig(ProviderConfig):
# To be precise, it's set by AUTHENTICATION_BACKENDS - which aws.py sets from THIRD_PARTY_AUTH_BACKENDS
)
)
provider_slug = models.SlugField(
max_length=30, db_index=True,
help_text=(
'A short string uniquely identifying this provider. '
'Cannot contain spaces and should be a usable as a CSS class. Examples: "ubc", "mit-staging"'
))
key = models.TextField(blank=True, verbose_name="Client ID")
secret = models.TextField(
blank=True,
......@@ -406,6 +429,15 @@ class SAMLConfiguration(ConfigurationModel):
Service Provider and allow users to authenticate via third party SAML
Identity Providers (IdPs)
"""
KEY_FIELDS = ('site_id', )
site = models.ForeignKey(
Site,
default=settings.SITE_ID,
related_name='%(class)ss',
help_text=_(
'The Site that this SAML configuration belongs to.'
),
)
private_key = models.TextField(
help_text=(
'To generate a key pair as two files, run '
......
"""
Third-party auth provider configuration API.
"""
from django.contrib.sites.models import Site
from openedx.core.djangoapps.theming.helpers import get_current_request
from .models import (
OAuth2ProviderConfig, SAMLConfiguration, SAMLProviderConfig, LTIProviderConfig,
_PSA_OAUTH2_BACKENDS, _PSA_SAML_BACKENDS, _LTI_BACKENDS,
......@@ -15,20 +18,23 @@ class Registry(object):
"""
@classmethod
def _enabled_providers(cls):
""" Helper method to iterate over all providers """
"""
Helper method that returns a generator used to iterate over all providers
of the current site.
"""
for backend_name in _PSA_OAUTH2_BACKENDS:
provider = OAuth2ProviderConfig.current(backend_name)
if provider.enabled:
if provider.enabled_for_current_site:
yield provider
if SAMLConfiguration.is_enabled():
if SAMLConfiguration.is_enabled(Site.objects.get_current(get_current_request())):
idp_slugs = SAMLProviderConfig.key_values('idp_slug', flat=True)
for idp_slug in idp_slugs:
provider = SAMLProviderConfig.current(idp_slug)
if provider.enabled and provider.backend_name in _PSA_SAML_BACKENDS:
if provider.enabled_for_current_site and provider.backend_name in _PSA_SAML_BACKENDS:
yield provider
for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
provider = LTIProviderConfig.current(consumer_key)
if provider.enabled and provider.backend_name in _LTI_BACKENDS:
if provider.enabled_for_current_site and provider.backend_name in _LTI_BACKENDS:
yield provider
@classmethod
......@@ -69,7 +75,7 @@ class Registry(object):
@classmethod
def get_enabled_by_backend_name(cls, backend_name):
"""Generator returning all enabled providers that use the specified
backend.
backend on the current site.
Example:
>>> list(get_enabled_by_backend_name("tpa-saml"))
......@@ -84,16 +90,17 @@ class Registry(object):
"""
if backend_name in _PSA_OAUTH2_BACKENDS:
provider = OAuth2ProviderConfig.current(backend_name)
if provider.enabled:
if provider.enabled_for_current_site:
yield provider
elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled():
elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled(
Site.objects.get_current(get_current_request())):
idp_names = SAMLProviderConfig.key_values('idp_slug', flat=True)
for idp_name in idp_names:
provider = SAMLProviderConfig.current(idp_name)
if provider.backend_name == backend_name and provider.enabled:
if provider.backend_name == backend_name and provider.enabled_for_current_site:
yield provider
elif backend_name in _LTI_BACKENDS:
for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
provider = LTIProviderConfig.current(consumer_key)
if provider.backend_name == backend_name and provider.enabled:
if provider.backend_name == backend_name and provider.enabled_for_current_site:
yield provider
......@@ -2,8 +2,10 @@
Slightly customized python-social-auth backend for SAML 2.0 support
"""
import logging
from django.contrib.sites.models import Site
from django.http import Http404
from django.utils.functional import cached_property
from openedx.core.djangoapps.theming.helpers import get_current_request
from social.backends.saml import SAMLAuth, OID_EDU_PERSON_ENTITLEMENT
from social.exceptions import AuthForbidden, AuthMissingParameter
......@@ -41,10 +43,12 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
if not self._config.enabled:
log.error('SAML authentication is not enabled')
raise Http404
# TODO: remove this check once the fix is merged upstream:
# https://github.com/omab/python-social-auth/pull/821
if 'idp' not in self.strategy.request_data():
raise AuthMissingParameter(self, 'idp')
return super(SAMLAuthBackend, self).auth_url()
def _check_entitlements(self, idp, attributes):
......@@ -93,4 +97,4 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
@cached_property
def _config(self):
from .models import SAMLConfiguration
return SAMLConfiguration.current()
return SAMLConfiguration.current(Site.objects.get_current(get_current_request()))
......@@ -26,7 +26,7 @@ class ConfigurationModelStrategy(DjangoStrategy):
"""
if isinstance(backend, OAuthAuth):
provider_config = OAuth2ProviderConfig.current(backend.name)
if not provider_config.enabled:
if not provider_config.enabled_for_current_site:
raise Exception("Can't fetch setting of a disabled backend/provider.")
try:
return provider_config.get_setting(name)
......
......@@ -36,16 +36,13 @@ def fetch_saml_metadata():
num_failed: Number of providers that could not be updated
num_total: Total number of providers whose metadata was fetched
"""
if not SAMLConfiguration.is_enabled():
return (0, 0, 0) # Nothing to do until SAML is enabled.
num_changed, num_failed = 0, 0
# First make a list of all the metadata XML URLs:
url_map = {}
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
config = SAMLProviderConfig.current(idp_slug)
if not config.enabled:
if not config.enabled or not SAMLConfiguration.is_enabled(config.site):
continue
url = config.metadata_source
if url not in url_map:
......
"""Unit tests for provider.py."""
from django.contrib.sites.models import Site
from mock import Mock, patch
from openedx.core.djangoapps.site_configuration.tests.test_util import with_site_configuration
from third_party_auth import provider
from third_party_auth.tests import testutil
import unittest
SITE_DOMAIN_A = 'professionalx.example.com'
SITE_DOMAIN_B = 'somethingelse.example.com'
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled')
class RegistryTest(testutil.TestCase):
......@@ -84,6 +89,22 @@ class RegistryTest(testutil.TestCase):
self.assertNotIn(no_log_in_provider.provider_id, provider_ids)
self.assertIn(normal_provider.provider_id, provider_ids)
def test_provider_enabled_for_current_site(self):
"""
Verify that enabled_for_current_site returns True when the provider matches the current site.
"""
prov = self.configure_google_provider(visible=True, enabled=True, site=Site.objects.get_current())
self.assertEqual(prov.enabled_for_current_site, True)
@with_site_configuration(SITE_DOMAIN_A)
def test_provider_disabled_for_mismatching_site(self):
"""
Verify that enabled_for_current_site returns False when the provider is configured for a different site.
"""
site_b = Site.objects.get_or_create(domain=SITE_DOMAIN_B, name=SITE_DOMAIN_B)[0]
prov = self.configure_google_provider(visible=True, enabled=True, site=site_b)
self.assertEqual(prov.enabled_for_current_site, False)
def test_get_returns_enabled_provider(self):
google_provider = self.configure_google_provider(enabled=True)
self.assertEqual(google_provider.id, provider.Registry.get(google_provider.provider_id).id)
......
......@@ -7,6 +7,7 @@ Used by Django and non-Django tests; must not have Django deps.
from contextlib import contextmanager
from django.conf import settings
from django.contrib.auth.models import User
from django.contrib.sites.models import Site
from provider.oauth2.models import Client as OAuth2Client
from provider import constants
import django.test
......@@ -76,13 +77,17 @@ class ThirdPartyAuthTestMixin(object):
@staticmethod
def configure_oauth_provider(**kwargs):
""" Update the settings for an OAuth2-based third party auth provider """
kwargs.setdefault('provider_slug', kwargs['backend_name'])
obj = OAuth2ProviderConfig(**kwargs)
obj.save()
return obj
def configure_saml_provider(self, **kwargs):
""" Update the settings for a SAML-based third party auth provider """
self.assertTrue(SAMLConfiguration.is_enabled(), "SAML Provider Configuration only works if SAML is enabled.")
self.assertTrue(
SAMLConfiguration.is_enabled(Site.objects.get_current()),
"SAML Provider Configuration only works if SAML is enabled."
)
obj = SAMLProviderConfig(**kwargs)
obj.save()
return obj
......
......@@ -36,7 +36,7 @@ def saml_metadata_view(request):
Get the Service Provider metadata for this edx-platform instance.
You must send this XML to any Shibboleth Identity Provider that you wish to use.
"""
if not SAMLConfiguration.is_enabled():
if not SAMLConfiguration.is_enabled(request.site):
raise Http404
complete_url = reverse('social:complete', args=("tpa-saml", ))
if settings.APPEND_SLASH and not complete_url.endswith('/'):
......
......@@ -10,8 +10,10 @@
"icon_class": "fa-google-plus",
"icon_image": null,
"backend_name": "google-oauth2",
"provider_slug": "google-oauth2",
"key": "test",
"secret": "test",
"site": 2,
"other_settings": "{}",
"visible": true
}
......@@ -27,8 +29,10 @@
"icon_class": "fa-facebook",
"icon_image": null,
"backend_name": "facebook",
"provider_slug": "facebook",
"key": "test",
"secret": "test",
"site": 2,
"other_settings": "{}",
"visible": true
}
......@@ -44,8 +48,10 @@
"icon_class": "",
"icon_image": "test-icon.png",
"backend_name": "dummy",
"provider_slug": "dummy",
"key": "",
"secret": "",
"site": 2,
"other_settings": "{}",
"visible": true
}
......
......@@ -37,7 +37,8 @@ def with_site_configuration(domain="test.localhost", configuration=None):
with patch('openedx.core.djangoapps.site_configuration.helpers.get_current_site_configuration',
return_value=site_configuration):
with patch('openedx.core.djangoapps.theming.helpers.get_current_site', return_value=site):
return func(*args, **kwargs)
with patch('django.contrib.sites.models.SiteManager.get_current', return_value=site):
return func(*args, **kwargs)
return _decorated
return _decorator
......@@ -63,4 +64,5 @@ def with_site_configuration_context(domain="test.localhost", configuration=None)
with patch('openedx.core.djangoapps.site_configuration.helpers.get_current_site_configuration',
return_value=site_configuration):
with patch('openedx.core.djangoapps.theming.helpers.get_current_site', return_value=site):
yield
with patch('django.contrib.sites.models.SiteManager.get_current', return_value=site):
yield
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment