"""
Third-party auth provider configuration API.
"""
from django.contrib.sites.models import Site
from openedx.core.djangoapps.theming.helpers import get_current_request

from .models import (
    OAuth2ProviderConfig, SAMLConfiguration, SAMLProviderConfig, LTIProviderConfig,
    _PSA_OAUTH2_BACKENDS, _PSA_SAML_BACKENDS, _LTI_BACKENDS,
)


class Registry(object):
    """
    API for querying third-party auth ProviderConfig objects.

    Providers must subclass ProviderConfig in order to be usable in the registry.
    """
    @classmethod
    def _enabled_providers(cls):
        """
        Helper method that returns a generator used to iterate over all providers
        of the current site.
        """
        for backend_name in _PSA_OAUTH2_BACKENDS:
            provider = OAuth2ProviderConfig.current(backend_name)
            if provider.enabled_for_current_site:
                yield provider
        if SAMLConfiguration.is_enabled(Site.objects.get_current(get_current_request())):
            idp_slugs = SAMLProviderConfig.key_values('idp_slug', flat=True)
            for idp_slug in idp_slugs:
                provider = SAMLProviderConfig.current(idp_slug)
                if provider.enabled_for_current_site and provider.backend_name in _PSA_SAML_BACKENDS:
                    yield provider
        for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
            provider = LTIProviderConfig.current(consumer_key)
            if provider.enabled_for_current_site and provider.backend_name in _LTI_BACKENDS:
                yield provider

    @classmethod
    def enabled(cls):
        """Returns list of enabled providers."""
        return sorted(cls._enabled_providers(), key=lambda provider: provider.name)

    @classmethod
    def displayed_for_login(cls):
        """Returns list of providers that can be used to initiate logins in the UI"""
        return [provider for provider in cls.enabled() if provider.display_for_login]

    @classmethod
    def get(cls, provider_id):
        """Gets provider by provider_id string if enabled, else None."""
        if '-' not in provider_id:  # Check format - see models.py:ProviderConfig
            raise ValueError("Invalid provider_id. Expect something like oa2-google")
        try:
            return next(provider for provider in cls._enabled_providers() if provider.provider_id == provider_id)
        except StopIteration:
            return None

    @classmethod
    def get_from_pipeline(cls, running_pipeline):
        """Gets the provider that is being used for the specified pipeline (or None).

        Args:
            running_pipeline: The python-social-auth pipeline being used to
                authenticate a user.

        Returns:
            An instance of ProviderConfig or None.
        """
        for enabled in cls._enabled_providers():
            if enabled.is_active_for_pipeline(running_pipeline):
                return enabled

    @classmethod
    def get_enabled_by_backend_name(cls, backend_name):
        """Generator returning all enabled providers that use the specified
        backend on the current site.

        Example:
            >>> list(get_enabled_by_backend_name("tpa-saml"))
                [<SAMLProviderConfig>, <SAMLProviderConfig>]

        Args:
            backend_name: The name of a python-social-auth backend used by
                one or more providers.

        Yields:
            Instances of ProviderConfig.
        """
        if backend_name in _PSA_OAUTH2_BACKENDS:
            provider = OAuth2ProviderConfig.current(backend_name)
            if provider.enabled_for_current_site:
                yield provider
        elif backend_name in _PSA_SAML_BACKENDS and SAMLConfiguration.is_enabled(
                Site.objects.get_current(get_current_request())):
            idp_names = SAMLProviderConfig.key_values('idp_slug', flat=True)
            for idp_name in idp_names:
                provider = SAMLProviderConfig.current(idp_name)
                if provider.backend_name == backend_name and provider.enabled_for_current_site:
                    yield provider
        elif backend_name in _LTI_BACKENDS:
            for consumer_key in LTIProviderConfig.key_values('lti_consumer_key', flat=True):
                provider = LTIProviderConfig.current(consumer_key)
                if provider.backend_name == backend_name and provider.enabled_for_current_site:
                    yield provider