saml.py 13.3 KB
Newer Older
1 2 3
"""
Slightly customized python-social-auth backend for SAML 2.0 support
"""
4
import logging
5
from copy import deepcopy
6 7

import requests
8
from django.contrib.sites.models import Site
9
from django.http import Http404
10
from django.utils.functional import cached_property
11
from django_countries import countries
12 13
from social_core.backends.saml import OID_EDU_PERSON_ENTITLEMENT, SAMLAuth, SAMLIdentityProvider
from social_core.exceptions import AuthForbidden
14

15 16
from openedx.core.djangoapps.theming.helpers import get_current_request

17 18
STANDARD_SAML_PROVIDER_KEY = 'standard_saml_provider'
SAP_SUCCESSFACTORS_SAML_KEY = 'sap_success_factors'
19
log = logging.getLogger(__name__)
20 21 22 23 24 25 26 27 28 29 30


class SAMLAuthBackend(SAMLAuth):  # pylint: disable=abstract-method
    """
    Customized version of SAMLAuth that gets the list of IdPs from third_party_auth's list of
    enabled providers.
    """
    name = "tpa-saml"

    def get_idp(self, idp_name):
        """ Given the name of an IdP, get a SAMLIdentityProvider instance """
31 32 33 34 35 36 37 38
        from .models import SAMLProviderConfig
        return SAMLProviderConfig.current(idp_name).get_config()

    def setting(self, name, default=None):
        """ Get a setting, from SAMLConfiguration """
        try:
            return self._config.get_setting(name)
        except KeyError:
39
            return self.strategy.setting(name, default, backend=self)
40

41 42
    def auth_url(self):
        """
43 44 45
        Check that SAML is enabled and that the request includes an 'idp'
        parameter before getting the URL to which we must redirect in order to
        authenticate the user.
46

47
        raise Http404 if SAML authentication is disabled.
48
        """
49
        if not self._config.enabled:
50
            log.error('SAML authentication is not enabled')
51
            raise Http404
52

53 54
        return super(SAMLAuthBackend, self).auth_url()

55 56 57 58 59 60 61 62 63 64 65 66 67 68
    def _check_entitlements(self, idp, attributes):
        """
        Check if we require the presence of any specific eduPersonEntitlement.

        raise AuthForbidden if the user should not be authenticated, or do nothing
        to allow the login pipeline to continue.
        """
        if "requiredEntitlements" in idp.conf:
            entitlements = attributes.get(OID_EDU_PERSON_ENTITLEMENT, [])
            for expected in idp.conf['requiredEntitlements']:
                if expected not in entitlements:
                    log.warning(
                        "SAML user from IdP %s rejected due to missing eduPersonEntitlement %s", idp.name, expected)
                    raise AuthForbidden(self)
69

70 71 72 73
    def _create_saml_auth(self, idp):
        """
        Get an instance of OneLogin_Saml2_Auth

74
        idp: The Identity Provider - a social_core.backends.saml.SAMLIdentityProvider instance
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
        """
        # We only override this method so that we can add extra debugging when debug_mode is True
        # Note that auth_inst is instantiated just for the current HTTP request, then is destroyed
        auth_inst = super(SAMLAuthBackend, self)._create_saml_auth(idp)
        from .models import SAMLProviderConfig
        if SAMLProviderConfig.current(idp.name).debug_mode:

            def wrap_with_logging(method_name, action_description, xml_getter):
                """ Wrap the request and response handlers to add debug mode logging """
                method = getattr(auth_inst, method_name)

                def wrapped_method(*args, **kwargs):
                    """ Wrapped login or process_response method """
                    result = method(*args, **kwargs)
                    log.info("SAML login %s for IdP %s. XML is:\n%s", action_description, idp.name, xml_getter())
                    return result
                setattr(auth_inst, method_name, wrapped_method)

            wrap_with_logging("login", "request", auth_inst.get_last_request_xml)
            wrap_with_logging("process_response", "response", auth_inst.get_last_response_xml)

        return auth_inst

98 99 100
    @cached_property
    def _config(self):
        from .models import SAMLConfiguration
101
        return SAMLConfiguration.current(Site.objects.get_current(get_current_request()))
102 103


104
class EdXSAMLIdentityProvider(SAMLIdentityProvider):
105
    """
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    Customized version of SAMLIdentityProvider that can retrieve details beyond the standard
    details supported by the canonical upstream version.
    """

    def get_user_details(self, attributes):
        """
        Overrides `get_user_details` from the base class; retrieves those details,
        then updates the dict with values from whatever additional fields are desired.
        """
        details = super(EdXSAMLIdentityProvider, self).get_user_details(attributes)
        extra_field_definitions = self.conf.get('extra_field_definitions', [])
        details.update({
            field['name']: attributes[field['urn']][0] if field['urn'] in attributes else None
            for field in extra_field_definitions
        })
        return details


class SapSuccessFactorsIdentityProvider(EdXSAMLIdentityProvider):
    """
    Customized version of EdXSAMLIdentityProvider that knows how to retrieve user details
127 128 129 130 131 132 133 134 135 136 137 138
    from the SAPSuccessFactors OData API, rather than parse them directly off the
    SAML assertion that we get in response to a login attempt.
    """

    required_variables = (
        'sapsf_oauth_root_url',
        'sapsf_private_key',
        'odata_api_root_url',
        'odata_company_id',
        'odata_client_id',
    )

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
    # Define the relationships between SAPSF record fields and Open edX logistration fields.
    default_field_mapping = {
        'username': 'username',
        'firstName': 'first_name',
        'lastName': 'last_name',
        'defaultFullName': 'fullname',
        'email': 'email',
        'country': 'country',
        'city': 'city',
    }

    # Define a simple mapping to relate SAPSF values to Open edX-compatible values for
    # any given field. By default, this only contains the Country field, as SAPSF supplies
    # a country name, which has to be translated to a country code.
    default_value_mapping = {
        'country': {name: code for code, name in countries}
    }

    # Unfortunately, not everything has a 1:1 name mapping between Open edX and SAPSF, so
    # we need some overrides. TODO: Fill in necessary mappings
    default_value_mapping.update({
        'United States': 'US',
    })

    def get_registration_fields(self, response):
        """
        Get a dictionary mapping registration field names to default values.
        """
        field_mapping = self.field_mappings
        registration_fields = {edx_name: response['d'].get(odata_name, '') for odata_name, edx_name in field_mapping.items()}
        value_mapping = self.value_mappings
        for field, value in registration_fields.items():
            if field in value_mapping and value in value_mapping[field]:
                registration_fields[field] = value_mapping[field][value]
        return registration_fields

    @property
    def field_mappings(self):
        """
        Get a dictionary mapping the field names returned in an SAP SuccessFactors
        user entity to the field names with which those values should be used in
        the Open edX registration form.
        """
        overrides = self.conf.get('sapsf_field_mappings', {})
        base = self.default_field_mapping.copy()
        base.update(overrides)
        return base

    @property
    def value_mappings(self):
        """
        Get a dictionary mapping of field names to override objects which each
        map values received from SAP SuccessFactors to values expected in the
        Open edX platform registration form.
        """
        overrides = self.conf.get('sapsf_value_mappings', {})
195
        base = deepcopy(self.default_value_mapping)
196 197 198 199 200 201 202 203 204 205 206 207 208 209
        for field, override in overrides.items():
            if field in base:
                base[field].update(override)
            else:
                base[field] = override[field]
        return base

    @property
    def timeout(self):
        """
        The number of seconds OData API requests should wait for a response before failing.
        """
        return self.conf.get('odata_api_request_timeout', 10)

210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
    @property
    def sapsf_idp_url(self):
        return self.conf['sapsf_oauth_root_url'] + 'idp'

    @property
    def sapsf_token_url(self):
        return self.conf['sapsf_oauth_root_url'] + 'token'

    @property
    def sapsf_private_key(self):
        return self.conf['sapsf_private_key']

    @property
    def odata_api_root_url(self):
        return self.conf['odata_api_root_url']

    @property
    def odata_company_id(self):
        return self.conf['odata_company_id']

    @property
    def odata_client_id(self):
        return self.conf['odata_client_id']

    def missing_variables(self):
        """
        Check that we have all the details we need to properly retrieve rich data from the
        SAP SuccessFactors OData API. If we don't, then we should log a warning indicating
        the specific variables that are missing.
        """
        if not all(var in self.conf for var in self.required_variables):
            missing = [var for var in self.required_variables if var not in self.conf]
            log.warning(
                "To retrieve rich user data for an SAP SuccessFactors identity provider, the following keys in "
                "'other_settings' are required, but were missing: %s",
                missing
            )
            return missing

    def get_odata_api_client(self, user_id):
        """
        Get a Requests session with the headers needed to properly authenticate it with
        the SAP SuccessFactors OData API.
        """
        session = requests.Session()
        assertion = session.post(
            self.sapsf_idp_url,
            data={
                'client_id': self.odata_client_id,
                'user_id': user_id,
                'token_url': self.sapsf_token_url,
                'private_key': self.sapsf_private_key,
            },
263
            timeout=self.timeout,
264 265 266 267 268 269 270 271 272 273 274
        )
        assertion.raise_for_status()
        assertion = assertion.text
        token = session.post(
            self.sapsf_token_url,
            data={
                'client_id': self.odata_client_id,
                'company_id': self.odata_company_id,
                'grant_type': 'urn:ietf:params:oauth:grant-type:saml2-bearer',
                'assertion': assertion,
            },
275
            timeout=self.timeout,
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
        )
        token.raise_for_status()
        token = token.json()['access_token']
        session.headers.update({'Authorization': 'Bearer {}'.format(token), 'Accept': 'application/json'})
        return session

    def get_user_details(self, attributes):
        """
        Attempt to get rich user details from the SAP SuccessFactors OData API. If we're missing any
        of the details we need to do that, fail nicely by returning the details we're able to extract
        from just the SAML response and log a warning.
        """
        details = super(SapSuccessFactorsIdentityProvider, self).get_user_details(attributes)
        if self.missing_variables():
            # If there aren't enough details to make the request, log a warning and return the details
            # from the SAML assertion.
            return details
        username = details['username']
294 295 296 297 298 299
        fields = ','.join(self.field_mappings)
        odata_api_url = '{root_url}User(userId=\'{user_id}\')?$select={fields}'.format(
            root_url=self.odata_api_root_url,
            user_id=username,
            fields=fields,
        )
300 301 302
        try:
            client = self.get_odata_api_client(user_id=username)
            response = client.get(
303
                odata_api_url,
304
                timeout=self.timeout,
305 306 307
            )
            response.raise_for_status()
            response = response.json()
308
        except requests.RequestException as err:
309 310
            # If there was an HTTP level error, log the error and return the details from the SAML assertion.
            log.warning(
311 312
                'Unable to retrieve user details with username %s from SAPSuccessFactors for company ID %s '
                'with url "%s" and error message: %s',
313 314
                username,
                self.odata_company_id,
315 316 317
                odata_api_url,
                err.message,
                exc_info=True,
318 319
            )
            return details
320

321
        return self.get_registration_fields(response)
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340


def get_saml_idp_choices():
    """
    Get a list of the available SAMLIdentityProvider subclasses that can be used to process
    SAML requests, for use in the Django administration form.
    """
    return (
        (STANDARD_SAML_PROVIDER_KEY, 'Standard SAML provider'),
        (SAP_SUCCESSFACTORS_SAML_KEY, 'SAP SuccessFactors provider'),
    )


def get_saml_idp_class(idp_identifier_string):
    """
    Given a string ID indicating the type of identity provider in use during a given request, return
    the SAMLIdentityProvider subclass able to handle requests for that type of identity provider.
    """
    choices = {
341
        STANDARD_SAML_PROVIDER_KEY: EdXSAMLIdentityProvider,
342 343 344 345
        SAP_SUCCESSFACTORS_SAML_KEY: SapSuccessFactorsIdentityProvider,
    }
    if idp_identifier_string not in choices:
        log.error(
346
            '%s is not a valid EdXSAMLIdentityProvider subclass; using EdXSAMLIdentityProvider base class.',
347 348
            idp_identifier_string
        )
349
    return choices.get(idp_identifier_string, EdXSAMLIdentityProvider)