Commit ae8a5f2b by Brittney Exline

ENT-447 Add flag to third party auth SAML provider to send to the registration page first

parent f99fbaea
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('third_party_auth', '0011_auto_20170616_0112'),
]
operations = [
migrations.AddField(
model_name='ltiproviderconfig',
name='send_to_registration_first',
field=models.BooleanField(default=False, help_text='If this option is selected, users will be directed to the registration page immediately after authenticating with the third party instead of the login page.'),
),
migrations.AddField(
model_name='oauth2providerconfig',
name='send_to_registration_first',
field=models.BooleanField(default=False, help_text='If this option is selected, users will be directed to the registration page immediately after authenticating with the third party instead of the login page.'),
),
migrations.AddField(
model_name='samlproviderconfig',
name='send_to_registration_first',
field=models.BooleanField(default=False, help_text='If this option is selected, users will be directed to the registration page immediately after authenticating with the third party instead of the login page.'),
),
]
...@@ -176,6 +176,13 @@ class ProviderConfig(ConfigurationModel): ...@@ -176,6 +176,13 @@ class ProviderConfig(ConfigurationModel):
"Django platform session default length will be used." "Django platform session default length will be used."
) )
) )
send_to_registration_first = models.BooleanField(
default=False,
help_text=_(
"If this option is selected, users will be directed to the registration page "
"immediately after authenticating with the third party instead of the login page."
),
)
prefix = None # used for provider_id. Set to a string value in subclass prefix = None # used for provider_id. Set to a string value in subclass
backend_name = None # Set to a field or fixed value in subclass backend_name = None # Set to a field or fixed value in subclass
accepts_logins = True # Whether to display a sign-in button when the provider is enabled accepts_logins = True # Whether to display a sign-in button when the provider is enabled
......
...@@ -558,7 +558,8 @@ def ensure_user_information(strategy, auth_entry, backend=None, user=None, socia ...@@ -558,7 +558,8 @@ def ensure_user_information(strategy, auth_entry, backend=None, user=None, socia
def should_force_account_creation(): def should_force_account_creation():
""" For some third party providers, we auto-create user accounts """ """ For some third party providers, we auto-create user accounts """
current_provider = provider.Registry.get_from_pipeline({'backend': current_partial.backend, 'kwargs': kwargs}) current_provider = provider.Registry.get_from_pipeline({'backend': current_partial.backend, 'kwargs': kwargs})
return current_provider and current_provider.skip_email_verification return (current_provider and
(current_provider.skip_email_verification or current_provider.send_to_registration_first))
if not user: if not user:
if auth_entry in [AUTH_ENTRY_LOGIN_API, AUTH_ENTRY_REGISTER_API]: if auth_entry in [AUTH_ENTRY_LOGIN_API, AUTH_ENTRY_REGISTER_API]:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import unittest import unittest
import mock import mock
import ddt
from django import test from django import test
from django.conf import settings from django.conf import settings
from django.contrib.auth import models from django.contrib.auth import models
...@@ -296,3 +297,42 @@ class TestPipelineUtilityFunctions(TestCase, test.TestCase): ...@@ -296,3 +297,42 @@ class TestPipelineUtilityFunctions(TestCase, test.TestCase):
) )
pipeline.lift_quarantine(request) pipeline.lift_quarantine(request)
self.assertNotIn('third_party_auth_quarantined_modules', request.session) self.assertNotIn('third_party_auth_quarantined_modules', request.session)
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
@ddt.ddt
class EnsureUserInformationTestCase(testutil.TestCase, test.TestCase):
"""Tests ensuring that we have the necessary user information to proceed with the pipeline."""
def setUp(self):
super(EnsureUserInformationTestCase, self).setUp()
@ddt.data(
(True, '/register'),
(False, '/login')
)
@ddt.unpack
def test_provider_settings_redirect_to_registration(self, send_to_registration_first, expected_redirect_url):
"""
Test if user is not authenticated, that they get redirected to the appropriate page
based on the provider's setting for send_to_registration_first.
"""
provider = mock.MagicMock(
send_to_registration_first=send_to_registration_first,
skip_email_verification=False
)
with mock.patch('third_party_auth.pipeline.provider.Registry.get_from_pipeline') as get_from_pipeline:
get_from_pipeline.return_value = provider
with mock.patch('social_core.pipeline.partial.partial_prepare') as partial_prepare:
partial_prepare.return_value = mock.MagicMock(token='')
strategy = mock.MagicMock()
response = pipeline.ensure_user_information(
strategy=strategy,
backend=None,
auth_entry=pipeline.AUTH_ENTRY_LOGIN,
pipeline_index=0
)
assert response.status_code == 302
assert response.url == expected_redirect_url
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment