Commit d5c85331 by Uman Shahzad

Automatically populate additional fields for SSO scenarios.

When authenticating using an SAML IdP, gather additional user
data besides what is standard. Requires admin to input JSON
in settings to recognize the additional user data.
parent 9ecc2938
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('third_party_auth', '0010_add_skip_hinted_login_dialog_field'),
]
operations = [
migrations.AlterField(
model_name='samlproviderconfig',
name='other_settings',
field=models.TextField(help_text=b'For advanced use cases, enter a JSON object with addtional configuration. The tpa-saml backend supports {"requiredEntitlements": ["urn:..."]}, which can be used to require the presence of a specific eduPersonEntitlement, and {"extra_field_definitions": [{"name": "...", "urn": "..."},...]}, which can be used to define registration form fields and the URNs that can be used to retrieve the relevant values from the SAML response. Custom provider types, as selected in the "Identity Provider Type" field, may make use of the information stored in this field for additional configuration.', verbose_name=b'Advanced settings', blank=True),
),
]
...@@ -31,6 +31,11 @@ from .saml import STANDARD_SAML_PROVIDER_KEY, get_saml_idp_choices, get_saml_idp ...@@ -31,6 +31,11 @@ from .saml import STANDARD_SAML_PROVIDER_KEY, get_saml_idp_choices, get_saml_idp
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
REGISTRATION_FORM_FIELD_BLACKLIST = [
'name',
'username'
]
# A dictionary of {name: class} entries for each python-social-auth backend available. # A dictionary of {name: class} entries for each python-social-auth backend available.
# Because this setting can specify arbitrary code to load and execute, it is set via # Because this setting can specify arbitrary code to load and execute, it is set via
...@@ -241,8 +246,13 @@ class ProviderConfig(ConfigurationModel): ...@@ -241,8 +246,13 @@ class ProviderConfig(ConfigurationModel):
values for that field. Where there is no value, the empty string values for that field. Where there is no value, the empty string
must be used. must be used.
""" """
registration_form_data = {}
# Details about the user sent back from the provider. # Details about the user sent back from the provider.
details = pipeline_kwargs.get('details') details = pipeline_kwargs.get('details').copy()
# Set the registration form to use the `fullname` detail for the `name` field.
registration_form_data['name'] = details.get('fullname', '')
# Get the username separately to take advantage of the de-duping logic # Get the username separately to take advantage of the de-duping logic
# built into the pipeline. The provider cannot de-dupe because it can't # built into the pipeline. The provider cannot de-dupe because it can't
...@@ -250,13 +260,19 @@ class ProviderConfig(ConfigurationModel): ...@@ -250,13 +260,19 @@ class ProviderConfig(ConfigurationModel):
# technically a data race between the creation of this value and the # technically a data race between the creation of this value and the
# creation of the user object, so it is still possible for users to get # creation of the user object, so it is still possible for users to get
# an error on submit. # an error on submit.
suggested_username = pipeline_kwargs.get('username') registration_form_data['username'] = pipeline_kwargs.get('username')
return { # Any other values that are present in the details dict should be copied
'email': details.get('email', ''), # into the registration form details. This may include details that do
'name': details.get('fullname', ''), # not map to a value that exists in the registration form. However,
'username': suggested_username, # because the fields that are actually rendered are not based on this
} # list, only those values that map to a valid registration form field
# will actually be sent to the form as default values.
for blacklisted_field in REGISTRATION_FORM_FIELD_BLACKLIST:
details.pop(blacklisted_field, None)
registration_form_data.update(details)
return registration_form_data
def get_authentication_backend(self): def get_authentication_backend(self):
"""Gets associated Django settings.AUTHENTICATION_BACKEND string.""" """Gets associated Django settings.AUTHENTICATION_BACKEND string."""
...@@ -401,10 +417,13 @@ class SAMLProviderConfig(ProviderConfig): ...@@ -401,10 +417,13 @@ class SAMLProviderConfig(ProviderConfig):
verbose_name="Advanced settings", blank=True, verbose_name="Advanced settings", blank=True,
help_text=( help_text=(
'For advanced use cases, enter a JSON object with addtional configuration. ' 'For advanced use cases, enter a JSON object with addtional configuration. '
'The tpa-saml backend supports only {"requiredEntitlements": ["urn:..."]} ' 'The tpa-saml backend supports {"requiredEntitlements": ["urn:..."]}, '
'which can be used to require the presence of a specific eduPersonEntitlement. ' 'which can be used to require the presence of a specific eduPersonEntitlement, '
'Custom provider types, as selected in the "Identity Provider Type" field, may make ' 'and {"extra_field_definitions": [{"name": "...", "urn": "..."},...]}, which can be '
'use of the information stored in this field for configuration.' 'used to define registration form fields and the URNs that can be used to retrieve '
'the relevant values from the SAML response. Custom provider types, as selected '
'in the "Identity Provider Type" field, may make use of the information stored '
'in this field for additional configuration.'
)) ))
def clean(self): def clean(self):
......
...@@ -100,9 +100,29 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method ...@@ -100,9 +100,29 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
return SAMLConfiguration.current(Site.objects.get_current(get_current_request())) return SAMLConfiguration.current(Site.objects.get_current(get_current_request()))
class SapSuccessFactorsIdentityProvider(SAMLIdentityProvider): class EdXSAMLIdentityProvider(SAMLIdentityProvider):
""" """
Customized version of SAMLIdentityProvider that knows how to retrieve user details Customized version of SAMLIdentityProvider that can retrieve details beyond the standard
details supported by the canonical upstream version.
"""
def get_user_details(self, attributes):
"""
Overrides `get_user_details` from the base class; retrieves those details,
then updates the dict with values from whatever additional fields are desired.
"""
details = super(EdXSAMLIdentityProvider, self).get_user_details(attributes)
extra_field_definitions = self.conf.get('extra_field_definitions', [])
details.update({
field['name']: attributes[field['urn']][0] if field['urn'] in attributes else None
for field in extra_field_definitions
})
return details
class SapSuccessFactorsIdentityProvider(EdXSAMLIdentityProvider):
"""
Customized version of EdXSAMLIdentityProvider that knows how to retrieve user details
from the SAPSuccessFactors OData API, rather than parse them directly off the from the SAPSuccessFactors OData API, rather than parse them directly off the
SAML assertion that we get in response to a login attempt. SAML assertion that we get in response to a login attempt.
""" """
...@@ -244,12 +264,12 @@ def get_saml_idp_class(idp_identifier_string): ...@@ -244,12 +264,12 @@ def get_saml_idp_class(idp_identifier_string):
the SAMLIdentityProvider subclass able to handle requests for that type of identity provider. the SAMLIdentityProvider subclass able to handle requests for that type of identity provider.
""" """
choices = { choices = {
STANDARD_SAML_PROVIDER_KEY: SAMLIdentityProvider, STANDARD_SAML_PROVIDER_KEY: EdXSAMLIdentityProvider,
SAP_SUCCESSFACTORS_SAML_KEY: SapSuccessFactorsIdentityProvider, SAP_SUCCESSFACTORS_SAML_KEY: SapSuccessFactorsIdentityProvider,
} }
if idp_identifier_string not in choices: if idp_identifier_string not in choices:
log.error( log.error(
'%s is not a valid SAMLIdentityProvider subclass; using SAMLIdentityProvider base class.', '%s is not a valid EdXSAMLIdentityProvider subclass; using EdXSAMLIdentityProvider base class.',
idp_identifier_string idp_identifier_string
) )
return choices.get(idp_identifier_string, SAMLIdentityProvider) return choices.get(idp_identifier_string, EdXSAMLIdentityProvider)
...@@ -251,7 +251,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -251,7 +251,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
self.assertEqual(302, response.status_code) self.assertEqual(302, response.status_code)
self.assertTrue(response.has_header('Location')) self.assertTrue(response.has_header('Location'))
def assert_register_response_in_pipeline_looks_correct(self, response, pipeline_kwargs): def assert_register_response_in_pipeline_looks_correct(self, response, pipeline_kwargs, required_fields):
"""Performs spot checks of the rendered register.html page. """Performs spot checks of the rendered register.html page.
When we display the new account registration form after the user signs When we display the new account registration form after the user signs
...@@ -267,9 +267,10 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -267,9 +267,10 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
self.assertIn('successfully signed in with <strong>%s</strong>' % self.provider.name, response.content) self.assertIn('successfully signed in with <strong>%s</strong>' % self.provider.name, response.content)
# Expect that each truthy value we've prepopulated the register form # Expect that each truthy value we've prepopulated the register form
# with is actually present. # with is actually present.
for prepopulated_form_value in self.provider.get_register_form_data(pipeline_kwargs).values(): form_field_data = self.provider.get_register_form_data(pipeline_kwargs)
if prepopulated_form_value: for prepopulated_form_data in form_field_data:
self.assertIn(prepopulated_form_value, response.content) if prepopulated_form_data in required_fields:
self.assertIn(form_field_data[prepopulated_form_data], response.content)
# Implementation details and actual tests past this point -- no more # Implementation details and actual tests past this point -- no more
# configuration needed. # configuration needed.
...@@ -823,7 +824,10 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -823,7 +824,10 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
# fire off the view that displays the registration form. # fire off the view that displays the registration form.
with self._patch_edxmako_current_request(request): with self._patch_edxmako_current_request(request):
self.assert_register_response_in_pipeline_looks_correct( self.assert_register_response_in_pipeline_looks_correct(
student_views.register_user(strategy.request), pipeline.get(request)['kwargs']) student_views.register_user(strategy.request),
pipeline.get(request)['kwargs'],
['name', 'username', 'email']
)
# Next, we invoke the view that handles the POST. Not all providers # Next, we invoke the view that handles the POST. Not all providers
# supply email. Manually add it as the user would have to; this # supply email. Manually add it as the user would have to; this
...@@ -892,7 +896,10 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -892,7 +896,10 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
with self._patch_edxmako_current_request(request): with self._patch_edxmako_current_request(request):
self.assert_register_response_in_pipeline_looks_correct( self.assert_register_response_in_pipeline_looks_correct(
student_views.register_user(strategy.request), pipeline.get(request)['kwargs']) student_views.register_user(strategy.request),
pipeline.get(request)['kwargs'],
['name', 'username', 'email']
)
with self._patch_edxmako_current_request(strategy.request): with self._patch_edxmako_current_request(strategy.request):
strategy.request.POST = self.get_registration_post_vars() strategy.request.POST = self.get_registration_post_vars()
......
...@@ -25,7 +25,7 @@ from third_party_auth.models import ( ...@@ -25,7 +25,7 @@ from third_party_auth.models import (
SAMLConfiguration, SAMLConfiguration,
SAMLProviderConfig SAMLProviderConfig
) )
from third_party_auth.saml import SAMLIdentityProvider, get_saml_idp_class from third_party_auth.saml import EdXSAMLIdentityProvider, get_saml_idp_class
AUTH_FEATURES_KEY = 'ENABLE_THIRD_PARTY_AUTH' AUTH_FEATURES_KEY = 'ENABLE_THIRD_PARTY_AUTH'
AUTH_FEATURE_ENABLED = AUTH_FEATURES_KEY in settings.FEATURES AUTH_FEATURE_ENABLED = AUTH_FEATURES_KEY in settings.FEATURES
...@@ -219,14 +219,14 @@ class SAMLTestCase(TestCase): ...@@ -219,14 +219,14 @@ class SAMLTestCase(TestCase):
error_mock = log_mock.error error_mock = log_mock.error
idp_class = get_saml_idp_class('fake_idp_class_option') idp_class = get_saml_idp_class('fake_idp_class_option')
error_mock.assert_called_once_with( error_mock.assert_called_once_with(
'%s is not a valid SAMLIdentityProvider subclass; using SAMLIdentityProvider base class.', '%s is not a valid EdXSAMLIdentityProvider subclass; using EdXSAMLIdentityProvider base class.',
'fake_idp_class_option' 'fake_idp_class_option'
) )
self.assertIs(idp_class, SAMLIdentityProvider) self.assertIs(idp_class, EdXSAMLIdentityProvider)
@contextmanager @contextmanager
def simulate_running_pipeline(pipeline_target, backend, email=None, fullname=None, username=None): def simulate_running_pipeline(pipeline_target, backend, email=None, fullname=None, username=None, **kwargs):
"""Simulate that a pipeline is currently running. """Simulate that a pipeline is currently running.
You can use this context manager to test packages that rely on third party auth. You can use this context manager to test packages that rely on third party auth.
...@@ -269,6 +269,9 @@ def simulate_running_pipeline(pipeline_target, backend, email=None, fullname=Non ...@@ -269,6 +269,9 @@ def simulate_running_pipeline(pipeline_target, backend, email=None, fullname=Non
app generates itself and should be available by the time the user app generates itself and should be available by the time the user
is authenticating with a third-party provider. is authenticating with a third-party provider.
kwargs (dict): If provided, simulate that the current provider has
included additional user details (useful for filling in the registration form).
Returns: Returns:
None None
...@@ -276,9 +279,10 @@ def simulate_running_pipeline(pipeline_target, backend, email=None, fullname=Non ...@@ -276,9 +279,10 @@ def simulate_running_pipeline(pipeline_target, backend, email=None, fullname=Non
pipeline_data = { pipeline_data = {
"backend": backend, "backend": backend,
"kwargs": { "kwargs": {
"details": {} "details": kwargs
} }
} }
if email is not None: if email is not None:
pipeline_data["kwargs"]["details"]["email"] = email pipeline_data["kwargs"]["details"]["email"] = email
if fullname is not None: if fullname is not None:
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
} %> } %>
<% if ( required ) { %> aria-required="true" required<% } %>> <% if ( required ) { %> aria-required="true" required<% } %>>
<% _.each(options, function(el) { %> <% _.each(options, function(el) { %>
<option value="<%= el.value%>"<% if ( el.default ) { %> data-isdefault="true"<% } %>><%= el.name %></option> <option value="<%= el.value%>"<% if ( el.default ) { %> data-isdefault="true" selected<% } %>><%= el.name %></option>
<% }); %> <% }); %>
</select> </select>
<% if ( instructions ) { %> <span class="tip tip-input" id="<%= form %>-<%= name %>-desc"><%= instructions %></span><% } %> <% if ( instructions ) { %> <span class="tip tip-input" id="<%= form %>-<%= name %>-desc"><%= instructions %></span><% } %>
......
...@@ -234,21 +234,29 @@ class FormDescription(object): ...@@ -234,21 +234,29 @@ class FormDescription(object):
"supplementalText": supplementalText "supplementalText": supplementalText
} }
field_override = self._field_overrides.get(name, {})
if field_type == "select": if field_type == "select":
if options is not None: if options is not None:
field_dict["options"] = [] field_dict["options"] = []
# Include an empty "default" option at the beginning of the list # Get an existing default value from the field override
existing_default_value = field_override.get('defaultValue')
# Include an empty "default" option at the beginning of the list;
# preselect it if there isn't an overriding default.
if include_default_option: if include_default_option:
field_dict["options"].append({ field_dict["options"].append({
"value": "", "value": "",
"name": "--", "name": "--",
"default": True "default": existing_default_value is None
}) })
field_dict["options"].extend([ field_dict["options"].extend([
{"value": option_value, "name": option_name} {
for option_value, option_name in options 'value': option_value,
'name': option_name,
'default': option_value == existing_default_value
} for option_value, option_name in options
]) ])
else: else:
raise InvalidFieldError("You must provide options for a select field.") raise InvalidFieldError("You must provide options for a select field.")
...@@ -270,7 +278,7 @@ class FormDescription(object): ...@@ -270,7 +278,7 @@ class FormDescription(object):
# If there are overrides for this field, apply them now. # If there are overrides for this field, apply them now.
# Any field property can be overwritten (for example, the default value or placeholder) # Any field property can be overwritten (for example, the default value or placeholder)
field_dict.update(self._field_overrides.get(name, {})) field_dict.update(field_override)
self.fields.append(field_dict) self.fields.append(field_dict)
...@@ -291,8 +299,8 @@ class FormDescription(object): ...@@ -291,8 +299,8 @@ class FormDescription(object):
"placeholder": "", "placeholder": "",
"instructions": "", "instructions": "",
"options": [ "options": [
{"value": "cheese", "name": "Cheese"}, {"value": "cheese", "name": "Cheese", "default": False},
{"value": "wine", "name": "Wine"} {"value": "wine", "name": "Wine", "default": False}
] ]
"restrictions": {}, "restrictions": {},
"errorMessages": {}, "errorMessages": {},
......
...@@ -147,6 +147,44 @@ class FormDescriptionTest(TestCase): ...@@ -147,6 +147,44 @@ class FormDescriptionTest(TestCase):
with self.assertRaises(InvalidFieldError): with self.assertRaises(InvalidFieldError):
desc.add_field("name", field_type="text", restrictions={"invalid": 0}) desc.add_field("name", field_type="text", restrictions={"invalid": 0})
def test_option_overrides(self):
desc = FormDescription("post", "/submit")
field = {
"name": "country",
"label": "Country",
"field_type": "select",
"default": "PK",
"required": True,
"error_messages": {
"required": "You must provide a value!"
},
"options": [
("US", "United States of America"),
("PK", "Pakistan")
]
}
desc.override_field_properties(
field["name"],
default="PK"
)
desc.add_field(**field)
self.assertEqual(
desc.fields[0]["options"],
[
{
'default': False,
'name': 'United States of America',
'value': 'US'
},
{
'default': True,
'name': 'Pakistan',
'value': 'PK'
}
]
)
@ddt.ddt @ddt.ddt
class StudentViewShimTest(TestCase): class StudentViewShimTest(TestCase):
......
...@@ -31,6 +31,7 @@ from third_party_auth.tests.utils import ( ...@@ -31,6 +31,7 @@ from third_party_auth.tests.utils import (
from .test_helpers import TestCaseForm from .test_helpers import TestCaseForm
from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory from xmodule.modulestore.tests.factories import CourseFactory
from ..helpers import FormDescription
from ..accounts import ( from ..accounts import (
NAME_MAX_LENGTH, EMAIL_MIN_LENGTH, EMAIL_MAX_LENGTH, PASSWORD_MIN_LENGTH, PASSWORD_MAX_LENGTH, NAME_MAX_LENGTH, EMAIL_MIN_LENGTH, EMAIL_MAX_LENGTH, PASSWORD_MIN_LENGTH, PASSWORD_MAX_LENGTH,
USERNAME_MIN_LENGTH, USERNAME_MAX_LENGTH USERNAME_MIN_LENGTH, USERNAME_MAX_LENGTH
...@@ -873,8 +874,6 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -873,8 +874,6 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
u"restrictions": { u"restrictions": {
'min_length': PASSWORD_MIN_LENGTH, 'min_length': PASSWORD_MIN_LENGTH,
'max_length': PASSWORD_MAX_LENGTH 'max_length': PASSWORD_MAX_LENGTH
# 'min_length': account_api.PASSWORD_MIN_LENGTH,
# 'max_length': account_api.PASSWORD_MAX_LENGTH
}, },
} }
) )
...@@ -936,35 +935,35 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -936,35 +935,35 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
} }
) )
def test_register_form_third_party_auth_running(self): def test_register_form_third_party_auth_running_google(self):
no_extra_fields_setting = {} no_extra_fields_setting = {}
country_options = (
self.configure_google_provider(enabled=True) [
with simulate_running_pipeline(
"openedx.core.djangoapps.user_api.views.third_party_auth.pipeline",
"google-oauth2", email="bob@example.com",
fullname="Bob", username="Bob123"
):
# Password field should be hidden
self._assert_reg_field(
no_extra_fields_setting,
{ {
"name": "password", "name": "--",
"type": "hidden", "value": "",
"required": False, "default": False
} }
) ] + [
# social_auth_provider should be present
# with value `Google`(we are setting up google provider for this test).
self._assert_reg_field(
no_extra_fields_setting,
{ {
"name": "social_auth_provider", "value": country_code,
"type": "hidden", "name": unicode(country_name),
"required": False, "default": True if country_code == "PK" else False
"defaultValue": "Google"
} }
) for country_code, country_name in SORTED_COUNTRIES
]
)
provider = self.configure_google_provider(enabled=True)
with simulate_running_pipeline(
"openedx.core.djangoapps.user_api.views.third_party_auth.pipeline", "google-oauth2",
email="bob@example.com",
fullname="Bob",
username="Bob123",
country="PK"
):
self._assert_password_field_hidden(no_extra_fields_setting)
self._assert_social_auth_provider_present(no_extra_fields_setting, provider)
# Email should be filled in # Email should be filled in
self._assert_reg_field( self._assert_reg_field(
...@@ -1020,6 +1019,22 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1020,6 +1019,22 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
} }
) )
# Country should be filled in.
self._assert_reg_field(
{u"country": u"required"},
{
u"label": u"Country",
u"name": u"country",
u"defaultValue": u"PK",
u"type": u"select",
u"required": True,
u"options": country_options,
u"errorMessages": {
u"required": u"Please select your Country."
},
}
)
def test_register_form_level_of_education(self): def test_register_form_level_of_education(self):
self._assert_reg_field( self._assert_reg_field(
{"level_of_education": "optional"}, {"level_of_education": "optional"},
...@@ -1030,15 +1045,15 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1030,15 +1045,15 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
"label": "Highest level of education completed", "label": "Highest level of education completed",
"options": [ "options": [
{"value": "", "name": "--", "default": True}, {"value": "", "name": "--", "default": True},
{"value": "p", "name": "Doctorate"}, {"value": "p", "name": "Doctorate", "default": False},
{"value": "m", "name": "Master's or professional degree"}, {"value": "m", "name": "Master's or professional degree", "default": False},
{"value": "b", "name": "Bachelor's degree"}, {"value": "b", "name": "Bachelor's degree", "default": False},
{"value": "a", "name": "Associate degree"}, {"value": "a", "name": "Associate degree", "default": False},
{"value": "hs", "name": "Secondary/high school"}, {"value": "hs", "name": "Secondary/high school", "default": False},
{"value": "jhs", "name": "Junior secondary/junior high/middle school"}, {"value": "jhs", "name": "Junior secondary/junior high/middle school", "default": False},
{"value": "el", "name": "Elementary/primary school"}, {"value": "el", "name": "Elementary/primary school", "default": False},
{"value": "none", "name": "No formal education"}, {"value": "none", "name": "No formal education", "default": False},
{"value": "other", "name": "Other education"}, {"value": "other", "name": "Other education", "default": False},
], ],
} }
) )
...@@ -1056,15 +1071,15 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1056,15 +1071,15 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
"label": "Highest level of education completed TRANSLATED", "label": "Highest level of education completed TRANSLATED",
"options": [ "options": [
{"value": "", "name": "--", "default": True}, {"value": "", "name": "--", "default": True},
{"value": "p", "name": "Doctorate TRANSLATED"}, {"value": "p", "name": "Doctorate TRANSLATED", "default": False},
{"value": "m", "name": "Master's or professional degree TRANSLATED"}, {"value": "m", "name": "Master's or professional degree TRANSLATED", "default": False},
{"value": "b", "name": "Bachelor's degree TRANSLATED"}, {"value": "b", "name": "Bachelor's degree TRANSLATED", "default": False},
{"value": "a", "name": "Associate degree TRANSLATED"}, {"value": "a", "name": "Associate degree TRANSLATED", "default": False},
{"value": "hs", "name": "Secondary/high school TRANSLATED"}, {"value": "hs", "name": "Secondary/high school TRANSLATED", "default": False},
{"value": "jhs", "name": "Junior secondary/junior high/middle school TRANSLATED"}, {"value": "jhs", "name": "Junior secondary/junior high/middle school TRANSLATED", "default": False},
{"value": "el", "name": "Elementary/primary school TRANSLATED"}, {"value": "el", "name": "Elementary/primary school TRANSLATED", "default": False},
{"value": "none", "name": "No formal education TRANSLATED"}, {"value": "none", "name": "No formal education TRANSLATED", "default": False},
{"value": "other", "name": "Other education TRANSLATED"}, {"value": "other", "name": "Other education TRANSLATED", "default": False},
], ],
} }
) )
...@@ -1079,9 +1094,9 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1079,9 +1094,9 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
"label": "Gender", "label": "Gender",
"options": [ "options": [
{"value": "", "name": "--", "default": True}, {"value": "", "name": "--", "default": True},
{"value": "m", "name": "Male"}, {"value": "m", "name": "Male", "default": False},
{"value": "f", "name": "Female"}, {"value": "f", "name": "Female", "default": False},
{"value": "o", "name": "Other/Prefer Not to Say"}, {"value": "o", "name": "Other/Prefer Not to Say", "default": False},
], ],
} }
) )
...@@ -1099,9 +1114,9 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1099,9 +1114,9 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
"label": "Gender TRANSLATED", "label": "Gender TRANSLATED",
"options": [ "options": [
{"value": "", "name": "--", "default": True}, {"value": "", "name": "--", "default": True},
{"value": "m", "name": "Male TRANSLATED"}, {"value": "m", "name": "Male TRANSLATED", "default": False},
{"value": "f", "name": "Female TRANSLATED"}, {"value": "f", "name": "Female TRANSLATED", "default": False},
{"value": "o", "name": "Other/Prefer Not to Say TRANSLATED"}, {"value": "o", "name": "Other/Prefer Not to Say TRANSLATED", "default": False},
], ],
} }
) )
...@@ -1109,8 +1124,18 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1109,8 +1124,18 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
def test_register_form_year_of_birth(self): def test_register_form_year_of_birth(self):
this_year = datetime.datetime.now(UTC).year this_year = datetime.datetime.now(UTC).year
year_options = ( year_options = (
[{"value": "", "name": "--", "default": True}] + [ [
{"value": unicode(year), "name": unicode(year)} {
"value": "",
"name": "--",
"default": True
}
] + [
{
"value": unicode(year),
"name": unicode(year),
"default": False
}
for year in range(this_year, this_year - 120, -1) for year in range(this_year, this_year - 120, -1)
] ]
) )
...@@ -1173,9 +1198,18 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1173,9 +1198,18 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
def test_registration_form_country(self): def test_registration_form_country(self):
country_options = ( country_options = (
[{"name": "--", "value": "", "default": True}] +
[ [
{"value": country_code, "name": unicode(country_name)} {
"name": "--",
"value": "",
"default": True
}
] + [
{
"value": country_code,
"name": unicode(country_name),
"default": False
}
for country_code, country_name in SORTED_COUNTRIES for country_code, country_name in SORTED_COUNTRIES
] ]
) )
...@@ -1820,6 +1854,63 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1820,6 +1854,63 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
} }
) )
def test_country_overrides(self):
"""Test that overridden countries are available in country list."""
# Retrieve the registration form description
with override_settings(REGISTRATION_EXTRA_FIELDS={"country": "required"}):
response = self.client.get(self.url)
self.assertHttpOK(response)
self.assertContains(response, 'Kosovo')
def test_create_account_not_allowed(self):
"""
Test case to check user creation is forbidden when ALLOW_PUBLIC_ACCOUNT_CREATION feature flag is turned off
"""
def _side_effect_for_get_value(value, default=None):
"""
returns a side_effect with given return value for a given value
"""
if value == 'ALLOW_PUBLIC_ACCOUNT_CREATION':
return False
else:
return get_value(value, default)
with mock.patch('openedx.core.djangoapps.site_configuration.helpers.get_value') as mock_get_value:
mock_get_value.side_effect = _side_effect_for_get_value
response = self.client.post(self.url, {"email": self.EMAIL, "username": self.USERNAME})
self.assertEqual(response.status_code, 403)
def _assert_fields_match(self, actual_field, expected_field):
self.assertIsNot(
actual_field, None,
msg="Could not find field {name}".format(name=expected_field["name"])
)
for key, value in expected_field.iteritems():
self.assertEqual(
actual_field[key], expected_field[key],
msg=u"Expected {expected} for {key} but got {actual} instead".format(
key=key,
actual=actual_field[key],
expected=expected_field[key]
)
)
def _populate_always_present_fields(self, field):
defaults = [
("label", ""),
("instructions", ""),
("placeholder", ""),
("defaultValue", ""),
("restrictions", {}),
("errorMessages", {}),
]
field.update({
key: value
for key, value in defaults if key not in field
})
def _assert_reg_field(self, extra_fields_setting, expected_field): def _assert_reg_field(self, extra_fields_setting, expected_field):
"""Retrieve the registration form description from the server and """Retrieve the registration form description from the server and
verify that it contains the expected field. verify that it contains the expected field.
...@@ -1835,17 +1926,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1835,17 +1926,7 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
""" """
# Add in fields that are always present # Add in fields that are always present
defaults = [ self._populate_always_present_fields(expected_field)
("label", ""),
("instructions", ""),
("placeholder", ""),
("defaultValue", ""),
("restrictions", {}),
("errorMessages", {}),
]
for key, value in defaults:
if key not in expected_field:
expected_field[key] = value
# Retrieve the registration form description # Retrieve the registration form description
with override_settings(REGISTRATION_EXTRA_FIELDS=extra_fields_setting): with override_settings(REGISTRATION_EXTRA_FIELDS=extra_fields_setting):
...@@ -1855,54 +1936,28 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase): ...@@ -1855,54 +1936,28 @@ class RegistrationViewTest(ThirdPartyAuthTestMixin, UserAPITestCase):
# Verify that the form description matches what we'd expect # Verify that the form description matches what we'd expect
form_desc = json.loads(response.content) form_desc = json.loads(response.content)
# Search the form for this field
actual_field = None actual_field = None
for field in form_desc["fields"]: for field in form_desc["fields"]:
if field["name"] == expected_field["name"]: if field["name"] == expected_field["name"]:
actual_field = field actual_field = field
break break
self.assertIsNot( self._assert_fields_match(actual_field, expected_field)
actual_field, None,
msg="Could not find field {name}".format(name=expected_field["name"])
)
for key, value in expected_field.iteritems():
self.assertEqual(
expected_field[key], actual_field[key],
msg=u"Expected {expected} for {key} but got {actual} instead".format(
key=key,
expected=expected_field[key],
actual=actual_field[key]
)
)
def test_country_overrides(self):
"""Test that overridden countries are available in country list."""
# Retrieve the registration form description
with override_settings(REGISTRATION_EXTRA_FIELDS={"country": "required"}):
response = self.client.get(self.url)
self.assertHttpOK(response)
self.assertContains(response, 'Kosovo')
def test_create_account_not_allowed(self): def _assert_password_field_hidden(self, field_settings):
""" self._assert_reg_field(field_settings, {
Test case to check user creation is forbidden when ALLOW_PUBLIC_ACCOUNT_CREATION feature flag is turned off "name": "password",
""" "type": "hidden",
def _side_effect_for_get_value(value, default=None): "required": False
""" })
returns a side_effect with given return value for a given value
"""
if value == 'ALLOW_PUBLIC_ACCOUNT_CREATION':
return False
else:
return get_value(value, default)
with mock.patch('openedx.core.djangoapps.site_configuration.helpers.get_value') as mock_get_value: def _assert_social_auth_provider_present(self, field_settings, backend):
mock_get_value.side_effect = _side_effect_for_get_value self._assert_reg_field(field_settings, {
response = self.client.post(self.url, {"email": self.EMAIL, "username": self.USERNAME}) "name": "social_auth_provider",
self.assertEqual(response.status_code, 403) "type": "hidden",
"required": False,
"defaultValue": backend.name
})
@httpretty.activate @httpretty.activate
...@@ -1931,8 +1986,7 @@ class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin, CacheIsolationTe ...@@ -1931,8 +1986,7 @@ class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin, CacheIsolationTe
"country": "US", "country": "US",
"username": user.username if user else "test_username", "username": user.username if user else "test_username",
"name": user.first_name if user else "test name", "name": user.first_name if user else "test name",
"email": user.email if user else "test@test.com", "email": user.email if user else "test@test.com"
} }
def _assert_existing_user_error(self, response): def _assert_existing_user_error(self, response):
......
...@@ -924,7 +924,7 @@ class RegistrationView(APIView): ...@@ -924,7 +924,7 @@ class RegistrationView(APIView):
running_pipeline.get('kwargs') running_pipeline.get('kwargs')
) )
for field_name in self.DEFAULT_FIELDS: for field_name in self.DEFAULT_FIELDS + self.EXTRA_FIELDS:
if field_name in field_overrides: if field_name in field_overrides:
form_desc.override_field_properties( form_desc.override_field_properties(
field_name, default=field_overrides[field_name] field_name, default=field_overrides[field_name]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment