Commit c553a47c by Will Daly

Update third party auth on the logistration form

parent 13986d84
...@@ -4,7 +4,6 @@ import unittest ...@@ -4,7 +4,6 @@ import unittest
from mock import patch from mock import patch
from django.conf import settings from django.conf import settings
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.test import TestCase
import ddt import ddt
from django.test.utils import override_settings from django.test.utils import override_settings
from xmodule.modulestore.tests.factories import CourseFactory from xmodule.modulestore.tests.factories import CourseFactory
...@@ -130,7 +129,7 @@ class LoginFormTest(ModuleStoreTestCase): ...@@ -130,7 +129,7 @@ class LoginFormTest(ModuleStoreTestCase):
@ddt.ddt @ddt.ddt
@override_settings(MODULESTORE=MODULESTORE_CONFIG) @override_settings(MODULESTORE=MODULESTORE_CONFIG)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class RegisterFormTest(TestCase): class RegisterFormTest(ModuleStoreTestCase):
"""Test rendering of the registration form. """ """Test rendering of the registration form. """
def setUp(self): def setUp(self):
......
"""Helper functions for the student account app. """
# TODO: move this function here instead of importing it from student
from student.helpers import auth_pipeline_urls
...@@ -12,12 +12,21 @@ from django.test import TestCase ...@@ -12,12 +12,21 @@ from django.test import TestCase
from django.conf import settings from django.conf import settings
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.core import mail from django.core import mail
from django.test.utils import override_settings
from util.testing import UrlResetMixin from util.testing import UrlResetMixin
from third_party_auth.tests.testutil import simulate_running_pipeline from third_party_auth.tests.testutil import simulate_running_pipeline
from user_api.api import account as account_api from user_api.api import account as account_api
from user_api.api import profile as profile_api from user_api.api import profile as profile_api
from util.bad_request_rate_limiter import BadRequestRateLimiter from util.bad_request_rate_limiter import BadRequestRateLimiter
from xmodule.modulestore.tests.django_utils import (
ModuleStoreTestCase, mixed_store_config
)
from xmodule.modulestore.tests.factories import CourseFactory
from student.tests.factories import CourseModeFactory
MODULESTORE_CONFIG = mixed_store_config(settings.COMMON_TEST_DATA_ROOT, {}, include_xml=False)
@ddt.ddt @ddt.ddt
...@@ -365,7 +374,8 @@ class StudentAccountUpdateTest(UrlResetMixin, TestCase): ...@@ -365,7 +374,8 @@ class StudentAccountUpdateTest(UrlResetMixin, TestCase):
@ddt.ddt @ddt.ddt
class StudentAccountLoginAndRegistrationTest(TestCase): @override_settings(MODULESTORE=MODULESTORE_CONFIG)
class StudentAccountLoginAndRegistrationTest(ModuleStoreTestCase):
""" Tests for the student account views that update the user's account information. """ """ Tests for the student account views that update the user's account information. """
USERNAME = "bob" USERNAME = "bob"
...@@ -398,13 +408,7 @@ class StudentAccountLoginAndRegistrationTest(TestCase): ...@@ -398,13 +408,7 @@ class StudentAccountLoginAndRegistrationTest(TestCase):
@ddt.data("account_login", "account_register") @ddt.data("account_login", "account_register")
def test_third_party_auth_disabled(self, url_name): def test_third_party_auth_disabled(self, url_name):
response = self.client.get(reverse(url_name)) response = self.client.get(reverse(url_name))
expected_data = "data-third-party-auth='{auth_info}'".format( self._assert_third_party_auth_data(response, None, [])
auth_info=json.dumps({
"currentProvider": None,
"providers": []
})
)
self.assertContains(response, expected_data)
@ddt.data( @ddt.data(
("account_login", None, None), ("account_login", None, None),
...@@ -427,23 +431,141 @@ class StudentAccountLoginAndRegistrationTest(TestCase): ...@@ -427,23 +431,141 @@ class StudentAccountLoginAndRegistrationTest(TestCase):
response = self.client.get(reverse(url_name)) response = self.client.get(reverse(url_name))
# This relies on the THIRD_PARTY_AUTH configuration in the test settings # This relies on the THIRD_PARTY_AUTH configuration in the test settings
expected_providers = [
{
"name": "Facebook",
"iconClass": "icon-facebook",
"loginUrl": self._third_party_login_url("facebook", "account_login"),
"registerUrl": self._third_party_login_url("facebook", "account_register")
},
{
"name": "Google",
"iconClass": "icon-google-plus",
"loginUrl": self._third_party_login_url("google-oauth2", "account_login"),
"registerUrl": self._third_party_login_url("google-oauth2", "account_register")
}
]
self._assert_third_party_auth_data(response, current_provider, expected_providers)
@ddt.data([], ["honor"], ["honor", "verified", "audit"], ["professional"])
def test_third_party_auth_course_id_verified(self, modes):
# Create a course with the specified course modes
course = CourseFactory.create()
for slug in modes:
CourseModeFactory.create(
course_id=course.id,
mode_slug=slug,
mode_display_name=slug
)
# Verify that the entry URL for third party auth
# contains the course ID and redirects to the track selection page.
course_modes_choose_url = reverse(
"course_modes_choose",
kwargs={"course_id": unicode(course.id)}
)
expected_providers = [
{
"name": "Facebook",
"iconClass": "icon-facebook",
"loginUrl": self._third_party_login_url(
"facebook", "account_login",
course_id=unicode(course.id),
redirect_url=course_modes_choose_url
),
"registerUrl": self._third_party_login_url(
"facebook", "account_register",
course_id=unicode(course.id),
redirect_url=course_modes_choose_url
)
},
{
"name": "Google",
"iconClass": "icon-google-plus",
"loginUrl": self._third_party_login_url(
"google-oauth2", "account_login",
course_id=unicode(course.id),
redirect_url=course_modes_choose_url
),
"registerUrl": self._third_party_login_url(
"google-oauth2", "account_register",
course_id=unicode(course.id),
redirect_url=course_modes_choose_url
)
}
]
# Verify that the login page contains the correct provider URLs
response = self.client.get(reverse("account_login"), {"course_id": unicode(course.id)})
self._assert_third_party_auth_data(response, None, expected_providers)
def test_third_party_auth_course_id_shopping_cart(self):
# Create a course with a white-label course mode
course = CourseFactory.create()
CourseModeFactory.create(
course_id=course.id,
mode_slug="honor",
mode_display_name="Honor",
min_price=100
)
# Verify that the entry URL for third party auth
# contains the course ID and redirects to the shopping cart
shoppingcart_url = reverse("shoppingcart.views.show_cart")
expected_providers = [
{
"name": "Facebook",
"iconClass": "icon-facebook",
"loginUrl": self._third_party_login_url(
"facebook", "account_login",
course_id=unicode(course.id),
redirect_url=shoppingcart_url
),
"registerUrl": self._third_party_login_url(
"facebook", "account_register",
course_id=unicode(course.id),
redirect_url=shoppingcart_url
)
},
{
"name": "Google",
"iconClass": "icon-google-plus",
"loginUrl": self._third_party_login_url(
"google-oauth2", "account_login",
course_id=unicode(course.id),
redirect_url=shoppingcart_url
),
"registerUrl": self._third_party_login_url(
"google-oauth2", "account_register",
course_id=unicode(course.id),
redirect_url=shoppingcart_url
)
}
]
# Verify that the login page contains the correct provider URLs
response = self.client.get(reverse("account_login"), {"course_id": unicode(course.id)})
self._assert_third_party_auth_data(response, None, expected_providers)
def _assert_third_party_auth_data(self, response, current_provider, providers):
"""Verify that third party auth info is rendered correctly in a DOM data attribute. """
expected_data = u"data-third-party-auth='{auth_info}'".format( expected_data = u"data-third-party-auth='{auth_info}'".format(
auth_info=json.dumps({ auth_info=json.dumps({
"currentProvider": current_provider, "currentProvider": current_provider,
"providers": [ "providers": providers
{
"name": "Facebook",
"iconClass": "icon-facebook",
"loginUrl": "/auth/login/facebook/?auth_entry=account_login",
"registerUrl": "/auth/login/facebook/?auth_entry=account_register",
},
{
"name": "Google",
"iconClass": "icon-google-plus",
"loginUrl": "/auth/login/google-oauth2/?auth_entry=account_login",
"registerUrl": "/auth/login/google-oauth2/?auth_entry=account_register",
}
]
}) })
) )
self.assertContains(response, expected_data) self.assertContains(response, expected_data)
def _third_party_login_url(self, backend_name, auth_entry, course_id=None, redirect_url=None):
"""Construct the login URL to start third party authentication. """
params = [("auth_entry", auth_entry)]
if redirect_url:
params.append(("next", redirect_url))
if course_id:
params.append(("enroll_course_id", course_id))
return u"{url}?{params}".format(
url=reverse("social:begin", kwargs={"backend": backend_name}),
params=urlencode(params)
)
...@@ -20,6 +20,8 @@ from user_api.api import account as account_api ...@@ -20,6 +20,8 @@ from user_api.api import account as account_api
from user_api.api import profile as profile_api from user_api.api import profile as profile_api
from util.bad_request_rate_limiter import BadRequestRateLimiter from util.bad_request_rate_limiter import BadRequestRateLimiter
from student_account.helpers import auth_pipeline_urls
AUDIT_LOG = logging.getLogger("audit") AUDIT_LOG = logging.getLogger("audit")
...@@ -279,17 +281,23 @@ def _third_party_auth_context(request): ...@@ -279,17 +281,23 @@ def _third_party_auth_context(request):
"providers": [] "providers": []
} }
course_id = request.GET.get("course_id")
login_urls = auth_pipeline_urls(
third_party_auth.pipeline.AUTH_ENTRY_LOGIN_2,
course_id=course_id
)
register_urls = auth_pipeline_urls(
third_party_auth.pipeline.AUTH_ENTRY_REGISTER_2,
course_id=course_id
)
if third_party_auth.is_enabled(): if third_party_auth.is_enabled():
context["providers"] = [ context["providers"] = [
{ {
"name": enabled.NAME, "name": enabled.NAME,
"iconClass": enabled.ICON_CLASS, "iconClass": enabled.ICON_CLASS,
"loginUrl": third_party_auth.pipeline.get_login_url( "loginUrl": login_urls[enabled.NAME],
enabled.NAME, third_party_auth.pipeline.AUTH_ENTRY_LOGIN_2 "registerUrl": register_urls[enabled.NAME]
),
"registerUrl": third_party_auth.pipeline.get_login_url(
enabled.NAME, third_party_auth.pipeline.AUTH_ENTRY_REGISTER_2
)
} }
for enabled in third_party_auth.provider.Registry.enabled() for enabled in third_party_auth.provider.Registry.enabled()
] ]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment