Commit 289469dd by Nimisha Asthagiri

Merge pull request #7353 from edx/mobile/third-party-oauth-reg

Mobile registration with Google/FB
parents 8a1877f0 dfcef9dd
...@@ -57,7 +57,7 @@ class AccessTokenExchangeForm(ScopeMixin, OAuthForm): ...@@ -57,7 +57,7 @@ class AccessTokenExchangeForm(ScopeMixin, OAuthForm):
} }
) )
self.request.session[pipeline.AUTH_ENTRY_KEY] = pipeline.AUTH_ENTRY_API self.request.session[pipeline.AUTH_ENTRY_KEY] = pipeline.AUTH_ENTRY_LOGIN_API
client_id = self.cleaned_data["client_id"] client_id = self.cleaned_data["client_id"]
try: try:
......
...@@ -12,11 +12,8 @@ from provider import scope ...@@ -12,11 +12,8 @@ from provider import scope
import social.apps.django_app.utils as social_utils import social.apps.django_app.utils as social_utils
from oauth_exchange.forms import AccessTokenExchangeForm from oauth_exchange.forms import AccessTokenExchangeForm
from oauth_exchange.tests.utils import ( from oauth_exchange.tests.utils import AccessTokenExchangeTestMixin
AccessTokenExchangeTestMixin, from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle
AccessTokenExchangeMixinFacebook,
AccessTokenExchangeMixinGoogle
)
class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin): class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
...@@ -50,7 +47,7 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin): ...@@ -50,7 +47,7 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
@httpretty.activate @httpretty.activate
class AccessTokenExchangeFormTestFacebook( class AccessTokenExchangeFormTestFacebook(
AccessTokenExchangeFormTest, AccessTokenExchangeFormTest,
AccessTokenExchangeMixinFacebook, ThirdPartyOAuthTestMixinFacebook,
TestCase TestCase
): ):
""" """
...@@ -64,7 +61,7 @@ class AccessTokenExchangeFormTestFacebook( ...@@ -64,7 +61,7 @@ class AccessTokenExchangeFormTestFacebook(
@httpretty.activate @httpretty.activate
class AccessTokenExchangeFormTestGoogle( class AccessTokenExchangeFormTestGoogle(
AccessTokenExchangeFormTest, AccessTokenExchangeFormTest,
AccessTokenExchangeMixinGoogle, ThirdPartyOAuthTestMixinGoogle,
TestCase TestCase
): ):
""" """
......
...@@ -14,11 +14,8 @@ import provider.constants ...@@ -14,11 +14,8 @@ import provider.constants
from provider import scope from provider import scope
from provider.oauth2.models import AccessToken from provider.oauth2.models import AccessToken
from oauth_exchange.tests.utils import ( from oauth_exchange.tests.utils import AccessTokenExchangeTestMixin
AccessTokenExchangeTestMixin, from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle
AccessTokenExchangeMixinFacebook,
AccessTokenExchangeMixinGoogle
)
class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin): class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
...@@ -95,7 +92,7 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin): ...@@ -95,7 +92,7 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
@httpretty.activate @httpretty.activate
class AccessTokenExchangeViewTestFacebook( class AccessTokenExchangeViewTestFacebook(
AccessTokenExchangeViewTest, AccessTokenExchangeViewTest,
AccessTokenExchangeMixinFacebook, ThirdPartyOAuthTestMixinFacebook,
TestCase TestCase
): ):
""" """
...@@ -109,7 +106,7 @@ class AccessTokenExchangeViewTestFacebook( ...@@ -109,7 +106,7 @@ class AccessTokenExchangeViewTestFacebook(
@httpretty.activate @httpretty.activate
class AccessTokenExchangeViewTestGoogle( class AccessTokenExchangeViewTestGoogle(
AccessTokenExchangeViewTest, AccessTokenExchangeViewTest,
AccessTokenExchangeMixinGoogle, ThirdPartyOAuthTestMixinGoogle,
TestCase TestCase
): ):
""" """
......
""" """
Test utilities for OAuth access token exchange Test utilities for OAuth access token exchange
""" """
import json
import httpretty
import provider.constants import provider.constants
from provider.oauth2.models import Client
from social.apps.django_app.default.models import UserSocialAuth from social.apps.django_app.default.models import UserSocialAuth
from student.tests.factories import UserFactory from third_party_auth.tests.utils import ThirdPartyOAuthTestMixin
class AccessTokenExchangeTestMixin(object): class AccessTokenExchangeTestMixin(ThirdPartyOAuthTestMixin):
""" """
A mixin to define test cases for access token exchange. The following A mixin to define test cases for access token exchange. The following
methods must be implemented by subclasses: methods must be implemented by subclasses:
...@@ -21,40 +17,12 @@ class AccessTokenExchangeTestMixin(object): ...@@ -21,40 +17,12 @@ class AccessTokenExchangeTestMixin(object):
def setUp(self): def setUp(self):
super(AccessTokenExchangeTestMixin, self).setUp() super(AccessTokenExchangeTestMixin, self).setUp()
self.client_id = "test_client_id"
self.oauth_client = Client.objects.create(
client_id=self.client_id,
client_type=provider.constants.PUBLIC
)
self.social_uid = "test_social_uid"
self.user = UserFactory()
UserSocialAuth.objects.create(user=self.user, provider=self.BACKEND, uid=self.social_uid)
self.access_token = "test_access_token"
# Initialize to minimal data # Initialize to minimal data
self.data = { self.data = {
"access_token": self.access_token, "access_token": self.access_token,
"client_id": self.client_id, "client_id": self.client_id,
} }
def _setup_provider_response(self, success):
"""
Register a mock response for the third party user information endpoint;
success indicates whether the response status code should be 200 or 400
"""
if success:
status = 200
body = json.dumps({self.UID_FIELD: self.social_uid})
else:
status = 400
body = json.dumps({})
httpretty.register_uri(
httpretty.GET,
self.USER_URL,
body=body,
status=status,
content_type="application/json"
)
def _assert_error(self, _data, _expected_error, _expected_error_description): def _assert_error(self, _data, _expected_error, _expected_error_description):
""" """
Given request data, execute a test and check that the expected error Given request data, execute a test and check that the expected error
...@@ -101,6 +69,12 @@ class AccessTokenExchangeTestMixin(object): ...@@ -101,6 +69,12 @@ class AccessTokenExchangeTestMixin(object):
"test_client_id is not a public client" "test_client_id is not a public client"
) )
def test_inactive_user(self):
self.user.is_active = False
self.user.save() # pylint: disable=no-member
self._setup_provider_response(success=True)
self._assert_success(self.data, expected_scopes=[])
def test_invalid_acess_token(self): def test_invalid_acess_token(self):
self._setup_provider_response(success=False) self._setup_provider_response(success=False)
self._assert_error(self.data, "invalid_grant", "access_token is not valid") self._assert_error(self.data, "invalid_grant", "access_token is not valid")
...@@ -110,18 +84,14 @@ class AccessTokenExchangeTestMixin(object): ...@@ -110,18 +84,14 @@ class AccessTokenExchangeTestMixin(object):
self._setup_provider_response(success=True) self._setup_provider_response(success=True)
self._assert_error(self.data, "invalid_grant", "access_token is not valid") self._assert_error(self.data, "invalid_grant", "access_token is not valid")
def test_user_automatically_linked_by_email(self):
UserSocialAuth.objects.all().delete()
self._setup_provider_response(success=True, email=self.user.email)
self._assert_success(self.data, expected_scopes=[])
class AccessTokenExchangeMixinFacebook(object): def test_inactive_user_not_automatically_linked(self):
"""Tests access token exchange with the Facebook backend""" UserSocialAuth.objects.all().delete()
BACKEND = "facebook" self._setup_provider_response(success=True, email=self.user.email)
USER_URL = "https://graph.facebook.com/me" self.user.is_active = False
# In facebook responses, the "id" field is used as the user's identifier self.user.save() # pylint: disable=no-member
UID_FIELD = "id" self._assert_error(self.data, "invalid_grant", "access_token is not valid")
class AccessTokenExchangeMixinGoogle(object):
"""Tests access token exchange with the Google backend"""
BACKEND = "google-oauth2"
USER_URL = "https://www.googleapis.com/oauth2/v1/userinfo"
# In google-oauth2 responses, the "email" field is used as the user's identifier
UID_FIELD = "email"
...@@ -10,14 +10,18 @@ from django.conf import settings ...@@ -10,14 +10,18 @@ from django.conf import settings
from django.core.cache import cache from django.core.cache import cache
from django.core.urlresolvers import reverse, NoReverseMatch from django.core.urlresolvers import reverse, NoReverseMatch
from django.http import HttpResponseBadRequest, HttpResponse from django.http import HttpResponseBadRequest, HttpResponse
from external_auth.models import ExternalAuthMap
import httpretty import httpretty
from mock import patch from mock import patch
from social.apps.django_app.default.models import UserSocialAuth from social.apps.django_app.default.models import UserSocialAuth
from external_auth.models import ExternalAuthMap
from student.tests.factories import UserFactory, RegistrationFactory, UserProfileFactory from student.tests.factories import UserFactory, RegistrationFactory, UserProfileFactory
from student.views import login_oauth_token from student.views import login_oauth_token
from third_party_auth.tests.utils import (
ThirdPartyOAuthTestMixin,
ThirdPartyOAuthTestMixinFacebook,
ThirdPartyOAuthTestMixinGoogle
)
from xmodule.modulestore.tests.factories import CourseFactory from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
...@@ -437,7 +441,7 @@ class ExternalAuthShibTest(ModuleStoreTestCase): ...@@ -437,7 +441,7 @@ class ExternalAuthShibTest(ModuleStoreTestCase):
@httpretty.activate @httpretty.activate
class LoginOAuthTokenMixin(object): class LoginOAuthTokenMixin(ThirdPartyOAuthTestMixin):
""" """
Mixin with tests for the login_oauth_token view. A TestCase that includes Mixin with tests for the login_oauth_token view. A TestCase that includes
this must define the following: this must define the following:
...@@ -448,30 +452,8 @@ class LoginOAuthTokenMixin(object): ...@@ -448,30 +452,8 @@ class LoginOAuthTokenMixin(object):
""" """
def setUp(self): def setUp(self):
self.client = Client() super(LoginOAuthTokenMixin, self).setUp()
self.url = reverse(login_oauth_token, kwargs={"backend": self.BACKEND}) self.url = reverse(login_oauth_token, kwargs={"backend": self.BACKEND})
self.social_uid = "social_uid"
self.user = UserFactory()
UserSocialAuth.objects.create(user=self.user, provider=self.BACKEND, uid=self.social_uid)
def _setup_user_response(self, success):
"""
Register a mock response for the third party user information endpoint;
success indicates whether the response status code should be 200 or 400
"""
if success:
status = 200
body = json.dumps({self.UID_FIELD: self.social_uid})
else:
status = 400
body = json.dumps({})
httpretty.register_uri(
httpretty.GET,
self.USER_URL,
body=body,
status=status,
content_type="application/json"
)
def _assert_error(self, response, status_code, error): def _assert_error(self, response, status_code, error):
"""Assert that the given response was a 400 with the given error code""" """Assert that the given response was a 400 with the given error code"""
...@@ -480,13 +462,13 @@ class LoginOAuthTokenMixin(object): ...@@ -480,13 +462,13 @@ class LoginOAuthTokenMixin(object):
self.assertNotIn("partial_pipeline", self.client.session) self.assertNotIn("partial_pipeline", self.client.session)
def test_success(self): def test_success(self):
self._setup_user_response(success=True) self._setup_provider_response(success=True)
response = self.client.post(self.url, {"access_token": "dummy"}) response = self.client.post(self.url, {"access_token": "dummy"})
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertEqual(self.client.session['_auth_user_id'], self.user.id) # pylint: disable=no-member self.assertEqual(self.client.session['_auth_user_id'], self.user.id) # pylint: disable=no-member
def test_invalid_token(self): def test_invalid_token(self):
self._setup_user_response(success=False) self._setup_provider_response(success=False)
response = self.client.post(self.url, {"access_token": "dummy"}) response = self.client.post(self.url, {"access_token": "dummy"})
self._assert_error(response, 401, "invalid_token") self._assert_error(response, 401, "invalid_token")
...@@ -496,7 +478,7 @@ class LoginOAuthTokenMixin(object): ...@@ -496,7 +478,7 @@ class LoginOAuthTokenMixin(object):
def test_unlinked_user(self): def test_unlinked_user(self):
UserSocialAuth.objects.all().delete() UserSocialAuth.objects.all().delete()
self._setup_user_response(success=True) self._setup_provider_response(success=True)
response = self.client.post(self.url, {"access_token": "dummy"}) response = self.client.post(self.url, {"access_token": "dummy"})
self._assert_error(response, 401, "invalid_token") self._assert_error(response, 401, "invalid_token")
...@@ -507,17 +489,13 @@ class LoginOAuthTokenMixin(object): ...@@ -507,17 +489,13 @@ class LoginOAuthTokenMixin(object):
# This is necessary because cms does not implement third party auth # This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
class LoginOAuthTokenTestFacebook(LoginOAuthTokenMixin, TestCase): class LoginOAuthTokenTestFacebook(LoginOAuthTokenMixin, ThirdPartyOAuthTestMixinFacebook, TestCase):
"""Tests login_oauth_token with the Facebook backend""" """Tests login_oauth_token with the Facebook backend"""
BACKEND = "facebook" pass
USER_URL = "https://graph.facebook.com/me"
UID_FIELD = "id"
# This is necessary because cms does not implement third party auth # This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
class LoginOAuthTokenTestGoogle(LoginOAuthTokenMixin, TestCase): class LoginOAuthTokenTestGoogle(LoginOAuthTokenMixin, ThirdPartyOAuthTestMixinGoogle, TestCase):
"""Tests login_oauth_token with the Google backend""" """Tests login_oauth_token with the Google backend"""
BACKEND = "google-oauth2" pass
USER_URL = "https://www.googleapis.com/oauth2/v1/userinfo"
UID_FIELD = "email"
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
import uuid import uuid
import time import time
import json import json
import warnings
from collections import defaultdict from collections import defaultdict
from pytz import UTC from pytz import UTC
from ipware.ip import get_ip from ipware.ip import get_ip
...@@ -43,6 +44,7 @@ from requests import HTTPError ...@@ -43,6 +44,7 @@ from requests import HTTPError
from social.apps.django_app import utils as social_utils from social.apps.django_app import utils as social_utils
from social.backends import oauth as social_oauth from social.backends import oauth as social_oauth
from social.exceptions import AuthException, AuthAlreadyAssociated
from edxmako.shortcuts import render_to_response, render_to_string from edxmako.shortcuts import render_to_response, render_to_string
...@@ -1168,11 +1170,13 @@ def login_oauth_token(request, backend): ...@@ -1168,11 +1170,13 @@ def login_oauth_token(request, backend):
retrieve information from a third party and matching that information to an retrieve information from a third party and matching that information to an
existing user. existing user.
""" """
warnings.warn("Please use AccessTokenExchangeView instead.", DeprecationWarning)
backend = request.social_strategy.backend backend = request.social_strategy.backend
if isinstance(backend, social_oauth.BaseOAuth1) or isinstance(backend, social_oauth.BaseOAuth2): if isinstance(backend, social_oauth.BaseOAuth1) or isinstance(backend, social_oauth.BaseOAuth2):
if "access_token" in request.POST: if "access_token" in request.POST:
# Tell third party auth pipeline that this is an API call # Tell third party auth pipeline that this is an API call
request.session[pipeline.AUTH_ENTRY_KEY] = pipeline.AUTH_ENTRY_API request.session[pipeline.AUTH_ENTRY_KEY] = pipeline.AUTH_ENTRY_LOGIN_API
user = None user = None
try: try:
user = backend.do_auth(request.POST["access_token"]) user = backend.do_auth(request.POST["access_token"])
...@@ -1417,7 +1421,14 @@ def create_account_with_params(request, params): ...@@ -1417,7 +1421,14 @@ def create_account_with_params(request, params):
getattr(settings, 'REGISTRATION_EXTRA_FIELDS', {}) getattr(settings, 'REGISTRATION_EXTRA_FIELDS', {})
) )
if third_party_auth.is_enabled() and pipeline.running(request): # Boolean of whether a 3rd party auth provider and credentials were provided in
# the API so the newly created account can link with the 3rd party account.
#
# Note: this is orthogonal to the 3rd party authentication pipeline that occurs
# when the account is created via the browser and redirect URLs.
should_link_with_social_auth = third_party_auth.is_enabled() and 'provider' in params
if should_link_with_social_auth or (third_party_auth.is_enabled() and pipeline.running(request)):
params["password"] = pipeline.make_random_password() params["password"] = pipeline.make_random_password()
# if doing signup for an external authorization, then get email, password, name from the eamap # if doing signup for an external authorization, then get email, password, name from the eamap
...@@ -1458,13 +1469,38 @@ def create_account_with_params(request, params): ...@@ -1458,13 +1469,38 @@ def create_account_with_params(request, params):
extended_profile_fields=extended_profile_fields, extended_profile_fields=extended_profile_fields,
enforce_username_neq_password=True, enforce_username_neq_password=True,
enforce_password_policy=enforce_password_policy, enforce_password_policy=enforce_password_policy,
tos_required=tos_required tos_required=tos_required,
) )
with transaction.commit_on_success(): with transaction.commit_on_success():
ret = _do_create_account(form) # first, create the account
(user, profile, registration) = _do_create_account(form)
(user, profile, registration) = ret
# next, link the account with social auth, if provided
if should_link_with_social_auth:
request.social_strategy = social_utils.load_strategy(backend=params['provider'], request=request)
social_access_token = params.get('access_token')
if not social_access_token:
raise ValidationError({
'access_token': [
_("An access_token is required when passing value ({}) for provider.").format(
params['provider']
)
]
})
request.session[pipeline.AUTH_ENTRY_KEY] = pipeline.AUTH_ENTRY_REGISTER_API
pipeline_user = None
error_message = ""
try:
pipeline_user = request.social_strategy.backend.do_auth(social_access_token, user=user)
except AuthAlreadyAssociated:
error_message = _("The provided access_token is already associated with another user.")
except (HTTPError, AuthException):
error_message = _("The provided access_token is not valid.")
if not pipeline_user or not isinstance(pipeline_user, User):
# Ensure user does not re-enter the pipeline
request.social_strategy.clean_partial_pipeline()
raise ValidationError({'access_token': [error_message]})
if settings.FEATURES.get('ENABLE_DISCUSSION_EMAIL_DIGEST'): if settings.FEATURES.get('ENABLE_DISCUSSION_EMAIL_DIGEST'):
try: try:
...@@ -1598,6 +1634,8 @@ def create_account(request, post_override=None): ...@@ -1598,6 +1634,8 @@ def create_account(request, post_override=None):
JSON call to create new edX account. JSON call to create new edX account.
Used by form in signup_modal.html, which is included into navigation.html Used by form in signup_modal.html, which is included into navigation.html
""" """
warnings.warn("Please use RegistrationView instead.", DeprecationWarning)
try: try:
create_account_with_params(request, post_override or request.POST) create_account_with_params(request, post_override or request.POST)
except AccountValidationError as exc: except AccountValidationError as exc:
......
...@@ -72,6 +72,7 @@ from django.shortcuts import redirect ...@@ -72,6 +72,7 @@ from django.shortcuts import redirect
from social.apps.django_app.default import models from social.apps.django_app.default import models
from social.exceptions import AuthException from social.exceptions import AuthException
from social.pipeline import partial from social.pipeline import partial
from social.pipeline.social_auth import associate_by_email
import student import student
from embargo import api as embargo_api from embargo import api as embargo_api
...@@ -111,6 +112,8 @@ AUTH_REDIRECT_KEY = 'next' ...@@ -111,6 +112,8 @@ AUTH_REDIRECT_KEY = 'next'
AUTH_ENROLL_COURSE_ID_KEY = 'enroll_course_id' AUTH_ENROLL_COURSE_ID_KEY = 'enroll_course_id'
AUTH_EMAIL_OPT_IN_KEY = 'email_opt_in' AUTH_EMAIL_OPT_IN_KEY = 'email_opt_in'
# The following are various possible values for the AUTH_ENTRY_KEY.
AUTH_ENTRY_DASHBOARD = 'dashboard' AUTH_ENTRY_DASHBOARD = 'dashboard'
AUTH_ENTRY_LOGIN = 'login' AUTH_ENTRY_LOGIN = 'login'
AUTH_ENTRY_REGISTER = 'register' AUTH_ENTRY_REGISTER = 'register'
...@@ -122,7 +125,14 @@ AUTH_ENTRY_REGISTER = 'register' ...@@ -122,7 +125,14 @@ AUTH_ENTRY_REGISTER = 'register'
AUTH_ENTRY_LOGIN_2 = 'account_login' AUTH_ENTRY_LOGIN_2 = 'account_login'
AUTH_ENTRY_REGISTER_2 = 'account_register' AUTH_ENTRY_REGISTER_2 = 'account_register'
AUTH_ENTRY_API = 'api' # Entry modes into the authentication process by a remote API call (as opposed to a browser session).
AUTH_ENTRY_LOGIN_API = 'login_api'
AUTH_ENTRY_REGISTER_API = 'register_api'
def is_api(auth_entry):
"""Returns whether the auth entry point is via an API call."""
return (auth_entry == AUTH_ENTRY_LOGIN_API) or (auth_entry == AUTH_ENTRY_REGISTER_API)
# URLs associated with auth entry points # URLs associated with auth entry points
# These are used to request additional user information # These are used to request additional user information
...@@ -157,7 +167,8 @@ _AUTH_ENTRY_CHOICES = frozenset([ ...@@ -157,7 +167,8 @@ _AUTH_ENTRY_CHOICES = frozenset([
AUTH_ENTRY_LOGIN_2, AUTH_ENTRY_LOGIN_2,
AUTH_ENTRY_REGISTER_2, AUTH_ENTRY_REGISTER_2,
AUTH_ENTRY_API, AUTH_ENTRY_LOGIN_API,
AUTH_ENTRY_REGISTER_API,
]) ])
_DEFAULT_RANDOM_PASSWORD_LENGTH = 12 _DEFAULT_RANDOM_PASSWORD_LENGTH = 12
...@@ -436,39 +447,11 @@ def parse_query_params(strategy, response, *args, **kwargs): ...@@ -436,39 +447,11 @@ def parse_query_params(strategy, response, *args, **kwargs):
if not (auth_entry and auth_entry in _AUTH_ENTRY_CHOICES): if not (auth_entry and auth_entry in _AUTH_ENTRY_CHOICES):
raise AuthEntryError(strategy.backend, 'auth_entry missing or invalid') raise AuthEntryError(strategy.backend, 'auth_entry missing or invalid')
# Note: We expect only one member of this dictionary to be `True` at any return {'auth_entry': auth_entry}
# given time. If something changes this convention in the future, please look
# at the `login_analytics` function in this file as well to ensure logging
# is still done properly
return {
# Whether the auth pipeline entered from /dashboard.
'is_dashboard': auth_entry == AUTH_ENTRY_DASHBOARD,
# Whether the auth pipeline entered from /login.
'is_login': auth_entry in [AUTH_ENTRY_LOGIN, AUTH_ENTRY_LOGIN_2],
# Whether the auth pipeline entered from /register.
'is_register': auth_entry in [AUTH_ENTRY_REGISTER, AUTH_ENTRY_REGISTER_2],
# Whether the auth pipeline entered from an API
'is_api': auth_entry == AUTH_ENTRY_API,
}
@partial.partial @partial.partial
def ensure_user_information( def ensure_user_information(strategy, auth_entry, user=None, *args, **kwargs):
strategy,
details,
response,
uid,
is_dashboard=None,
is_login=None,
is_profile=None,
is_register=None,
is_login_2=None,
is_register_2=None,
is_api=None,
user=None,
*args,
**kwargs
):
""" """
Ensure that we have the necessary information about a user (either an Ensure that we have the necessary information about a user (either an
existing account or registration data) to proceed with the pipeline. existing account or registration data) to proceed with the pipeline.
...@@ -485,32 +468,32 @@ def ensure_user_information( ...@@ -485,32 +468,32 @@ def ensure_user_information(
# It is important that we always execute the entire pipeline. Even if # It is important that we always execute the entire pipeline. Even if
# behavior appears correct without executing a step, it means important # behavior appears correct without executing a step, it means important
# invariants have been violated and future misbehavior is likely. # invariants have been violated and future misbehavior is likely.
user_inactive = user and not user.is_active def dispatch_to_login():
user_unset = user is None """Redirects to the login page."""
return redirect(_create_redirect_url(AUTH_DISPATCH_URLS[AUTH_ENTRY_LOGIN], strategy))
dispatch_to_login = ( def dispatch_to_register():
((is_login or is_login_2) and (user_unset or user_inactive)) """Redirects to the registration page."""
or return redirect(_create_redirect_url(AUTH_DISPATCH_URLS[AUTH_ENTRY_REGISTER], strategy))
((is_register or is_register_2) and user_inactive)
)
dispatch_to_register = (is_register or is_register_2) and user_unset
reject_api_request = is_api and (user_unset or user_inactive)
if reject_api_request: user_inactive = user and not user.is_active
# Content doesn't matter; we just want to exit the pipeline
return HttpResponseBadRequest()
if is_dashboard or is_profile: if auth_entry in [AUTH_ENTRY_LOGIN_API, AUTH_ENTRY_REGISTER_API]:
return if not user:
return HttpResponseBadRequest()
# If the user has a linked account, but has not yet activated elif auth_entry in [AUTH_ENTRY_LOGIN, AUTH_ENTRY_LOGIN_2]:
# we should send them to the login page. The login page if not user or user_inactive:
# will tell them that they need to activate their account. return dispatch_to_login()
if dispatch_to_login:
return redirect(_create_redirect_url(AUTH_DISPATCH_URLS[AUTH_ENTRY_LOGIN], strategy))
if dispatch_to_register: elif auth_entry in [AUTH_ENTRY_REGISTER, AUTH_ENTRY_REGISTER_2]:
return redirect(_create_redirect_url(AUTH_DISPATCH_URLS[AUTH_ENTRY_REGISTER], strategy)) if not user:
return dispatch_to_register()
elif user_inactive:
# If the user has a linked account, but has not yet activated
# we should send them to the login page. The login page
# will tell them that they need to activate their account.
return dispatch_to_login()
def _create_redirect_url(url, strategy): def _create_redirect_url(url, strategy):
...@@ -543,7 +526,7 @@ def _create_redirect_url(url, strategy): ...@@ -543,7 +526,7 @@ def _create_redirect_url(url, strategy):
@partial.partial @partial.partial
def set_logged_in_cookie(backend=None, user=None, request=None, is_api=None, *args, **kwargs): def set_logged_in_cookie(backend=None, user=None, request=None, auth_entry=None, *args, **kwargs):
"""This pipeline step sets the "logged in" cookie for authenticated users. """This pipeline step sets the "logged in" cookie for authenticated users.
Some installations have a marketing site front-end separate from Some installations have a marketing site front-end separate from
...@@ -568,7 +551,7 @@ def set_logged_in_cookie(backend=None, user=None, request=None, is_api=None, *ar ...@@ -568,7 +551,7 @@ def set_logged_in_cookie(backend=None, user=None, request=None, is_api=None, *ar
to the next pipeline step. to the next pipeline step.
""" """
if user is not None and user.is_authenticated() and not is_api: if not is_api(auth_entry) and user is not None and user.is_authenticated():
if request is not None: if request is not None:
# Check that the cookie isn't already set. # Check that the cookie isn't already set.
# This ensures that we allow the user to continue to the next # This ensures that we allow the user to continue to the next
...@@ -588,27 +571,14 @@ def set_logged_in_cookie(backend=None, user=None, request=None, is_api=None, *ar ...@@ -588,27 +571,14 @@ def set_logged_in_cookie(backend=None, user=None, request=None, is_api=None, *ar
@partial.partial @partial.partial
def login_analytics(strategy, *args, **kwargs): def login_analytics(strategy, auth_entry, *args, **kwargs):
""" Sends login info to Segment.io """ """ Sends login info to Segment.io """
event_name = None
action_to_event_name = { event_name = None
'is_login': 'edx.bi.user.account.authenticated', if auth_entry in [AUTH_ENTRY_LOGIN, AUTH_ENTRY_LOGIN_2]:
'is_dashboard': 'edx.bi.user.account.linked', event_name = 'edx.bi.user.account.authenticated'
'is_profile': 'edx.bi.user.account.linked', elif auth_entry in [AUTH_ENTRY_DASHBOARD]:
event_name = 'edx.bi.user.account.linked'
# Backwards compatibility: during an A/B test for the combined
# login/registration form, we introduced a new login end-point.
# Since users may continue to have this in their sessions after
# the test concludes, we need to continue accepting this action.
'is_login_2': 'edx.bi.user.account.authenticated',
}
# Note: we assume only one of the `action` kwargs (is_dashboard, is_login) to be
# `True` at any given time
for action in action_to_event_name.keys():
if kwargs.get(action):
event_name = action_to_event_name[action]
if event_name is not None: if event_name is not None:
tracking_context = tracker.get_tracker().resolve_context() tracking_context = tracker.get_tracker().resolve_context()
...@@ -629,7 +599,7 @@ def login_analytics(strategy, *args, **kwargs): ...@@ -629,7 +599,7 @@ def login_analytics(strategy, *args, **kwargs):
@partial.partial @partial.partial
def change_enrollment(strategy, user=None, is_dashboard=False, *args, **kwargs): def change_enrollment(strategy, auth_entry=None, user=None, *args, **kwargs):
"""Enroll a user in a course. """Enroll a user in a course.
If a user entered the authentication flow when trying to enroll If a user entered the authentication flow when trying to enroll
...@@ -649,10 +619,8 @@ def change_enrollment(strategy, user=None, is_dashboard=False, *args, **kwargs): ...@@ -649,10 +619,8 @@ def change_enrollment(strategy, user=None, is_dashboard=False, *args, **kwargs):
(configured using the ?next parameter to the third party auth login url). (configured using the ?next parameter to the third party auth login url).
Keyword Arguments: Keyword Arguments:
auth_entry: The entry mode into the pipeline.
user (User): The user being authenticated. user (User): The user being authenticated.
is_dashboard (boolean): Whether the user entered the authentication
pipeline from the "link account" button on the student dashboard.
""" """
# We skip enrollment if the user entered the flow from the "link account" # We skip enrollment if the user entered the flow from the "link account"
# button on the student dashboard. At this point, either: # button on the student dashboard. At this point, either:
...@@ -665,7 +633,7 @@ def change_enrollment(strategy, user=None, is_dashboard=False, *args, **kwargs): ...@@ -665,7 +633,7 @@ def change_enrollment(strategy, user=None, is_dashboard=False, *args, **kwargs):
# args when sending users to this page, successfully authenticating through this page # args when sending users to this page, successfully authenticating through this page
# would also enroll the student in the course. # would also enroll the student in the course.
enroll_course_id = strategy.session_get('enroll_course_id') enroll_course_id = strategy.session_get('enroll_course_id')
if enroll_course_id and not is_dashboard: if enroll_course_id and auth_entry != AUTH_ENTRY_DASHBOARD:
course_id = CourseKey.from_string(enroll_course_id) course_id = CourseKey.from_string(enroll_course_id)
modes = CourseMode.modes_for_course_dict(course_id) modes = CourseMode.modes_for_course_dict(course_id)
...@@ -713,11 +681,34 @@ def change_enrollment(strategy, user=None, is_dashboard=False, *args, **kwargs): ...@@ -713,11 +681,34 @@ def change_enrollment(strategy, user=None, is_dashboard=False, *args, **kwargs):
except ( except (
CourseDoesNotExistException, CourseDoesNotExistException,
ItemAlreadyInCartException, ItemAlreadyInCartException,
AlreadyEnrolledInCourseException AlreadyEnrolledInCourseException,
): ):
pass pass
# It's more important to complete login than to # It's more important to complete login than to
# ensure that the course was added to the shopping cart. # ensure that the course was added to the shopping cart.
# Log errors, but don't stop the authentication pipeline. # Log errors, but don't stop the authentication pipeline.
except Exception as ex: except Exception as ex: # pylint: disable=broad-except
logger.exception(ex) logger.exception(ex)
@partial.partial
def associate_by_email_if_login_api(auth_entry, strategy, details, user, *args, **kwargs):
"""
This pipeline step associates the current social auth with the user with the
same email address in the database. It defers to the social library's associate_by_email
implementation, which verifies that only a single database user is associated with the email.
This association is done ONLY if the user entered the pipeline through a LOGIN API.
"""
if auth_entry == AUTH_ENTRY_LOGIN_API:
association_response = associate_by_email(strategy, details, user, *args, **kwargs)
if (
association_response and
association_response.get('user') and
association_response['user'].is_active
):
# Only return the user matched by email if their email has been activated.
# Otherwise, an illegitimate user can create an account with another user's
# email address and the legitimate user would now login to the illegitimate
# account.
return association_response
...@@ -103,6 +103,7 @@ def _set_global_settings(django_settings): ...@@ -103,6 +103,7 @@ def _set_global_settings(django_settings):
'social.pipeline.social_auth.social_uid', 'social.pipeline.social_auth.social_uid',
'social.pipeline.social_auth.auth_allowed', 'social.pipeline.social_auth.auth_allowed',
'social.pipeline.social_auth.social_user', 'social.pipeline.social_auth.social_user',
'third_party_auth.pipeline.associate_by_email_if_login_api',
'social.pipeline.user.get_username', 'social.pipeline.user.get_username',
'third_party_auth.pipeline.ensure_user_information', 'third_party_auth.pipeline.ensure_user_information',
'social.pipeline.user.create_user', 'social.pipeline.user.create_user',
......
...@@ -147,7 +147,7 @@ class PipelineEnrollmentTest(UrlResetMixin, ModuleStoreTestCase): ...@@ -147,7 +147,7 @@ class PipelineEnrollmentTest(UrlResetMixin, ModuleStoreTestCase):
# Simulate completing the pipeline from the student dashboard's # Simulate completing the pipeline from the student dashboard's
# "link account" button. # "link account" button.
result = pipeline.change_enrollment(strategy, 1, user=self.user, is_dashboard=True) # pylint: disable=assignment-from-no-return,redundant-keyword-arg result = pipeline.change_enrollment(strategy, 1, user=self.user, auth_entry=pipeline.AUTH_ENTRY_DASHBOARD) # pylint: disable=assignment-from-no-return,redundant-keyword-arg
# Verify that we were NOT enrolled # Verify that we were NOT enrolled
self.assertEqual(result, {}) self.assertEqual(result, {})
...@@ -165,7 +165,7 @@ class PipelineEnrollmentTest(UrlResetMixin, ModuleStoreTestCase): ...@@ -165,7 +165,7 @@ class PipelineEnrollmentTest(UrlResetMixin, ModuleStoreTestCase):
details=None, details=None,
response=None, response=None,
uid=None, uid=None,
is_register=True, auth_entry=pipeline.AUTH_ENTRY_REGISTER,
backend=backend backend=backend
) )
self.assertIsNotNone(response) self.assertIsNotNone(response)
......
"""Common utility for testing third party oauth2 features."""
import json
import httpretty
from provider.constants import PUBLIC
from provider.oauth2.models import Client
from social.apps.django_app.default.models import UserSocialAuth
from student.tests.factories import UserFactory
@httpretty.activate
class ThirdPartyOAuthTestMixin(object):
"""
Mixin with tests for third party oauth views. A TestCase that includes
this must define the following:
BACKEND: The name of the backend from python-social-auth
USER_URL: The URL of the endpoint that the backend retrieves user data from
UID_FIELD: The field in the user data that the backend uses as the user id
"""
def setUp(self, create_user=True):
super(ThirdPartyOAuthTestMixin, self).setUp()
self.social_uid = "test_social_uid"
self.access_token = "test_access_token"
self.client_id = "test_client_id"
self.oauth_client = Client.objects.create(
client_id=self.client_id,
client_type=PUBLIC
)
if create_user:
self.user = UserFactory()
UserSocialAuth.objects.create(user=self.user, provider=self.BACKEND, uid=self.social_uid)
def _setup_provider_response(self, success=False, email=''):
"""
Register a mock response for the third party user information endpoint;
success indicates whether the response status code should be 200 or 400
"""
if success:
status = 200
response = {self.UID_FIELD: self.social_uid}
if email:
response.update({'email': email})
body = json.dumps(response)
else:
status = 400
body = json.dumps({})
self._setup_provider_response_with_body(status, body)
def _setup_provider_response_with_body(self, status, body):
"""
Register a mock response for the third party user information endpoint with given status and body.
"""
httpretty.register_uri(
httpretty.GET,
self.USER_URL,
body=body,
status=status,
content_type="application/json"
)
class ThirdPartyOAuthTestMixinFacebook(object):
"""Tests oauth with the Facebook backend"""
BACKEND = "facebook"
USER_URL = "https://graph.facebook.com/me"
# In facebook responses, the "id" field is used as the user's identifier
UID_FIELD = "id"
class ThirdPartyOAuthTestMixinGoogle(object):
"""Tests oauth with the Google backend"""
BACKEND = "google-oauth2"
USER_URL = "https://www.googleapis.com/oauth2/v1/userinfo"
# In google-oauth2 responses, the "email" field is used as the user's identifier
UID_FIELD = "email"
...@@ -596,6 +596,8 @@ if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): ...@@ -596,6 +596,8 @@ if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
oauth_exchange.views.AccessTokenExchangeView.as_view(), oauth_exchange.views.AccessTokenExchangeView.as_view(),
name="exchange_access_token" name="exchange_access_token"
), ),
# NOTE: The following login_oauth_token endpoint is DEPRECATED.
# Please use the exchange_access_token endpoint instead.
url(r'^login_oauth_token/(?P<backend>[^/]+)/$', 'student.views.login_oauth_token'), url(r'^login_oauth_token/(?P<backend>[^/]+)/$', 'student.views.login_oauth_token'),
) )
......
...@@ -4,26 +4,33 @@ import datetime ...@@ -4,26 +4,33 @@ import datetime
import base64 import base64
import json import json
import re import re
from unittest import skipUnless, SkipTest
import ddt
import httpretty
from pytz import UTC
import mock
from django.conf import settings from django.conf import settings
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.core import mail from django.core import mail
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.test import TestCase from django.test import TestCase
from django.test.testcases import TransactionTestCase
from django.test.utils import override_settings from django.test.utils import override_settings
from unittest import skipUnless
import ddt
from pytz import UTC
import mock
from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from student.tests.factories import UserFactory from social.apps.django_app.default.models import UserSocialAuth
from unittest import SkipTest
from django_comment_common import models
from opaque_keys.edx.locations import SlashSeparatedCourseKey from opaque_keys.edx.locations import SlashSeparatedCourseKey
from third_party_auth.tests.testutil import simulate_running_pipeline
from django_comment_common import models
from student.tests.factories import UserFactory
from third_party_auth.tests.testutil import simulate_running_pipeline
from third_party_auth.tests.utils import (
ThirdPartyOAuthTestMixin, ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle
)
from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from ..accounts.api import get_account_settings from ..accounts.api import get_account_settings
from ..accounts import ( from ..accounts import (
NAME_MAX_LENGTH, EMAIL_MIN_LENGTH, EMAIL_MAX_LENGTH, PASSWORD_MIN_LENGTH, PASSWORD_MAX_LENGTH, NAME_MAX_LENGTH, EMAIL_MIN_LENGTH, EMAIL_MAX_LENGTH, PASSWORD_MIN_LENGTH, PASSWORD_MAX_LENGTH,
...@@ -1546,6 +1553,148 @@ class RegistrationViewTest(ApiTestCase): ...@@ -1546,6 +1553,148 @@ class RegistrationViewTest(ApiTestCase):
) )
@httpretty.activate
@ddt.ddt
class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin):
"""
Tests for the User API registration endpoint with 3rd party authentication.
"""
def setUp(self):
super(ThirdPartyRegistrationTestMixin, self).setUp(create_user=False)
self.url = reverse('user_api_registration')
def data(self, user=None):
"""Returns the request data for the endpoint."""
return {
"provider": self.BACKEND,
"access_token": self.access_token,
"client_id": self.client_id,
"honor_code": "true",
"country": "US",
"username": user.username if user else "test_username",
"name": user.first_name if user else "test name",
"email": user.email if user else "test@test.com",
}
def _assert_existing_user_error(self, response):
"""Assert that the given response was an error with the given status_code and error code."""
self.assertEqual(response.status_code, 409)
errors = json.loads(response.content)
for conflict_attribute in ["username", "email"]:
self.assertIn(conflict_attribute, errors)
self.assertIn("belongs to an existing account", errors[conflict_attribute][0]["user_message"])
self.assertNotIn("partial_pipeline", self.client.session)
def _assert_access_token_error(self, response, expected_error_message):
"""Assert that the given response was an error for the access_token field with the given error message."""
self.assertEqual(response.status_code, 400)
response_json = json.loads(response.content)
self.assertEqual(
response_json,
{"access_token": [{"user_message": expected_error_message}]}
)
self.assertNotIn("partial_pipeline", self.client.session)
def _verify_user_existence(self, user_exists, social_link_exists, user_is_active=None, username=None):
"""Verifies whether the user object exists."""
users = User.objects.filter(username=(username if username else "test_username"))
self.assertEquals(users.exists(), user_exists)
if user_exists:
self.assertEquals(users[0].is_active, user_is_active)
self.assertEqual(
UserSocialAuth.objects.filter(user=users[0], provider=self.BACKEND).exists(),
social_link_exists
)
else:
self.assertEquals(UserSocialAuth.objects.count(), 0)
def test_success(self):
self._verify_user_existence(user_exists=False, social_link_exists=False)
self._setup_provider_response(success=True)
response = self.client.post(self.url, self.data())
self.assertEqual(response.status_code, 200)
self._verify_user_existence(user_exists=True, social_link_exists=True, user_is_active=False)
def test_unlinked_active_user(self):
user = UserFactory()
response = self.client.post(self.url, self.data(user))
self._assert_existing_user_error(response)
self._verify_user_existence(
user_exists=True, social_link_exists=False, user_is_active=True, username=user.username
)
def test_unlinked_inactive_user(self):
user = UserFactory(is_active=False)
response = self.client.post(self.url, self.data(user))
self._assert_existing_user_error(response)
self._verify_user_existence(
user_exists=True, social_link_exists=False, user_is_active=False, username=user.username
)
def test_user_already_registered(self):
self._setup_provider_response(success=True)
user = UserFactory()
UserSocialAuth.objects.create(user=user, provider=self.BACKEND, uid=self.social_uid)
response = self.client.post(self.url, self.data(user))
self._assert_existing_user_error(response)
self._verify_user_existence(
user_exists=True, social_link_exists=True, user_is_active=True, username=user.username
)
def test_social_user_conflict(self):
self._setup_provider_response(success=True)
user = UserFactory()
UserSocialAuth.objects.create(user=user, provider=self.BACKEND, uid=self.social_uid)
response = self.client.post(self.url, self.data())
self._assert_access_token_error(response, "The provided access_token is already associated with another user.")
self._verify_user_existence(
user_exists=True, social_link_exists=True, user_is_active=True, username=user.username
)
def test_invalid_token(self):
self._setup_provider_response(success=False)
response = self.client.post(self.url, self.data())
self._assert_access_token_error(response, "The provided access_token is not valid.")
self._verify_user_existence(user_exists=False, social_link_exists=False)
def test_missing_token(self):
data = self.data()
data.pop("access_token")
response = self.client.post(self.url, data)
self._assert_access_token_error(
response,
"An access_token is required when passing value ({}) for provider.".format(self.BACKEND)
)
self._verify_user_existence(user_exists=False, social_link_exists=False)
@skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
class TestFacebookRegistrationView(
ThirdPartyRegistrationTestMixin, ThirdPartyOAuthTestMixinFacebook, TransactionTestCase
):
"""Tests the User API registration endpoint with Facebook authentication."""
def test_social_auth_exception(self):
"""
According to the do_auth method in social.backends.facebook.py,
the Facebook API sometimes responds back a JSON with just False as value.
"""
self._setup_provider_response_with_body(200, json.dumps("false"))
response = self.client.post(self.url, self.data())
self._assert_access_token_error(response, "The provided access_token is not valid.")
self._verify_user_existence(user_exists=False, social_link_exists=False)
@skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
class TestGoogleRegistrationView(
ThirdPartyRegistrationTestMixin, ThirdPartyOAuthTestMixinGoogle, TransactionTestCase
):
"""Tests the User API registration endpoint with Google authentication."""
pass
@ddt.ddt @ddt.ddt
class UpdateEmailOptInTestCase(ApiTestCase, ModuleStoreTestCase): class UpdateEmailOptInTestCase(ApiTestCase, ModuleStoreTestCase):
"""Tests the UpdateEmailOptInPreference view. """ """Tests the UpdateEmailOptInPreference view. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment