provider.py 5.19 KB
Newer Older
1
"""
2 3
Third-party auth provider configuration API.
"""
4
from django.contrib.sites.models import Site
5

6 7
from openedx.core.djangoapps.theming.helpers import get_current_request

8
from .models import (
9 10 11 12 13 14 15
    _LTI_BACKENDS,
    _PSA_OAUTH2_BACKENDS,
    _PSA_SAML_BACKENDS,
    LTIProviderConfig,
    OAuth2ProviderConfig,
    SAMLConfiguration,
    SAMLProviderConfig
16
)
17 18


19 20
class Registry(object):
    """
21
    API for querying third-party auth ProviderConfig objects.
22

23 24
    Providers must subclass ProviderConfig in order to be usable in the registry.
    """
25
    @classmethod
26
    def _enabled_providers(cls):
27 28 29 30
        """
        Helper method that returns a generator used to iterate over all providers
        of the current site.
        """
31 32 33 34
        oauth2_slugs = OAuth2ProviderConfig.key_values('provider_slug', flat=True)
        for oauth2_slug in oauth2_slugs:
            provider = OAuth2ProviderConfig.current(oauth2_slug)
            if provider.enabled_for_current_site and provider.backend_name in _PSA_OAUTH2_BACKENDS:
35
                yield provider
36
        if SAMLConfiguration.is_enabled(Site.objects.get_current(get_current_request())):
37 38 39
            idp_slugs = SAMLProviderConfig.key_values('idp_slug', flat=True)
            for idp_slug in idp_slugs:
                provider = SAMLProviderConfig.current(idp_slug)
40
                if provider.enabled_for_current_site and provider.backend_name in _PSA_SAML_BACKENDS:
41
                    yield provider
42 43
        for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
            provider = LTIProviderConfig.current(consumer_key)
44
            if provider.enabled_for_current_site and provider.backend_name in _LTI_BACKENDS:
45
                yield provider
46 47 48 49

    @classmethod
    def enabled(cls):
        """Returns list of enabled providers."""
50
        return sorted(cls._enabled_providers(), key=lambda provider: provider.name)
51 52

    @classmethod
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    def displayed_for_login(cls, tpa_hint=None):
        """
        Args:
            tpa_hint (string): An override used in certain third-party authentication
                scenarios that will cause the specified provider to be included in the
                set along with any providers matching the 'display_for_login' criteria.
                Note that 'provider_id' cannot have a value of None according to the
                current implementation.

        Returns:
            List of ProviderConfig entities
        """
        return [
            provider
            for provider in cls.enabled()
            if provider.display_for_login or provider.provider_id == tpa_hint
        ]
70 71

    @classmethod
72 73 74 75 76 77 78 79
    def get(cls, provider_id):
        """Gets provider by provider_id string if enabled, else None."""
        if '-' not in provider_id:  # Check format - see models.py:ProviderConfig
            raise ValueError("Invalid provider_id. Expect something like oa2-google")
        try:
            return next(provider for provider in cls._enabled_providers() if provider.provider_id == provider_id)
        except StopIteration:
            return None
80 81

    @classmethod
82 83
    def get_from_pipeline(cls, running_pipeline):
        """Gets the provider that is being used for the specified pipeline (or None).
84 85

        Args:
86 87 88 89
            running_pipeline: The python-social-auth pipeline being used to
                authenticate a user.

        Returns:
90
            An instance of ProviderConfig or None.
91
        """
92
        for enabled in cls._enabled_providers():
93
            if enabled.is_active_for_pipeline(running_pipeline):
94 95 96
                return enabled

    @classmethod
97 98
    def get_enabled_by_backend_name(cls, backend_name):
        """Generator returning all enabled providers that use the specified
99
        backend on the current site.
100 101 102

        Example:
            >>> list(get_enabled_by_backend_name("tpa-saml"))
103
                [<SAMLProviderConfig>, <SAMLProviderConfig>]
104 105 106 107 108 109

        Args:
            backend_name: The name of a python-social-auth backend used by
                one or more providers.

        Yields:
110
            Instances of ProviderConfig.
111
        """
112
        if backend_name in _PSA_OAUTH2_BACKENDS:
113 114 115 116 117
            oauth2_slugs = OAuth2ProviderConfig.key_values('provider_slug', flat=True)
            for oauth2_slug in oauth2_slugs:
                provider = OAuth2ProviderConfig.current(oauth2_slug)
                if provider.backend_name == backend_name and provider.enabled_for_current_site:
                    yield provider
118 119
        elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled(
                Site.objects.get_current(get_current_request())):
120 121 122
            idp_names = SAMLProviderConfig.key_values('idp_slug', flat=True)
            for idp_name in idp_names:
                provider = SAMLProviderConfig.current(idp_name)
123
                if provider.backend_name == backend_name and provider.enabled_for_current_site:
124
                    yield provider
125 126 127
        elif backend_name in _LTI_BACKENDS:
            for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
                provider = LTIProviderConfig.current(consumer_key)
128
                if provider.backend_name == backend_name and provider.enabled_for_current_site:
129
                    yield provider