Commit 8b65ca17 by Uman Shahzad

Migrate to latest, split python-social-auth.

PSA was monolothic, now split, with new features, like
a DB-backed partial pipeline. FB OAuth2 version also upped.

Partial pipelines don't get cleared except when necessary.
They persist for special cases like change of browser while
still mid-pipeline (i.e. email validation step).

Refactor, cleanup, and update of a lot of small things as well.

PLEASE NOTE the new `social_auth_partial` table.
parent 4dd4f979
...@@ -14,7 +14,7 @@ from django.test import TestCase ...@@ -14,7 +14,7 @@ from django.test import TestCase
from django.test.client import Client from django.test.client import Client
from django.test.utils import override_settings from django.test.utils import override_settings
from mock import patch from mock import patch
from social.apps.django_app.default.models import UserSocialAuth from social_django.models import UserSocialAuth
from openedx.core.djangoapps.external_auth.models import ExternalAuthMap from openedx.core.djangoapps.external_auth.models import ExternalAuthMap
from openedx.core.djangolib.testing.utils import CacheIsolationTestCase from openedx.core.djangolib.testing.utils import CacheIsolationTestCase
...@@ -540,7 +540,6 @@ class LoginOAuthTokenMixin(ThirdPartyOAuthTestMixin): ...@@ -540,7 +540,6 @@ class LoginOAuthTokenMixin(ThirdPartyOAuthTestMixin):
"""Assert that the given response was a 400 with the given error code""" """Assert that the given response was a 400 with the given error code"""
self.assertEqual(response.status_code, status_code) self.assertEqual(response.status_code, status_code)
self.assertEqual(json.loads(response.content), {"error": error}) self.assertEqual(json.loads(response.content), {"error": error})
self.assertNotIn("partial_pipeline", self.client.session)
def test_success(self): def test_success(self):
self._setup_provider_response(success=True) self._setup_provider_response(success=True)
......
...@@ -46,9 +46,9 @@ from provider.oauth2.models import Client ...@@ -46,9 +46,9 @@ from provider.oauth2.models import Client
from pytz import UTC from pytz import UTC
from ratelimitbackend.exceptions import RateLimitException from ratelimitbackend.exceptions import RateLimitException
from requests import HTTPError from requests import HTTPError
from social.apps.django_app import utils as social_utils from social_django import utils as social_utils
from social.backends import oauth as social_oauth from social_core.backends import oauth as social_oauth
from social.exceptions import AuthAlreadyAssociated, AuthException from social_core.exceptions import AuthAlreadyAssociated, AuthException
import dogstats_wrapper as dog_stats_api import dogstats_wrapper as dog_stats_api
import openedx.core.djangoapps.external_auth.views import openedx.core.djangoapps.external_auth.views
...@@ -1519,7 +1519,7 @@ def login_user(request, error=""): # pylint: disable=too-many-statements,unused ...@@ -1519,7 +1519,7 @@ def login_user(request, error=""): # pylint: disable=too-many-statements,unused
@csrf_exempt @csrf_exempt
@require_POST @require_POST
@social_utils.strategy("social:complete") @social_utils.psa("social:complete")
def login_oauth_token(request, backend): def login_oauth_token(request, backend):
""" """
Authenticate the client using an OAuth access token by using the token to Authenticate the client using an OAuth access token by using the token to
...@@ -1534,8 +1534,9 @@ def login_oauth_token(request, backend): ...@@ -1534,8 +1534,9 @@ def login_oauth_token(request, backend):
# Tell third party auth pipeline that this is an API call # Tell third party auth pipeline that this is an API call
request.session[pipeline.AUTH_ENTRY_KEY] = pipeline.AUTH_ENTRY_LOGIN_API request.session[pipeline.AUTH_ENTRY_KEY] = pipeline.AUTH_ENTRY_LOGIN_API
user = None user = None
access_token = request.POST["access_token"]
try: try:
user = backend.do_auth(request.POST["access_token"]) user = backend.do_auth(access_token)
except (HTTPError, AuthException): except (HTTPError, AuthException):
pass pass
# do_auth can return a non-User object if it fails # do_auth can return a non-User object if it fails
...@@ -1544,7 +1545,7 @@ def login_oauth_token(request, backend): ...@@ -1544,7 +1545,7 @@ def login_oauth_token(request, backend):
return JsonResponse(status=204) return JsonResponse(status=204)
else: else:
# Ensure user does not re-enter the pipeline # Ensure user does not re-enter the pipeline
request.social_strategy.clean_partial_pipeline() request.social_strategy.clean_partial_pipeline(access_token)
return JsonResponse({"error": "invalid_token"}, status=401) return JsonResponse({"error": "invalid_token"}, status=401)
else: else:
return JsonResponse({"error": "invalid_request"}, status=400) return JsonResponse({"error": "invalid_request"}, status=400)
...@@ -1889,7 +1890,7 @@ def create_account_with_params(request, params): ...@@ -1889,7 +1890,7 @@ def create_account_with_params(request, params):
error_message = _("The provided access_token is not valid.") error_message = _("The provided access_token is not valid.")
if not pipeline_user or not isinstance(pipeline_user, User): if not pipeline_user or not isinstance(pipeline_user, User):
# Ensure user does not re-enter the pipeline # Ensure user does not re-enter the pipeline
request.social_strategy.clean_partial_pipeline() request.social_strategy.clean_partial_pipeline(social_access_token)
raise ValidationError({'access_token': [error_message]}) raise ValidationError({'access_token': [error_message]})
# Perform operations that are non-critical parts of account creation # Perform operations that are non-critical parts of account creation
......
...@@ -14,7 +14,7 @@ from openedx.core.lib.api.permissions import ApiKeyHeaderPermission ...@@ -14,7 +14,7 @@ from openedx.core.lib.api.permissions import ApiKeyHeaderPermission
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from django.conf import settings from django.conf import settings
from django.test.utils import override_settings from django.test.utils import override_settings
from social.apps.django_app.default.models import UserSocialAuth from social_django.models import UserSocialAuth
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
from third_party_auth.api.permissions import ThirdPartyAuthProviderApiPermission from third_party_auth.api.permissions import ThirdPartyAuthProviderApiPermission
......
...@@ -9,7 +9,7 @@ from rest_framework.generics import ListAPIView ...@@ -9,7 +9,7 @@ from rest_framework.generics import ListAPIView
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework_oauth.authentication import OAuth2Authentication from rest_framework_oauth.authentication import OAuth2Authentication
from social.apps.django_app.default.models import UserSocialAuth from social_django.models import UserSocialAuth
from openedx.core.lib.api.authentication import ( from openedx.core.lib.api.authentication import (
OAuth2AuthenticationAllowInactiveUser, OAuth2AuthenticationAllowInactiveUser,
......
""" """
DummyBackend: A fake Third Party Auth provider for testing & development purposes. DummyBackend: A fake Third Party Auth provider for testing & development purposes.
""" """
from social.backends.oauth import BaseOAuth2 from social_core.backends.oauth import BaseOAuth2
from social.exceptions import AuthFailed from social_core.exceptions import AuthFailed
class DummyBackend(BaseOAuth2): # pylint: disable=abstract-method class DummyBackend(BaseOAuth2): # pylint: disable=abstract-method
......
...@@ -14,9 +14,9 @@ from oauthlib.oauth1.rfc5849.signature import ( ...@@ -14,9 +14,9 @@ from oauthlib.oauth1.rfc5849.signature import (
normalize_parameters, normalize_parameters,
sign_hmac_sha1 sign_hmac_sha1
) )
from social.backends.base import BaseAuth from social_core.backends.base import BaseAuth
from social.exceptions import AuthFailed from social_core.exceptions import AuthFailed
from social.utils import sanitize_redirect from social_core.utils import sanitize_redirect
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -34,16 +34,13 @@ class LTIAuthBackend(BaseAuth): ...@@ -34,16 +34,13 @@ class LTIAuthBackend(BaseAuth):
""" """
Prepare to handle a login request. Prepare to handle a login request.
This method replaces social.actions.do_auth and must be kept in sync This method replaces social_core.actions.do_auth and must be kept in sync
with any upstream changes in that method. In the current version of with any upstream changes in that method. In the current version of
the upstream, this means replacing the logic to populate the session the upstream, this means replacing the logic to populate the session
from request parameters, and not calling backend.start() to avoid from request parameters, and not calling backend.start() to avoid
an unwanted redirect to the non-existent login page. an unwanted redirect to the non-existent login page.
""" """
# Clean any partial pipeline data
self.strategy.clean_partial_pipeline()
# Save validated LTI parameters (or None if invalid or not submitted) # Save validated LTI parameters (or None if invalid or not submitted)
validated_lti_params = self.get_validated_lti_params(self.strategy) validated_lti_params = self.get_validated_lti_params(self.strategy)
......
"""Middleware classes for third_party_auth.""" """Middleware classes for third_party_auth."""
from social.apps.django_app.middleware import SocialAuthExceptionMiddleware from social_django.middleware import SocialAuthExceptionMiddleware
from . import pipeline from . import pipeline
...@@ -45,12 +45,10 @@ class PipelineQuarantineMiddleware(object): ...@@ -45,12 +45,10 @@ class PipelineQuarantineMiddleware(object):
collection of user consent for sharing data with a linked third-party collection of user consent for sharing data with a linked third-party
authentication provider. authentication provider.
""" """
running_pipeline = request.session.get('partial_pipeline') if not pipeline.running(request):
if not running_pipeline:
return return
view_module = view_func.__module__ view_module = view_func.__module__
quarantined_modules = request.session.get('third_party_auth_quarantined_modules', None) quarantined_modules = request.session.get('third_party_auth_quarantined_modules')
if quarantined_modules is not None and not any(view_module.startswith(mod) for mod in quarantined_modules): if quarantined_modules is not None and not any(view_module.startswith(mod) for mod in quarantined_modules):
request.session.flush() request.session.flush()
...@@ -17,11 +17,11 @@ from django.utils import timezone ...@@ -17,11 +17,11 @@ from django.utils import timezone
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from provider.oauth2.models import Client from provider.oauth2.models import Client
from provider.utils import long_token from provider.utils import long_token
from social.backends.base import BaseAuth from social_core.backends.base import BaseAuth
from social.backends.oauth import OAuthAuth from social_core.backends.oauth import OAuthAuth
from social.backends.saml import SAMLAuth, SAMLIdentityProvider from social_core.backends.saml import SAMLAuth, SAMLIdentityProvider
from social.exceptions import SocialAuthBaseException from social_core.exceptions import SocialAuthBaseException
from social.utils import module_member from social_core.utils import module_member
from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers
from openedx.core.djangoapps.theming.helpers import get_current_request from openedx.core.djangoapps.theming.helpers import get_current_request
...@@ -581,7 +581,7 @@ class SAMLConfiguration(ConfigurationModel): ...@@ -581,7 +581,7 @@ class SAMLConfiguration(ConfigurationModel):
return getattr(settings, 'SOCIAL_AUTH_SAML_SP_PRIVATE_KEY', '') return getattr(settings, 'SOCIAL_AUTH_SAML_SP_PRIVATE_KEY', '')
other_config = { other_config = {
# These defaults can be overriden by self.other_config_str # These defaults can be overriden by self.other_config_str
"EXTRA_DATA": ["attributes"], # Save all attribute values the IdP sends into the UserSocialAuth table "GET_ALL_EXTRA_DATA": True, # Save all attribute values the IdP sends into the UserSocialAuth table
"TECHNICAL_CONTACT": DEFAULT_SAML_CONTACT, "TECHNICAL_CONTACT": DEFAULT_SAML_CONTACT,
"SUPPORT_CONTACT": DEFAULT_SAML_CONTACT, "SUPPORT_CONTACT": DEFAULT_SAML_CONTACT,
} }
......
...@@ -54,7 +54,7 @@ This is surprising but important behavior, since it allows a single function in ...@@ -54,7 +54,7 @@ This is surprising but important behavior, since it allows a single function in
the pipeline to consolidate all the operations needed to establish invariants the pipeline to consolidate all the operations needed to establish invariants
rather than spreading them across two functions in the pipeline. rather than spreading them across two functions in the pipeline.
See http://psa.matiasaguirre.net/docs/pipeline.html for more docs. See http://python-social-auth.readthedocs.io/en/latest/pipeline.html for more docs.
""" """
import base64 import base64
...@@ -73,11 +73,10 @@ from django.contrib.auth.models import User ...@@ -73,11 +73,10 @@ from django.contrib.auth.models import User
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.http import HttpResponseBadRequest from django.http import HttpResponseBadRequest
from django.shortcuts import redirect from django.shortcuts import redirect
from social.apps.django_app.default import models import social_django
from social.apps.django_app.default.models import UserSocialAuth from social_core.exceptions import AuthException
from social.exceptions import AuthException from social_core.pipeline import partial
from social.pipeline import partial from social_core.pipeline.social_auth import associate_by_email
from social.pipeline.social_auth import associate_by_email
import student import student
from eventtracking import tracker from eventtracking import tracker
...@@ -195,8 +194,14 @@ class ProviderUserState(object): ...@@ -195,8 +194,14 @@ class ProviderUserState(object):
def get(request): def get(request):
"""Gets the running pipeline from the passed request.""" """Gets the running pipeline's data from the passed request."""
return request.session.get('partial_pipeline') strategy = social_django.utils.load_strategy(request)
token = strategy.session_get('partial_pipeline_token')
partial_object = strategy.partial_load(token)
pipeline_data = None
if partial_object:
pipeline_data = {'kwargs': partial_object.kwargs, 'backend': partial_object.backend}
return pipeline_data
def get_real_social_auth_object(request): def get_real_social_auth_object(request):
...@@ -209,7 +214,7 @@ def get_real_social_auth_object(request): ...@@ -209,7 +214,7 @@ def get_real_social_auth_object(request):
if running_pipeline and 'social' in running_pipeline['kwargs']: if running_pipeline and 'social' in running_pipeline['kwargs']:
social = running_pipeline['kwargs']['social'] social = running_pipeline['kwargs']['social']
if isinstance(social, dict): if isinstance(social, dict):
social = UserSocialAuth.objects.get(uid=social.get('uid', '')) social = social_django.models.UserSocialAuth.objects.get(**social)
return social return social
...@@ -253,7 +258,7 @@ def get_authenticated_user(auth_provider, username, uid): ...@@ -253,7 +258,7 @@ def get_authenticated_user(auth_provider, username, uid):
user has no social auth associated with the given backend. user has no social auth associated with the given backend.
AssertionError: if the user is not authenticated. AssertionError: if the user is not authenticated.
""" """
match = models.DjangoStorage.user.get_social_auth(provider=auth_provider.backend_name, uid=uid) match = social_django.models.DjangoStorage.user.get_social_auth(provider=auth_provider.backend_name, uid=uid)
if not match or match.user.username != username: if not match or match.user.username != username:
raise User.DoesNotExist raise User.DoesNotExist
...@@ -319,7 +324,7 @@ def get_disconnect_url(provider_id, association_id): ...@@ -319,7 +324,7 @@ def get_disconnect_url(provider_id, association_id):
"""Gets URL for the endpoint that starts the disconnect pipeline. """Gets URL for the endpoint that starts the disconnect pipeline.
Args: Args:
provider_id: string identifier of the models.ProviderConfig child you want provider_id: string identifier of the social_django.models.ProviderConfig child you want
to disconnect from. to disconnect from.
association_id: int. Optional ID of a specific row in the UserSocialAuth association_id: int. Optional ID of a specific row in the UserSocialAuth
table to disconnect (useful if multiple providers use a common backend) table to disconnect (useful if multiple providers use a common backend)
...@@ -341,7 +346,7 @@ def get_login_url(provider_id, auth_entry, redirect_url=None): ...@@ -341,7 +346,7 @@ def get_login_url(provider_id, auth_entry, redirect_url=None):
"""Gets the login URL for the endpoint that kicks off auth with a provider. """Gets the login URL for the endpoint that kicks off auth with a provider.
Args: Args:
provider_id: string identifier of the models.ProviderConfig child you want provider_id: string identifier of the social_django.models.ProviderConfig child you want
to disconnect from. to disconnect from.
auth_entry: string. Query argument specifying the desired entry point auth_entry: string. Query argument specifying the desired entry point
for the auth pipeline. Used by the pipeline for later branching. for the auth pipeline. Used by the pipeline for later branching.
...@@ -404,7 +409,7 @@ def get_provider_user_states(user): ...@@ -404,7 +409,7 @@ def get_provider_user_states(user):
each enabled provider. each enabled provider.
""" """
states = [] states = []
found_user_auths = list(models.DjangoStorage.user.get_social_auth_for_user(user)) found_user_auths = list(social_django.models.DjangoStorage.user.get_social_auth_for_user(user))
for enabled_provider in provider.Registry.enabled(): for enabled_provider in provider.Registry.enabled():
association = None association = None
...@@ -443,7 +448,7 @@ def make_random_password(length=None, choice_fn=random.SystemRandom().choice): ...@@ -443,7 +448,7 @@ def make_random_password(length=None, choice_fn=random.SystemRandom().choice):
def running(request): def running(request):
"""Returns True iff request is running a third-party auth pipeline.""" """Returns True iff request is running a third-party auth pipeline."""
return request.session.get('partial_pipeline') is not None # Avoid False for {}. return get(request) is not None # Avoid False for {}.
# Pipeline functions. # Pipeline functions.
...@@ -454,7 +459,7 @@ def running(request): ...@@ -454,7 +459,7 @@ def running(request):
def parse_query_params(strategy, response, *args, **kwargs): def parse_query_params(strategy, response, *args, **kwargs):
"""Reads whitelisted query params, transforms them into pipeline args.""" """Reads whitelisted query params, transforms them into pipeline args."""
auth_entry = strategy.session.get(AUTH_ENTRY_KEY) auth_entry = strategy.request.session.get(AUTH_ENTRY_KEY)
if not (auth_entry and auth_entry in _AUTH_ENTRY_CHOICES): if not (auth_entry and auth_entry in _AUTH_ENTRY_CHOICES):
raise AuthEntryError(strategy.request.backend, 'auth_entry missing or invalid') raise AuthEntryError(strategy.request.backend, 'auth_entry missing or invalid')
...@@ -524,7 +529,7 @@ def redirect_to_custom_form(request, auth_entry, kwargs): ...@@ -524,7 +529,7 @@ def redirect_to_custom_form(request, auth_entry, kwargs):
@partial.partial @partial.partial
def ensure_user_information(strategy, auth_entry, backend=None, user=None, social=None, def ensure_user_information(strategy, auth_entry, backend=None, user=None, social=None, current_partial=None,
allow_inactive_user=False, *args, **kwargs): allow_inactive_user=False, *args, **kwargs):
""" """
Ensure that we have the necessary information about a user (either an Ensure that we have the necessary information about a user (either an
...@@ -552,7 +557,7 @@ def ensure_user_information(strategy, auth_entry, backend=None, user=None, socia ...@@ -552,7 +557,7 @@ def ensure_user_information(strategy, auth_entry, backend=None, user=None, socia
def should_force_account_creation(): def should_force_account_creation():
""" For some third party providers, we auto-create user accounts """ """ For some third party providers, we auto-create user accounts """
current_provider = provider.Registry.get_from_pipeline({'backend': backend.name, 'kwargs': kwargs}) current_provider = provider.Registry.get_from_pipeline({'backend': current_partial.backend, 'kwargs': kwargs})
return current_provider and current_provider.skip_email_verification return current_provider and current_provider.skip_email_verification
if not user: if not user:
...@@ -606,7 +611,8 @@ def ensure_user_information(strategy, auth_entry, backend=None, user=None, socia ...@@ -606,7 +611,8 @@ def ensure_user_information(strategy, auth_entry, backend=None, user=None, socia
@partial.partial @partial.partial
def set_logged_in_cookies(backend=None, user=None, strategy=None, auth_entry=None, *args, **kwargs): def set_logged_in_cookies(backend=None, user=None, strategy=None, auth_entry=None, current_partial=None,
*args, **kwargs):
"""This pipeline step sets the "logged in" cookie for authenticated users. """This pipeline step sets the "logged in" cookie for authenticated users.
Some installations have a marketing site front-end separate from Some installations have a marketing site front-end separate from
...@@ -641,7 +647,7 @@ def set_logged_in_cookies(backend=None, user=None, strategy=None, auth_entry=Non ...@@ -641,7 +647,7 @@ def set_logged_in_cookies(backend=None, user=None, strategy=None, auth_entry=Non
has_cookie = student.cookies.is_logged_in_cookie_set(request) has_cookie = student.cookies.is_logged_in_cookie_set(request)
if not has_cookie: if not has_cookie:
try: try:
redirect_url = get_complete_url(backend.name) redirect_url = get_complete_url(current_partial.backend)
except ValueError: except ValueError:
# If for some reason we can't get the URL, just skip this step # If for some reason we can't get the URL, just skip this step
# This may be overly paranoid, but it's far more important that # This may be overly paranoid, but it's far more important that
...@@ -653,7 +659,7 @@ def set_logged_in_cookies(backend=None, user=None, strategy=None, auth_entry=Non ...@@ -653,7 +659,7 @@ def set_logged_in_cookies(backend=None, user=None, strategy=None, auth_entry=Non
@partial.partial @partial.partial
def login_analytics(strategy, auth_entry, *args, **kwargs): def login_analytics(strategy, auth_entry, current_partial=None, *args, **kwargs):
""" Sends login info to Segment """ """ Sends login info to Segment """
event_name = None event_name = None
...@@ -682,7 +688,7 @@ def login_analytics(strategy, auth_entry, *args, **kwargs): ...@@ -682,7 +688,7 @@ def login_analytics(strategy, auth_entry, *args, **kwargs):
@partial.partial @partial.partial
def associate_by_email_if_login_api(auth_entry, backend, details, user, *args, **kwargs): def associate_by_email_if_login_api(auth_entry, backend, details, user, current_partial=None, *args, **kwargs):
""" """
This pipeline step associates the current social auth with the user with the This pipeline step associates the current social auth with the user with the
same email address in the database. It defers to the social library's associate_by_email same email address in the database. It defers to the social library's associate_by_email
......
...@@ -7,8 +7,8 @@ import requests ...@@ -7,8 +7,8 @@ import requests
from django.contrib.sites.models import Site from django.contrib.sites.models import Site
from django.http import Http404 from django.http import Http404
from django.utils.functional import cached_property from django.utils.functional import cached_property
from social.backends.saml import OID_EDU_PERSON_ENTITLEMENT, SAMLAuth, SAMLIdentityProvider from social_core.backends.saml import OID_EDU_PERSON_ENTITLEMENT, SAMLAuth, SAMLIdentityProvider
from social.exceptions import AuthForbidden, AuthMissingParameter from social_core.exceptions import AuthForbidden
from openedx.core.djangoapps.theming.helpers import get_current_request from openedx.core.djangoapps.theming.helpers import get_current_request
...@@ -43,7 +43,6 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method ...@@ -43,7 +43,6 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
authenticate the user. authenticate the user.
raise Http404 if SAML authentication is disabled. raise Http404 if SAML authentication is disabled.
raise AuthMissingParameter if the 'idp' parameter is missing.
""" """
if not self._config.enabled: if not self._config.enabled:
log.error('SAML authentication is not enabled') log.error('SAML authentication is not enabled')
...@@ -70,7 +69,7 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method ...@@ -70,7 +69,7 @@ class SAMLAuthBackend(SAMLAuth): # pylint: disable=abstract-method
""" """
Get an instance of OneLogin_Saml2_Auth Get an instance of OneLogin_Saml2_Auth
idp: The Identity Provider - a social.backends.saml.SAMLIdentityProvider instance idp: The Identity Provider - a social_core.backends.saml.SAMLIdentityProvider instance
""" """
# We only override this method so that we can add extra debugging when debug_mode is True # We only override this method so that we can add extra debugging when debug_mode is True
# Note that auth_inst is instantiated just for the current HTTP request, then is destroyed # Note that auth_inst is instantiated just for the current HTTP request, then is destroyed
......
...@@ -42,18 +42,18 @@ def apply_settings(django_settings): ...@@ -42,18 +42,18 @@ def apply_settings(django_settings):
# this pipeline. # this pipeline.
django_settings.SOCIAL_AUTH_PIPELINE = [ django_settings.SOCIAL_AUTH_PIPELINE = [
'third_party_auth.pipeline.parse_query_params', 'third_party_auth.pipeline.parse_query_params',
'social.pipeline.social_auth.social_details', 'social_core.pipeline.social_auth.social_details',
'social.pipeline.social_auth.social_uid', 'social_core.pipeline.social_auth.social_uid',
'social.pipeline.social_auth.auth_allowed', 'social_core.pipeline.social_auth.auth_allowed',
'social.pipeline.social_auth.social_user', 'social_core.pipeline.social_auth.social_user',
'third_party_auth.pipeline.associate_by_email_if_login_api', 'third_party_auth.pipeline.associate_by_email_if_login_api',
'social.pipeline.user.get_username', 'social_core.pipeline.user.get_username',
'third_party_auth.pipeline.set_pipeline_timeout', 'third_party_auth.pipeline.set_pipeline_timeout',
'third_party_auth.pipeline.ensure_user_information', 'third_party_auth.pipeline.ensure_user_information',
'social.pipeline.user.create_user', 'social_core.pipeline.user.create_user',
'social.pipeline.social_auth.associate_user', 'social_core.pipeline.social_auth.associate_user',
'social.pipeline.social_auth.load_extra_data', 'social_core.pipeline.social_auth.load_extra_data',
'social.pipeline.user.user_details', 'social_core.pipeline.user.user_details',
'third_party_auth.pipeline.set_logged_in_cookies', 'third_party_auth.pipeline.set_logged_in_cookies',
'third_party_auth.pipeline.login_analytics', 'third_party_auth.pipeline.login_analytics',
] ]
...@@ -87,6 +87,6 @@ def apply_settings(django_settings): ...@@ -87,6 +87,6 @@ def apply_settings(django_settings):
# Context processors required under Django. # Context processors required under Django.
django_settings.SOCIAL_AUTH_UUID_LENGTH = 4 django_settings.SOCIAL_AUTH_UUID_LENGTH = 4
django_settings.DEFAULT_TEMPLATE_ENGINE['OPTIONS']['context_processors'] += ( django_settings.DEFAULT_TEMPLATE_ENGINE['OPTIONS']['context_processors'] += (
'social.apps.django_app.context_processors.backends', 'social_django.context_processors.backends',
'social.apps.django_app.context_processors.login_redirect', 'social_django.context_processors.login_redirect',
) )
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
A custom Strategy for python-social-auth that allows us to fetch configuration from A custom Strategy for python-social-auth that allows us to fetch configuration from
ConfigurationModels rather than django.settings ConfigurationModels rather than django.settings
""" """
from social.backends.oauth import OAuthAuth
from social.strategies.django_strategy import DjangoStrategy from social_core.backends.oauth import OAuthAuth
from social_django.strategy import DjangoStrategy
from .models import OAuth2ProviderConfig from .models import OAuth2ProviderConfig
from .pipeline import get as get_pipeline_from_request from .pipeline import get as get_pipeline_from_request
......
...@@ -14,9 +14,9 @@ from django.contrib.sessions.backends import cache ...@@ -14,9 +14,9 @@ from django.contrib.sessions.backends import cache
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.test import utils as django_utils from django.test import utils as django_utils
from django.conf import settings as django_settings from django.conf import settings as django_settings
from social import actions, exceptions from social_core import actions, exceptions
from social.apps.django_app import utils as social_utils from social_django import utils as social_utils
from social.apps.django_app import views as social_views from social_django import views as social_views
from lms.djangoapps.commerce.tests import TEST_API_URL from lms.djangoapps.commerce.tests import TEST_API_URL
from openedx.core.djangoapps.site_configuration.tests.factories import SiteFactory from openedx.core.djangoapps.site_configuration.tests.factories import SiteFactory
...@@ -26,7 +26,6 @@ from student.tests.factories import UserFactory ...@@ -26,7 +26,6 @@ from student.tests.factories import UserFactory
from student_account.views import account_settings_context from student_account.views import account_settings_context
from third_party_auth import middleware, pipeline from third_party_auth import middleware, pipeline
from third_party_auth import settings as auth_settings
from third_party_auth.tests import testutil from third_party_auth.tests import testutil
...@@ -270,7 +269,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -270,7 +269,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
form_field_data = self.provider.get_register_form_data(pipeline_kwargs) form_field_data = self.provider.get_register_form_data(pipeline_kwargs)
for prepopulated_form_data in form_field_data: for prepopulated_form_data in form_field_data:
if prepopulated_form_data in required_fields: if prepopulated_form_data in required_fields:
self.assertIn(form_field_data[prepopulated_form_data], response.content) self.assertIn(form_field_data[prepopulated_form_data], response.content.decode('utf-8'))
# Implementation details and actual tests past this point -- no more # Implementation details and actual tests past this point -- no more
# configuration needed. # configuration needed.
...@@ -285,7 +284,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -285,7 +284,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
return self.provider.backend_name return self.provider.backend_name
# pylint: disable=invalid-name # pylint: disable=invalid-name
def assert_account_settings_context_looks_correct(self, context, _user, duplicate=False, linked=None): def assert_account_settings_context_looks_correct(self, context, duplicate=False, linked=None):
"""Asserts the user's account settings page context is in the expected state. """Asserts the user's account settings page context is in the expected state.
If duplicate is True, we expect context['duplicate_provider'] to contain If duplicate is True, we expect context['duplicate_provider'] to contain
...@@ -473,7 +472,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -473,7 +472,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
return user return user
def fake_auth_complete(self, strategy): def fake_auth_complete(self, strategy):
"""Fake implementation of social.backends.BaseAuth.auth_complete. """Fake implementation of social_core.backends.BaseAuth.auth_complete.
Unlike what the docs say, it does not need to return a user instance. Unlike what the docs say, it does not need to return a user instance.
Sometimes (like when directing users to the /register form) it instead Sometimes (like when directing users to the /register form) it instead
...@@ -515,7 +514,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -515,7 +514,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
These two objects contain circular references, so we create them These two objects contain circular references, so we create them
together. The references themselves are a mixture of normal __init__ together. The references themselves are a mixture of normal __init__
stuff and monkey-patching done by python-social-auth. See, for example, stuff and monkey-patching done by python-social-auth. See, for example,
social.apps.django_apps.utils.strategy(). social_django.utils.strategy().
""" """
request = self.request_factory.get( request = self.request_factory.get(
pipeline.get_complete_url(self.backend_name) + pipeline.get_complete_url(self.backend_name) +
...@@ -600,7 +599,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -600,7 +599,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
# First we expect that we're in the unlinked state, and that there # First we expect that we're in the unlinked state, and that there
# really is no association in the backend. # really is no association in the backend.
self.assert_account_settings_context_looks_correct(account_settings_context(request), request.user, linked=False) self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=False)
self.assert_social_auth_does_not_exist_for_user(request.user, strategy) self.assert_social_auth_does_not_exist_for_user(request.user, strategy)
# We should be redirected back to the complete page, setting # We should be redirected back to the complete page, setting
...@@ -614,13 +613,19 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -614,13 +613,19 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
self.set_logged_in_cookies(request) self.set_logged_in_cookies(request)
# Fire off the auth pipeline to link. # Fire off the auth pipeline to link.
self.assert_redirect_to_dashboard_looks_correct(actions.do_complete( self.assert_redirect_to_dashboard_looks_correct( # pylint: disable=protected-access
request.backend, social_views._do_login, request.user, None, # pylint: disable=protected-access actions.do_complete(
redirect_field_name=auth.REDIRECT_FIELD_NAME)) request.backend,
social_views._do_login,
request.user,
None,
redirect_field_name=auth.REDIRECT_FIELD_NAME
)
)
# Now we expect to be in the linked state, with a backend entry. # Now we expect to be in the linked state, with a backend entry.
self.assert_social_auth_exists_for_user(request.user, strategy) self.assert_social_auth_exists_for_user(request.user, strategy)
self.assert_account_settings_context_looks_correct(account_settings_context(request), request.user, linked=True) self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=True)
def test_full_pipeline_succeeds_for_unlinking_account(self): def test_full_pipeline_succeeds_for_unlinking_account(self):
# First, create, the request and strategy that store pipeline state, # First, create, the request and strategy that store pipeline state,
...@@ -647,15 +652,21 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -647,15 +652,21 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
actions.do_complete(request.backend, social_views._do_login, user=user) # pylint: disable=protected-access actions.do_complete(request.backend, social_views._do_login, user=user) # pylint: disable=protected-access
# First we expect that we're in the linked state, with a backend entry. # First we expect that we're in the linked state, with a backend entry.
self.assert_account_settings_context_looks_correct(account_settings_context(request), user, linked=True) self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=True)
self.assert_social_auth_exists_for_user(request.user, strategy) self.assert_social_auth_exists_for_user(request.user, strategy)
# Fire off the disconnect pipeline to unlink. # Fire off the disconnect pipeline to unlink.
self.assert_redirect_to_dashboard_looks_correct(actions.do_disconnect( self.assert_redirect_to_dashboard_looks_correct(
request.backend, request.user, None, redirect_field_name=auth.REDIRECT_FIELD_NAME)) actions.do_disconnect(
request.backend,
request.user,
None,
redirect_field_name=auth.REDIRECT_FIELD_NAME
)
)
# Now we expect to be in the unlinked state, with no backend entry. # Now we expect to be in the unlinked state, with no backend entry.
self.assert_account_settings_context_looks_correct(account_settings_context(request), user, linked=False) self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=False)
self.assert_social_auth_does_not_exist_for_user(user, strategy) self.assert_social_auth_does_not_exist_for_user(user, strategy)
def test_linking_already_associated_account_raises_auth_already_associated(self): def test_linking_already_associated_account_raises_auth_already_associated(self):
...@@ -712,7 +723,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -712,7 +723,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
exceptions.AuthAlreadyAssociated(self.provider.backend_name, 'account is already in use.')) exceptions.AuthAlreadyAssociated(self.provider.backend_name, 'account is already in use.'))
self.assert_account_settings_context_looks_correct( self.assert_account_settings_context_looks_correct(
account_settings_context(request), user, duplicate=True, linked=True) account_settings_context(request), duplicate=True, linked=True)
def test_full_pipeline_succeeds_for_signing_in_to_existing_active_account(self): def test_full_pipeline_succeeds_for_signing_in_to_existing_active_account(self):
# First, create, the request and strategy that store pipeline state, # First, create, the request and strategy that store pipeline state,
...@@ -763,7 +774,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -763,7 +774,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
self.assert_redirect_to_dashboard_looks_correct( self.assert_redirect_to_dashboard_looks_correct(
actions.do_complete(request.backend, social_views._do_login, user=user)) actions.do_complete(request.backend, social_views._do_login, user=user))
self.assert_account_settings_context_looks_correct(account_settings_context(request), user) self.assert_account_settings_context_looks_correct(account_settings_context(request))
def test_signin_fails_if_account_not_active(self): def test_signin_fails_if_account_not_active(self):
_, strategy = self.get_request_and_strategy( _, strategy = self.get_request_and_strategy(
...@@ -869,7 +880,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase): ...@@ -869,7 +880,7 @@ class IntegrationTest(testutil.TestCase, test.TestCase):
actions.do_complete(strategy.request.backend, social_views._do_login, user=created_user)) actions.do_complete(strategy.request.backend, social_views._do_login, user=created_user))
# Now the user has been redirected to the dashboard. Their third party account should now be linked. # Now the user has been redirected to the dashboard. Their third party account should now be linked.
self.assert_social_auth_exists_for_user(created_user, strategy) self.assert_social_auth_exists_for_user(created_user, strategy)
self.assert_account_settings_context_looks_correct(account_settings_context(request), created_user, linked=True) self.assert_account_settings_context_looks_correct(account_settings_context(request), linked=True)
def test_new_account_registration_assigns_distinct_username_on_collision(self): def test_new_account_registration_assigns_distinct_username_on_collision(self):
original_username = self.get_username() original_username = self.get_username()
......
...@@ -7,7 +7,7 @@ from third_party_auth.tests import testutil ...@@ -7,7 +7,7 @@ from third_party_auth.tests import testutil
from .base import IntegrationTestMixin from .base import IntegrationTestMixin
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled') @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
class GenericIntegrationTest(IntegrationTestMixin, testutil.TestCase): class GenericIntegrationTest(IntegrationTestMixin, testutil.TestCase):
""" """
Basic integration tests of third_party_auth using Dummy provider Basic integration tests of third_party_auth using Dummy provider
......
...@@ -6,7 +6,7 @@ from django.conf import settings ...@@ -6,7 +6,7 @@ from django.conf import settings
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
import json import json
from mock import patch from mock import patch
from social.exceptions import AuthException from social_core.exceptions import AuthException
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
from third_party_auth import pipeline from third_party_auth import pipeline
from third_party_auth.tests.specs import base from third_party_auth.tests.specs import base
......
...@@ -6,10 +6,9 @@ import ddt ...@@ -6,10 +6,9 @@ import ddt
import unittest import unittest
import httpretty import httpretty
import json import json
import time
from mock import patch from mock import patch
from freezegun import freeze_time from freezegun import freeze_time
from social.apps.django_app.default.models import UserSocialAuth from social_django.models import UserSocialAuth
from unittest import skip from unittest import skip
from third_party_auth.saml import log as saml_log from third_party_auth.saml import log as saml_log
...@@ -119,7 +118,7 @@ class SamlIntegrationTestUtilities(object): ...@@ -119,7 +118,7 @@ class SamlIntegrationTestUtilities(object):
@ddt.ddt @ddt.ddt
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled') @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin, testutil.SAMLTestCase): class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin, testutil.SAMLTestCase):
""" """
TestShib provider Integration Test, to test SAML functionality TestShib provider Integration Test, to test SAML functionality
...@@ -157,7 +156,7 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin ...@@ -157,7 +156,7 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
record = UserSocialAuth.objects.get( record = UserSocialAuth.objects.get(
user=self.user, provider=self.PROVIDER_BACKEND, uid__startswith=self.PROVIDER_IDP_SLUG user=self.user, provider=self.PROVIDER_BACKEND, uid__startswith=self.PROVIDER_IDP_SLUG
) )
attributes = record.extra_data["attributes"] attributes = record.extra_data
self.assertEqual( self.assertEqual(
attributes.get("urn:oid:1.3.6.1.4.1.5923.1.1.1.9"), ["Member@testshib.org", "Staff@testshib.org"] attributes.get("urn:oid:1.3.6.1.4.1.5923.1.1.1.9"), ["Member@testshib.org", "Staff@testshib.org"]
) )
...@@ -232,7 +231,7 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin ...@@ -232,7 +231,7 @@ class TestShibIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin
self._test_return_login(previous_session_timed_out=True) self._test_return_login(previous_session_timed_out=True)
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled') @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin, testutil.SAMLTestCase): class SuccessFactorsIntegrationTest(SamlIntegrationTestUtilities, IntegrationTestMixin, testutil.SAMLTestCase):
""" """
Test basic SAML capability using the TestShib details, and then check that we're able Test basic SAML capability using the TestShib details, and then check that we're able
......
...@@ -20,7 +20,7 @@ class TwitterIntegrationTest(base.Oauth2IntegrationTest): ...@@ -20,7 +20,7 @@ class TwitterIntegrationTest(base.Oauth2IntegrationTest):
# To test an OAuth1 provider, we need to patch an additional method: # To test an OAuth1 provider, we need to patch an additional method:
patcher = patch( patcher = patch(
'social.backends.twitter.TwitterOAuth.unauthorized_token', 'social_core.backends.twitter.TwitterOAuth.unauthorized_token',
create=True, create=True,
return_value="unauth_token" return_value="unauth_token"
) )
......
...@@ -16,7 +16,7 @@ from third_party_auth.tests import testutil ...@@ -16,7 +16,7 @@ from third_party_auth.tests import testutil
# This is necessary because cms does not implement third party auth # This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'), 'third party auth not enabled') @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
class Oauth2ProviderConfigAdminTest(testutil.TestCase): class Oauth2ProviderConfigAdminTest(testutil.TestCase):
""" """
Tests for oauth2 provider config admin Tests for oauth2 provider config admin
......
...@@ -5,6 +5,7 @@ import unittest ...@@ -5,6 +5,7 @@ import unittest
from django.conf import settings from django.conf import settings
from django.test import Client from django.test import Client
from social_django.models import Partial
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
...@@ -13,39 +14,47 @@ class TestSessionFlushMiddleware(unittest.TestCase): ...@@ -13,39 +14,47 @@ class TestSessionFlushMiddleware(unittest.TestCase):
Ensure that if the pipeline is exited when it's been quarantined, Ensure that if the pipeline is exited when it's been quarantined,
the entire session is flushed. the entire session is flushed.
""" """
def setUp(self):
self.client = Client()
self.fancy_variable = 13025
self.token = 'pipeline_running'
self.tpa_quarantined_modules = ('fake_quarantined_module',)
def tearDown(self):
Partial.objects.all().delete()
def test_session_flush(self): def test_session_flush(self):
""" """
Test that a quarantined session is flushed when navigating elsewhere Test that a quarantined session is flushed when navigating elsewhere
""" """
client = Client() session = self.client.session
session = client.session session['fancy_variable'] = self.fancy_variable
session['fancy_variable'] = 13025 session['partial_pipeline_token'] = self.token
session['partial_pipeline'] = 'pipeline_running' session['third_party_auth_quarantined_modules'] = self.tpa_quarantined_modules
session['third_party_auth_quarantined_modules'] = ('fake_quarantined_module',)
session.save() session.save()
client.get('/') Partial.objects.create(token=session.get('partial_pipeline_token'))
self.assertEqual(client.session.get('fancy_variable', None), None) self.client.get('/')
self.assertEqual(self.client.session.get('fancy_variable', None), None)
def test_session_no_running_pipeline(self): def test_session_no_running_pipeline(self):
""" """
Test that a quarantined session without a running pipeline is not flushed Test that a quarantined session without a running pipeline is not flushed
""" """
client = Client() session = self.client.session
session = client.session session['fancy_variable'] = self.fancy_variable
session['fancy_variable'] = 13025 session['third_party_auth_quarantined_modules'] = self.tpa_quarantined_modules
session['third_party_auth_quarantined_modules'] = ('fake_quarantined_module',)
session.save() session.save()
client.get('/') self.client.get('/')
self.assertEqual(client.session.get('fancy_variable', None), 13025) self.assertEqual(self.client.session.get('fancy_variable', None), self.fancy_variable)
def test_session_no_quarantine(self): def test_session_no_quarantine(self):
""" """
Test that a session with a running pipeline but no quarantine is not flushed Test that a session with a running pipeline but no quarantine is not flushed
""" """
client = Client() session = self.client.session
session = client.session session['fancy_variable'] = self.fancy_variable
session['fancy_variable'] = 13025 session['partial_pipeline_token'] = self.token
session['partial_pipeline'] = 'pipeline_running'
session.save() session.save()
client.get('/') Partial.objects.create(token=session.get('partial_pipeline_token'))
self.assertEqual(client.session.get('fancy_variable', None), 13025) self.client.get('/')
self.assertEqual(self.client.session.get('fancy_variable', None), self.fancy_variable)
...@@ -35,7 +35,7 @@ class MakeRandomPasswordTest(testutil.TestCase): ...@@ -35,7 +35,7 @@ class MakeRandomPasswordTest(testutil.TestCase):
self.assertEqual(expected, pipeline.make_random_password(choice_fn=random_instance.choice)) self.assertEqual(expected, pipeline.make_random_password(choice_fn=random_instance.choice))
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled') @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
class ProviderUserStateTestCase(testutil.TestCase): class ProviderUserStateTestCase(testutil.TestCase):
"""Tests ProviderUserState behavior.""" """Tests ProviderUserState behavior."""
......
...@@ -6,7 +6,7 @@ import mock ...@@ -6,7 +6,7 @@ import mock
from django import test from django import test
from django.conf import settings from django.conf import settings
from django.contrib.auth import models from django.contrib.auth import models
from social.apps.django_app.default import models as social_models from social_django import models as social_models
from third_party_auth import pipeline, provider from third_party_auth import pipeline, provider
from third_party_auth.tests import testutil from third_party_auth.tests import testutil
...@@ -24,8 +24,7 @@ class TestCase(testutil.TestCase, test.TestCase): ...@@ -24,8 +24,7 @@ class TestCase(testutil.TestCase, test.TestCase):
self.enabled_provider = self.configure_google_provider(enabled=True) self.enabled_provider = self.configure_google_provider(enabled=True)
@unittest.skipUnless( @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
testutil.AUTH_FEATURES_KEY in settings.FEATURES, testutil.AUTH_FEATURES_KEY + ' not in settings.FEATURES')
class GetAuthenticatedUserTestCase(TestCase): class GetAuthenticatedUserTestCase(TestCase):
"""Tests for get_authenticated_user.""" """Tests for get_authenticated_user."""
...@@ -66,8 +65,7 @@ class GetAuthenticatedUserTestCase(TestCase): ...@@ -66,8 +65,7 @@ class GetAuthenticatedUserTestCase(TestCase):
self.assertEqual(self.enabled_provider.get_authentication_backend(), user.backend) self.assertEqual(self.enabled_provider.get_authentication_backend(), user.backend)
@unittest.skipUnless( @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
testutil.AUTH_FEATURES_KEY in settings.FEATURES, testutil.AUTH_FEATURES_KEY + ' not in settings.FEATURES')
class GetProviderUserStatesTestCase(testutil.TestCase, test.TestCase): class GetProviderUserStatesTestCase(testutil.TestCase, test.TestCase):
"""Tests generation of ProviderUserStates.""" """Tests generation of ProviderUserStates."""
...@@ -143,8 +141,7 @@ class GetProviderUserStatesTestCase(testutil.TestCase, test.TestCase): ...@@ -143,8 +141,7 @@ class GetProviderUserStatesTestCase(testutil.TestCase, test.TestCase):
self.assertEqual(self.user, linkedin_state.user) self.assertEqual(self.user, linkedin_state.user)
@unittest.skipUnless( @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
testutil.AUTH_FEATURES_KEY in settings.FEATURES, testutil.AUTH_FEATURES_KEY + ' not in settings.FEATURES')
class UrlFormationTestCase(TestCase): class UrlFormationTestCase(TestCase):
"""Tests formation of URLs for pipeline hook points.""" """Tests formation of URLs for pipeline hook points."""
...@@ -210,8 +207,7 @@ class UrlFormationTestCase(TestCase): ...@@ -210,8 +207,7 @@ class UrlFormationTestCase(TestCase):
pipeline.get_complete_url(provider_id) pipeline.get_complete_url(provider_id)
@unittest.skipUnless( @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
testutil.AUTH_FEATURES_KEY in settings.FEATURES, testutil.AUTH_FEATURES_KEY + ' not in settings.FEATURES')
class TestPipelineUtilityFunctions(TestCase, test.TestCase): class TestPipelineUtilityFunctions(TestCase, test.TestCase):
""" """
Test some of the isolated utility functions in the pipeline Test some of the isolated utility functions in the pipeline
...@@ -230,36 +226,36 @@ class TestPipelineUtilityFunctions(TestCase, test.TestCase): ...@@ -230,36 +226,36 @@ class TestPipelineUtilityFunctions(TestCase, test.TestCase):
Test that we can use a dictionary with a UID entry to retrieve a Test that we can use a dictionary with a UID entry to retrieve a
database-backed UserSocialAuth object. database-backed UserSocialAuth object.
""" """
request = mock.MagicMock( request = mock.MagicMock()
session={ pipeline_partial = {
'partial_pipeline': { 'kwargs': {
'kwargs': { 'social': {
'social': { 'uid': 'fake uid'
'uid': 'fake uid'
}
}
} }
} }
) }
real_social = pipeline.get_real_social_auth_object(request)
self.assertEqual(real_social, self.social_auth) with mock.patch('third_party_auth.pipeline.get') as get_pipeline:
get_pipeline.return_value = pipeline_partial
real_social = pipeline.get_real_social_auth_object(request)
self.assertEqual(real_social, self.social_auth)
def test_get_real_social_auth(self): def test_get_real_social_auth(self):
""" """
Test that trying to get a database-backed UserSocialAuth from an existing Test that trying to get a database-backed UserSocialAuth from an existing
instance returns correctly. instance returns correctly.
""" """
request = mock.MagicMock( request = mock.MagicMock()
session={ pipeline_partial = {
'partial_pipeline': { 'kwargs': {
'kwargs': { 'social': self.social_auth
'social': self.social_auth
}
}
} }
) }
real_social = pipeline.get_real_social_auth_object(request)
self.assertEqual(real_social, self.social_auth) with mock.patch('third_party_auth.pipeline.get') as get_pipeline:
get_pipeline.return_value = pipeline_partial
real_social = pipeline.get_real_social_auth_object(request)
self.assertEqual(real_social, self.social_auth)
def test_get_real_social_auth_no_pipeline(self): def test_get_real_social_auth_no_pipeline(self):
""" """
......
...@@ -13,7 +13,7 @@ SITE_DOMAIN_A = 'professionalx.example.com' ...@@ -13,7 +13,7 @@ SITE_DOMAIN_A = 'professionalx.example.com'
SITE_DOMAIN_B = 'somethingelse.example.com' SITE_DOMAIN_B = 'somethingelse.example.com'
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled') @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
class RegistryTest(testutil.TestCase): class RegistryTest(testutil.TestCase):
"""Tests registry discovery and operation.""" """Tests registry discovery and operation."""
......
...@@ -45,7 +45,7 @@ class SettingsUnitTest(testutil.TestCase): ...@@ -45,7 +45,7 @@ class SettingsUnitTest(testutil.TestCase):
settings.apply_settings(self.settings) settings.apply_settings(self.settings)
self.assertEqual(settings._FIELDS_STORED_IN_SESSION, self.settings.FIELDS_STORED_IN_SESSION) self.assertEqual(settings._FIELDS_STORED_IN_SESSION, self.settings.FIELDS_STORED_IN_SESSION)
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, 'third_party_auth not enabled') @unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
def test_apply_settings_enables_no_providers_by_default(self): def test_apply_settings_enables_no_providers_by_default(self):
# Providers are only enabled via ConfigurationModels in the database # Providers are only enabled via ConfigurationModels in the database
settings.apply_settings(self.settings) settings.apply_settings(self.settings)
......
...@@ -12,12 +12,12 @@ from onelogin.saml2.errors import OneLogin_Saml2_Error ...@@ -12,12 +12,12 @@ from onelogin.saml2.errors import OneLogin_Saml2_Error
# Define some XML namespaces: # Define some XML namespaces:
from third_party_auth.tasks import SAML_XML_NS from third_party_auth.tasks import SAML_XML_NS
from .testutil import AUTH_FEATURE_ENABLED, SAMLTestCase from .testutil import AUTH_FEATURE_ENABLED, AUTH_FEATURES_KEY, SAMLTestCase
XMLDSIG_XML_NS = 'http://www.w3.org/2000/09/xmldsig#' XMLDSIG_XML_NS = 'http://www.w3.org/2000/09/xmldsig#'
@unittest.skipUnless(AUTH_FEATURE_ENABLED, 'third_party_auth not enabled') @unittest.skipUnless(AUTH_FEATURE_ENABLED, AUTH_FEATURES_KEY + ' not enabled')
@ddt.ddt @ddt.ddt
class SAMLMetadataTest(SAMLTestCase): class SAMLMetadataTest(SAMLTestCase):
""" """
...@@ -135,7 +135,7 @@ class SAMLMetadataTest(SAMLTestCase): ...@@ -135,7 +135,7 @@ class SAMLMetadataTest(SAMLTestCase):
self.assertEqual(support_email_node.text, support_email) self.assertEqual(support_email_node.text, support_email)
@unittest.skipUnless(AUTH_FEATURE_ENABLED, 'third_party_auth not enabled') @unittest.skipUnless(AUTH_FEATURE_ENABLED, AUTH_FEATURES_KEY + ' not enabled')
class SAMLAuthTest(SAMLTestCase): class SAMLAuthTest(SAMLTestCase):
""" """
Test the SAML auth views Test the SAML auth views
......
...@@ -4,8 +4,8 @@ import json ...@@ -4,8 +4,8 @@ import json
import httpretty import httpretty
from provider.constants import PUBLIC from provider.constants import PUBLIC
from provider.oauth2.models import Client from provider.oauth2.models import Client
from social.apps.django_app.default.models import UserSocialAuth from social_core.backends.facebook import FacebookOAuth2, API_VERSION as FACEBOOK_API_VERSION
from social.backends.facebook import FacebookOAuth2 from social_django.models import UserSocialAuth, Partial
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
...@@ -39,6 +39,10 @@ class ThirdPartyOAuthTestMixin(ThirdPartyAuthTestMixin): ...@@ -39,6 +39,10 @@ class ThirdPartyOAuthTestMixin(ThirdPartyAuthTestMixin):
elif self.BACKEND == 'facebook': elif self.BACKEND == 'facebook':
self.configure_facebook_provider(enabled=True, visible=True) self.configure_facebook_provider(enabled=True, visible=True)
def tearDown(self):
super(ThirdPartyOAuthTestMixin, self).tearDown()
Partial.objects.all().delete()
def _create_client(self): def _create_client(self):
""" """
Create an OAuth2 client application Create an OAuth2 client application
...@@ -81,7 +85,7 @@ class ThirdPartyOAuthTestMixin(ThirdPartyAuthTestMixin): ...@@ -81,7 +85,7 @@ class ThirdPartyOAuthTestMixin(ThirdPartyAuthTestMixin):
class ThirdPartyOAuthTestMixinFacebook(object): class ThirdPartyOAuthTestMixinFacebook(object):
"""Tests oauth with the Facebook backend""" """Tests oauth with the Facebook backend"""
BACKEND = "facebook" BACKEND = "facebook"
USER_URL = FacebookOAuth2.USER_DATA_URL USER_URL = FacebookOAuth2.USER_DATA_URL.format(version=FACEBOOK_API_VERSION)
# In facebook responses, the "id" field is used as the user's identifier # In facebook responses, the "id" field is used as the user's identifier
UID_FIELD = "id" UID_FIELD = "id"
......
...@@ -10,5 +10,5 @@ urlpatterns = patterns( ...@@ -10,5 +10,5 @@ urlpatterns = patterns(
url(r'^auth/custom_auth_entry', post_to_custom_auth_form, name='tpa_post_to_custom_auth_form'), url(r'^auth/custom_auth_entry', post_to_custom_auth_form, name='tpa_post_to_custom_auth_form'),
url(r'^auth/saml/metadata.xml', saml_metadata_view), url(r'^auth/saml/metadata.xml', saml_metadata_view),
url(r'^auth/login/(?P<backend>lti)/$', lti_login_and_complete_view), url(r'^auth/login/(?P<backend>lti)/$', lti_login_and_complete_view),
url(r'^auth/', include('social.apps.django_app.urls', namespace='social')), url(r'^auth/', include('social_django.urls', namespace='social')),
) )
""" """
Extra views required for SSO Extra views required for SSO
""" """
import social
from django.conf import settings from django.conf import settings
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.http import Http404, HttpResponse, HttpResponseNotAllowed, HttpResponseServerError from django.http import Http404, HttpResponse, HttpResponseNotAllowed, HttpResponseServerError
from django.shortcuts import redirect, render from django.shortcuts import redirect, render
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from social.apps.django_app.utils import load_backend, load_strategy from social_django.utils import load_strategy, load_backend, psa
from social.apps.django_app.views import complete from social_django.views import complete
from social.utils import setting_name from social_core.utils import setting_name
from student.models import UserProfile from student.models import UserProfile
from student.views import compose_and_send_activation_email from student.views import compose_and_send_activation_email
...@@ -56,7 +55,7 @@ def saml_metadata_view(request): ...@@ -56,7 +55,7 @@ def saml_metadata_view(request):
@csrf_exempt @csrf_exempt
@social.apps.django_app.utils.psa('{0}:complete'.format(URL_NAMESPACE)) @psa('{0}:complete'.format(URL_NAMESPACE))
def lti_login_and_complete_view(request, backend, *args, **kwargs): def lti_login_and_complete_view(request, backend, *args, **kwargs):
"""This is a combination login/complete due to LTI being a one step login""" """This is a combination login/complete due to LTI being a one step login"""
......
...@@ -385,7 +385,7 @@ class SelfPacedCourseInfoTestCase(LoginEnrollmentTestCase, SharedModuleStoreTest ...@@ -385,7 +385,7 @@ class SelfPacedCourseInfoTestCase(LoginEnrollmentTestCase, SharedModuleStoreTest
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
def test_num_queries_instructor_paced(self): def test_num_queries_instructor_paced(self):
self.fetch_course_info_with_queries(self.instructor_paced_course, 23, 4) self.fetch_course_info_with_queries(self.instructor_paced_course, 24, 4)
def test_num_queries_self_paced(self): def test_num_queries_self_paced(self):
self.fetch_course_info_with_queries(self.self_paced_course, 23, 4) self.fetch_course_info_with_queries(self.self_paced_course, 24, 4)
...@@ -209,8 +209,8 @@ class IndexQueryTestCase(ModuleStoreTestCase): ...@@ -209,8 +209,8 @@ class IndexQueryTestCase(ModuleStoreTestCase):
NUM_PROBLEMS = 20 NUM_PROBLEMS = 20
@ddt.data( @ddt.data(
(ModuleStoreEnum.Type.mongo, 10, 149), (ModuleStoreEnum.Type.mongo, 10, 150),
(ModuleStoreEnum.Type.split, 4, 149), (ModuleStoreEnum.Type.split, 4, 150),
) )
@ddt.unpack @ddt.unpack
def test_index_query_counts(self, store_type, expected_mongo_query_count, expected_mysql_query_count): def test_index_query_counts(self, store_type, expected_mongo_query_count, expected_mysql_query_count):
...@@ -1435,12 +1435,12 @@ class ProgressPageTests(ProgressPageBaseTests): ...@@ -1435,12 +1435,12 @@ class ProgressPageTests(ProgressPageBaseTests):
"""Test that query counts remain the same for self-paced and instructor-paced courses.""" """Test that query counts remain the same for self-paced and instructor-paced courses."""
SelfPacedConfiguration(enabled=self_paced_enabled).save() SelfPacedConfiguration(enabled=self_paced_enabled).save()
self.setup_course(self_paced=self_paced) self.setup_course(self_paced=self_paced)
with self.assertNumQueries(43), check_mongo_calls(1): with self.assertNumQueries(44), check_mongo_calls(1):
self._get_progress_page() self._get_progress_page()
@ddt.data( @ddt.data(
(False, 43, 29), (False, 44, 30),
(True, 36, 25) (True, 37, 26)
) )
@ddt.unpack @ddt.unpack
def test_progress_queries(self, enable_waffle, initial, subsequent): def test_progress_queries(self, enable_waffle, initial, subsequent):
......
...@@ -379,8 +379,8 @@ class ViewsQueryCountTestCase( ...@@ -379,8 +379,8 @@ class ViewsQueryCountTestCase(
return inner return inner
@ddt.data( @ddt.data(
(ModuleStoreEnum.Type.mongo, 3, 4, 30), (ModuleStoreEnum.Type.mongo, 3, 4, 31),
(ModuleStoreEnum.Type.split, 3, 13, 30), (ModuleStoreEnum.Type.split, 3, 13, 31),
) )
@ddt.unpack @ddt.unpack
@count_queries @count_queries
...@@ -388,8 +388,8 @@ class ViewsQueryCountTestCase( ...@@ -388,8 +388,8 @@ class ViewsQueryCountTestCase(
self.create_thread_helper(mock_request) self.create_thread_helper(mock_request)
@ddt.data( @ddt.data(
(ModuleStoreEnum.Type.mongo, 3, 3, 26), (ModuleStoreEnum.Type.mongo, 3, 3, 27),
(ModuleStoreEnum.Type.split, 3, 10, 26), (ModuleStoreEnum.Type.split, 3, 10, 27),
) )
@ddt.unpack @ddt.unpack
@count_queries @count_queries
......
...@@ -681,10 +681,10 @@ X_FRAME_OPTIONS = ENV_TOKENS.get('X_FRAME_OPTIONS', X_FRAME_OPTIONS) ...@@ -681,10 +681,10 @@ X_FRAME_OPTIONS = ENV_TOKENS.get('X_FRAME_OPTIONS', X_FRAME_OPTIONS)
if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
AUTHENTICATION_BACKENDS = ( AUTHENTICATION_BACKENDS = (
ENV_TOKENS.get('THIRD_PARTY_AUTH_BACKENDS', [ ENV_TOKENS.get('THIRD_PARTY_AUTH_BACKENDS', [
'social.backends.google.GoogleOAuth2', 'social_core.backends.google.GoogleOAuth2',
'social.backends.linkedin.LinkedinOAuth2', 'social_core.backends.linkedin.LinkedinOAuth2',
'social.backends.facebook.FacebookOAuth2', 'social_core.backends.facebook.FacebookOAuth2',
'social.backends.azuread.AzureADOAuth2', 'social_core.backends.azuread.AzureADOAuth2',
'third_party_auth.saml.SAMLAuthBackend', 'third_party_auth.saml.SAMLAuthBackend',
'third_party_auth.lti.LTIAuthBackend', 'third_party_auth.lti.LTIAuthBackend',
]) + list(AUTHENTICATION_BACKENDS) ]) + list(AUTHENTICATION_BACKENDS)
......
...@@ -142,9 +142,9 @@ ...@@ -142,9 +142,9 @@
"SYSLOG_SERVER": "", "SYSLOG_SERVER": "",
"TECH_SUPPORT_EMAIL": "technical@example.com", "TECH_SUPPORT_EMAIL": "technical@example.com",
"THIRD_PARTY_AUTH_BACKENDS": [ "THIRD_PARTY_AUTH_BACKENDS": [
"social.backends.google.GoogleOAuth2", "social_core.backends.google.GoogleOAuth2",
"social.backends.linkedin.LinkedinOAuth2", "social_core.backends.linkedin.LinkedinOAuth2",
"social.backends.facebook.FacebookOAuth2", "social_core.backends.facebook.FacebookOAuth2",
"third_party_auth.dummy.DummyBackend", "third_party_auth.dummy.DummyBackend",
"third_party_auth.saml.SAMLAuthBackend" "third_party_auth.saml.SAMLAuthBackend"
], ],
......
...@@ -2110,7 +2110,7 @@ INSTALLED_APPS = ( ...@@ -2110,7 +2110,7 @@ INSTALLED_APPS = (
# edX Mobile API # edX Mobile API
'mobile_api', 'mobile_api',
'social.apps.django_app.default', 'social_django',
# Surveys # Surveys
'survey', 'survey',
......
...@@ -260,11 +260,11 @@ PASSWORD_COMPLEXITY = {} ...@@ -260,11 +260,11 @@ PASSWORD_COMPLEXITY = {}
FEATURES['ENABLE_THIRD_PARTY_AUTH'] = True FEATURES['ENABLE_THIRD_PARTY_AUTH'] = True
AUTHENTICATION_BACKENDS = ( AUTHENTICATION_BACKENDS = (
'social.backends.google.GoogleOAuth2', 'social_core.backends.google.GoogleOAuth2',
'social.backends.linkedin.LinkedinOAuth2', 'social_core.backends.linkedin.LinkedinOAuth2',
'social.backends.facebook.FacebookOAuth2', 'social_core.backends.facebook.FacebookOAuth2',
'social.backends.azuread.AzureADOAuth2', 'social_core.backends.azuread.AzureADOAuth2',
'social.backends.twitter.TwitterOAuth', 'social_core.backends.twitter.TwitterOAuth',
'third_party_auth.dummy.DummyBackend', 'third_party_auth.dummy.DummyBackend',
'third_party_auth.saml.SAMLAuthBackend', 'third_party_auth.saml.SAMLAuthBackend',
'third_party_auth.lti.LTIAuthBackend', 'third_party_auth.lti.LTIAuthBackend',
...@@ -554,7 +554,7 @@ SEARCH_ENGINE = "search.tests.mock_search_engine.MockSearchEngine" ...@@ -554,7 +554,7 @@ SEARCH_ENGINE = "search.tests.mock_search_engine.MockSearchEngine"
FACEBOOK_APP_SECRET = "Test" FACEBOOK_APP_SECRET = "Test"
FACEBOOK_APP_ID = "Test" FACEBOOK_APP_ID = "Test"
FACEBOOK_API_VERSION = "v2.2" FACEBOOK_API_VERSION = "v2.8"
######### custom courses ######### ######### custom courses #########
INSTALLED_APPS += ('lms.djangoapps.ccx', 'openedx.core.djangoapps.ccxcon') INSTALLED_APPS += ('lms.djangoapps.ccx', 'openedx.core.djangoapps.ccxcon')
......
...@@ -14,6 +14,7 @@ settings.INSTALLED_APPS # pylint: disable=pointless-statement ...@@ -14,6 +14,7 @@ settings.INSTALLED_APPS # pylint: disable=pointless-statement
from openedx.core.lib.django_startup import autostartup from openedx.core.lib.django_startup import autostartup
from openedx.core.release import doc_version from openedx.core.release import doc_version
import analytics import analytics
from openedx.core.djangoapps.monkey_patch import django_db_models_options from openedx.core.djangoapps.monkey_patch import django_db_models_options
import xmodule.x_module import xmodule.x_module
......
...@@ -10,8 +10,8 @@ from provider.forms import OAuthForm, OAuthValidationError ...@@ -10,8 +10,8 @@ from provider.forms import OAuthForm, OAuthValidationError
from provider.oauth2.forms import ScopeChoiceField, ScopeMixin from provider.oauth2.forms import ScopeChoiceField, ScopeMixin
from provider.oauth2.models import Client from provider.oauth2.models import Client
from requests import HTTPError from requests import HTTPError
from social.backends import oauth as social_oauth from social_core.backends import oauth as social_oauth
from social.exceptions import AuthException from social_core.exceptions import AuthException
from third_party_auth import pipeline from third_party_auth import pipeline
...@@ -90,15 +90,16 @@ class AccessTokenExchangeForm(ScopeMixin, OAuthForm): ...@@ -90,15 +90,16 @@ class AccessTokenExchangeForm(ScopeMixin, OAuthForm):
self.cleaned_data["client"] = client self.cleaned_data["client"] = client
user = None user = None
access_token = self.cleaned_data.get("access_token")
try: try:
user = backend.do_auth(self.cleaned_data.get("access_token"), allow_inactive_user=True) user = backend.do_auth(access_token, allow_inactive_user=True)
except (HTTPError, AuthException): except (HTTPError, AuthException):
pass pass
if user and isinstance(user, User): if user and isinstance(user, User):
self.cleaned_data["user"] = user self.cleaned_data["user"] = user
else: else:
# Ensure user does not re-enter the pipeline # Ensure user does not re-enter the pipeline
self.request.social_strategy.clean_partial_pipeline() self.request.social_strategy.clean_partial_pipeline(access_token)
raise OAuthValidationError( raise OAuthValidationError(
{ {
"error": "invalid_grant", "error": "invalid_grant",
......
...@@ -10,12 +10,13 @@ from django.test import TestCase ...@@ -10,12 +10,13 @@ from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
import httpretty import httpretty
from provider import scope from provider import scope
import social.apps.django_app.utils as social_utils from social_django.models import Partial
import social_django.utils as social_utils
from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle
from ..forms import AccessTokenExchangeForm from ..forms import AccessTokenExchangeForm
from .utils import AccessTokenExchangeTestMixin from .utils import AccessTokenExchangeTestMixin, TPA_FEATURE_ENABLED, TPA_FEATURES_KEY
from .mixins import DOPAdapterMixin, DOTAdapterMixin from .mixins import DOPAdapterMixin, DOTAdapterMixin
...@@ -32,13 +33,16 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin): ...@@ -32,13 +33,16 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
# pylint: disable=no-member # pylint: disable=no-member
self.request.backend = social_utils.load_backend(self.request.social_strategy, self.BACKEND, redirect_uri) self.request.backend = social_utils.load_backend(self.request.social_strategy, self.BACKEND, redirect_uri)
def tearDown(self):
super(AccessTokenExchangeFormTest, self).tearDown()
Partial.objects.all().delete()
def _assert_error(self, data, expected_error, expected_error_description): def _assert_error(self, data, expected_error, expected_error_description):
form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data) form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data)
self.assertEqual( self.assertEqual(
form.errors, form.errors,
{"error": expected_error, "error_description": expected_error_description} {"error": expected_error, "error_description": expected_error_description}
) )
self.assertNotIn("partial_pipeline", self.request.session)
def _assert_success(self, data, expected_scopes): def _assert_success(self, data, expected_scopes):
form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data) form = AccessTokenExchangeForm(request=self.request, oauth2_adapter=self.oauth2_adapter, data=data)
...@@ -49,7 +53,7 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin): ...@@ -49,7 +53,7 @@ class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
# This is necessary because cms does not implement third party auth # This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @unittest.skipUnless(TPA_FEATURE_ENABLED, TPA_FEATURES_KEY + " not enabled")
@httpretty.activate @httpretty.activate
class DOPAccessTokenExchangeFormTestFacebook( class DOPAccessTokenExchangeFormTestFacebook(
DOPAdapterMixin, DOPAdapterMixin,
...@@ -65,7 +69,7 @@ class DOPAccessTokenExchangeFormTestFacebook( ...@@ -65,7 +69,7 @@ class DOPAccessTokenExchangeFormTestFacebook(
# This is necessary because cms does not implement third party auth # This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @unittest.skipUnless(TPA_FEATURE_ENABLED, TPA_FEATURES_KEY + " not enabled")
@httpretty.activate @httpretty.activate
class DOTAccessTokenExchangeFormTestFacebook( class DOTAccessTokenExchangeFormTestFacebook(
DOTAdapterMixin, DOTAdapterMixin,
...@@ -81,7 +85,7 @@ class DOTAccessTokenExchangeFormTestFacebook( ...@@ -81,7 +85,7 @@ class DOTAccessTokenExchangeFormTestFacebook(
# This is necessary because cms does not implement third party auth # This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @unittest.skipUnless(TPA_FEATURE_ENABLED, TPA_FEATURES_KEY + " not enabled")
@httpretty.activate @httpretty.activate
class DOPAccessTokenExchangeFormTestGoogle( class DOPAccessTokenExchangeFormTestGoogle(
DOPAdapterMixin, DOPAdapterMixin,
...@@ -97,7 +101,7 @@ class DOPAccessTokenExchangeFormTestGoogle( ...@@ -97,7 +101,7 @@ class DOPAccessTokenExchangeFormTestGoogle(
# This is necessary because cms does not implement third party auth # This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @unittest.skipUnless(TPA_FEATURE_ENABLED, TPA_FEATURES_KEY + " not enabled")
@httpretty.activate @httpretty.activate
class DOTAccessTokenExchangeFormTestGoogle( class DOTAccessTokenExchangeFormTestGoogle(
DOTAdapterMixin, DOTAdapterMixin,
......
...@@ -17,11 +17,12 @@ import httpretty ...@@ -17,11 +17,12 @@ import httpretty
import provider.constants import provider.constants
from provider.oauth2.models import AccessToken, Client from provider.oauth2.models import AccessToken, Client
from rest_framework.test import APIClient from rest_framework.test import APIClient
from social_django.models import Partial
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle from third_party_auth.tests.utils import ThirdPartyOAuthTestMixinFacebook, ThirdPartyOAuthTestMixinGoogle
from .mixins import DOPAdapterMixin, DOTAdapterMixin from .mixins import DOPAdapterMixin, DOTAdapterMixin
from .utils import AccessTokenExchangeTestMixin from .utils import AccessTokenExchangeTestMixin, TPA_FEATURE_ENABLED, TPA_FEATURES_KEY
@ddt.ddt @ddt.ddt
...@@ -34,6 +35,10 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin): ...@@ -34,6 +35,10 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
self.url = reverse("exchange_access_token", kwargs={"backend": self.BACKEND}) self.url = reverse("exchange_access_token", kwargs={"backend": self.BACKEND})
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
def tearDown(self):
super(AccessTokenExchangeViewTest, self).tearDown()
Partial.objects.all().delete()
def _assert_error(self, data, expected_error, expected_error_description): def _assert_error(self, data, expected_error, expected_error_description):
response = self.csrf_client.post(self.url, data) response = self.csrf_client.post(self.url, data)
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
...@@ -42,7 +47,6 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin): ...@@ -42,7 +47,6 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
json.loads(response.content), json.loads(response.content),
{u"error": expected_error, u"error_description": expected_error_description} {u"error": expected_error, u"error_description": expected_error_description}
) )
self.assertNotIn("partial_pipeline", self.client.session)
def _assert_success(self, data, expected_scopes): def _assert_success(self, data, expected_scopes):
response = self.csrf_client.post(self.url, data) response = self.csrf_client.post(self.url, data)
...@@ -101,7 +105,7 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin): ...@@ -101,7 +105,7 @@ class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
# This is necessary because cms does not implement third party auth # This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @unittest.skipUnless(TPA_FEATURE_ENABLED, TPA_FEATURES_KEY + " not enabled")
@httpretty.activate @httpretty.activate
class DOPAccessTokenExchangeViewTestFacebook( class DOPAccessTokenExchangeViewTestFacebook(
DOPAdapterMixin, DOPAdapterMixin,
...@@ -115,7 +119,7 @@ class DOPAccessTokenExchangeViewTestFacebook( ...@@ -115,7 +119,7 @@ class DOPAccessTokenExchangeViewTestFacebook(
pass pass
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @unittest.skipUnless(TPA_FEATURE_ENABLED, TPA_FEATURES_KEY + " not enabled")
@httpretty.activate @httpretty.activate
class DOTAccessTokenExchangeViewTestFacebook( class DOTAccessTokenExchangeViewTestFacebook(
DOTAdapterMixin, DOTAdapterMixin,
...@@ -130,7 +134,7 @@ class DOTAccessTokenExchangeViewTestFacebook( ...@@ -130,7 +134,7 @@ class DOTAccessTokenExchangeViewTestFacebook(
# This is necessary because cms does not implement third party auth # This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @unittest.skipUnless(TPA_FEATURE_ENABLED, TPA_FEATURES_KEY + " not enabled")
@httpretty.activate @httpretty.activate
class DOPAccessTokenExchangeViewTestGoogle( class DOPAccessTokenExchangeViewTestGoogle(
DOPAdapterMixin, DOPAdapterMixin,
...@@ -146,7 +150,7 @@ class DOPAccessTokenExchangeViewTestGoogle( ...@@ -146,7 +150,7 @@ class DOPAccessTokenExchangeViewTestGoogle(
# This is necessary because cms does not implement third party auth # This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled") @unittest.skipUnless(TPA_FEATURE_ENABLED, TPA_FEATURES_KEY + " not enabled")
@httpretty.activate @httpretty.activate
class DOTAccessTokenExchangeViewTestGoogle( class DOTAccessTokenExchangeViewTestGoogle(
DOTAdapterMixin, DOTAdapterMixin,
......
...@@ -2,10 +2,16 @@ ...@@ -2,10 +2,16 @@
Test utilities for OAuth access token exchange Test utilities for OAuth access token exchange
""" """
from social.apps.django_app.default.models import UserSocialAuth from django.conf import settings
from social_django.models import UserSocialAuth, Partial
from third_party_auth.tests.utils import ThirdPartyOAuthTestMixin from third_party_auth.tests.utils import ThirdPartyOAuthTestMixin
TPA_FEATURES_KEY = 'ENABLE_THIRD_PARTY_AUTH'
TPA_FEATURE_ENABLED = TPA_FEATURES_KEY in settings.FEATURES
class AccessTokenExchangeTestMixin(ThirdPartyOAuthTestMixin): class AccessTokenExchangeTestMixin(ThirdPartyOAuthTestMixin):
""" """
A mixin to define test cases for access token exchange. The following A mixin to define test cases for access token exchange. The following
...@@ -86,16 +92,19 @@ class AccessTokenExchangeTestMixin(ThirdPartyOAuthTestMixin): ...@@ -86,16 +92,19 @@ class AccessTokenExchangeTestMixin(ThirdPartyOAuthTestMixin):
def test_no_linked_user(self): def test_no_linked_user(self):
UserSocialAuth.objects.all().delete() UserSocialAuth.objects.all().delete()
Partial.objects.all().delete()
self._setup_provider_response(success=True) self._setup_provider_response(success=True)
self._assert_error(self.data, "invalid_grant", "access_token is not valid") self._assert_error(self.data, "invalid_grant", "access_token is not valid")
def test_user_automatically_linked_by_email(self): def test_user_automatically_linked_by_email(self):
UserSocialAuth.objects.all().delete() UserSocialAuth.objects.all().delete()
Partial.objects.all().delete()
self._setup_provider_response(success=True, email=self.user.email) self._setup_provider_response(success=True, email=self.user.email)
self._assert_success(self.data, expected_scopes=[]) self._assert_success(self.data, expected_scopes=[])
def test_inactive_user_not_automatically_linked(self): def test_inactive_user_not_automatically_linked(self):
UserSocialAuth.objects.all().delete() UserSocialAuth.objects.all().delete()
Partial.objects.all().delete()
self._setup_provider_response(success=True, email=self.user.email) self._setup_provider_response(success=True, email=self.user.email)
self.user.is_active = False self.user.is_active = False
self.user.save() # pylint: disable=no-member self.user.save() # pylint: disable=no-member
......
...@@ -10,7 +10,7 @@ The following are currently implemented: ...@@ -10,7 +10,7 @@ The following are currently implemented:
# pylint: disable=abstract-method # pylint: disable=abstract-method
import django.contrib.auth as auth import django.contrib.auth as auth
import social.apps.django_app.utils as social_utils import social_django.utils as social_utils
from django.conf import settings from django.conf import settings
from django.contrib.auth import login from django.contrib.auth import login
from django.http import HttpResponse from django.http import HttpResponse
...@@ -37,7 +37,7 @@ class AccessTokenExchangeBase(APIView): ...@@ -37,7 +37,7 @@ class AccessTokenExchangeBase(APIView):
OAuth access token. OAuth access token.
""" """
@method_decorator(csrf_exempt) @method_decorator(csrf_exempt)
@method_decorator(social_utils.strategy("social:complete")) @method_decorator(social_utils.psa("social:complete"))
def dispatch(self, *args, **kwargs): def dispatch(self, *args, **kwargs):
return super(AccessTokenExchangeBase, self).dispatch(*args, **kwargs) return super(AccessTokenExchangeBase, self).dispatch(*args, **kwargs)
...@@ -137,11 +137,11 @@ class DOTAccessTokenExchangeView(AccessTokenExchangeBase, DOTAccessTokenView): ...@@ -137,11 +137,11 @@ class DOTAccessTokenExchangeView(AccessTokenExchangeBase, DOTAccessTokenView):
request.extra_credentials = None request.extra_credentials = None
request.grant_type = client.authorization_grant_type request.grant_type = client.authorization_grant_type
def error_response(self, form_errors): def error_response(self, form_errors, **kwargs):
""" """
Return an error response consisting of the errors in the form Return an error response consisting of the errors in the form
""" """
return Response(status=400, data=form_errors) return Response(status=400, data=form_errors, **kwargs)
class LoginWithAccessTokenView(APIView): class LoginWithAccessTokenView(APIView):
......
...@@ -268,7 +268,7 @@ class BookmarksListViewTests(BookmarksViewsTestsBase): ...@@ -268,7 +268,7 @@ class BookmarksListViewTests(BookmarksViewsTestsBase):
self.assertEqual(response.data['developer_message'], u'Parameter usage_id not provided.') self.assertEqual(response.data['developer_message'], u'Parameter usage_id not provided.')
# Send empty data dictionary. # Send empty data dictionary.
with self.assertNumQueries(7): # No queries for bookmark table. with self.assertNumQueries(8): # No queries for bookmark table.
response = self.send_post( response = self.send_post(
client=self.client, client=self.client,
url=reverse('bookmarks'), url=reverse('bookmarks'),
......
...@@ -174,7 +174,7 @@ class TestOwnUsernameAPI(CacheIsolationTestCase, UserAPITestCase): ...@@ -174,7 +174,7 @@ class TestOwnUsernameAPI(CacheIsolationTestCase, UserAPITestCase):
Test that a client (logged in) can get her own username. Test that a client (logged in) can get her own username.
""" """
self.client.login(username=self.user.username, password=TEST_PASSWORD) self.client.login(username=self.user.username, password=TEST_PASSWORD)
self._verify_get_own_username(14) self._verify_get_own_username(15)
def test_get_username_inactive(self): def test_get_username_inactive(self):
""" """
...@@ -184,7 +184,7 @@ class TestOwnUsernameAPI(CacheIsolationTestCase, UserAPITestCase): ...@@ -184,7 +184,7 @@ class TestOwnUsernameAPI(CacheIsolationTestCase, UserAPITestCase):
self.client.login(username=self.user.username, password=TEST_PASSWORD) self.client.login(username=self.user.username, password=TEST_PASSWORD)
self.user.is_active = False self.user.is_active = False
self.user.save() self.user.save()
self._verify_get_own_username(14) self._verify_get_own_username(15)
def test_get_username_not_logged_in(self): def test_get_username_not_logged_in(self):
""" """
...@@ -193,7 +193,7 @@ class TestOwnUsernameAPI(CacheIsolationTestCase, UserAPITestCase): ...@@ -193,7 +193,7 @@ class TestOwnUsernameAPI(CacheIsolationTestCase, UserAPITestCase):
""" """
# verify that the endpoint is inaccessible when not logged in # verify that the endpoint is inaccessible when not logged in
self._verify_get_own_username(12, expected_status=401) self._verify_get_own_username(13, expected_status=401)
@ddt.ddt @ddt.ddt
...@@ -305,7 +305,7 @@ class TestAccountsAPI(CacheIsolationTestCase, UserAPITestCase): ...@@ -305,7 +305,7 @@ class TestAccountsAPI(CacheIsolationTestCase, UserAPITestCase):
""" """
self.different_client.login(username=self.different_user.username, password=TEST_PASSWORD) self.different_client.login(username=self.different_user.username, password=TEST_PASSWORD)
self.create_mock_profile(self.user) self.create_mock_profile(self.user)
with self.assertNumQueries(18): with self.assertNumQueries(19):
response = self.send_get(self.different_client) response = self.send_get(self.different_client)
self._verify_full_shareable_account_response(response, account_privacy=ALL_USERS_VISIBILITY) self._verify_full_shareable_account_response(response, account_privacy=ALL_USERS_VISIBILITY)
...@@ -320,7 +320,7 @@ class TestAccountsAPI(CacheIsolationTestCase, UserAPITestCase): ...@@ -320,7 +320,7 @@ class TestAccountsAPI(CacheIsolationTestCase, UserAPITestCase):
""" """
self.different_client.login(username=self.different_user.username, password=TEST_PASSWORD) self.different_client.login(username=self.different_user.username, password=TEST_PASSWORD)
self.create_mock_profile(self.user) self.create_mock_profile(self.user)
with self.assertNumQueries(18): with self.assertNumQueries(19):
response = self.send_get(self.different_client) response = self.send_get(self.different_client)
self._verify_private_account_response(response, account_privacy=PRIVATE_VISIBILITY) self._verify_private_account_response(response, account_privacy=PRIVATE_VISIBILITY)
...@@ -395,12 +395,12 @@ class TestAccountsAPI(CacheIsolationTestCase, UserAPITestCase): ...@@ -395,12 +395,12 @@ class TestAccountsAPI(CacheIsolationTestCase, UserAPITestCase):
self.assertEqual(False, data["accomplishments_shared"]) self.assertEqual(False, data["accomplishments_shared"])
self.client.login(username=self.user.username, password=TEST_PASSWORD) self.client.login(username=self.user.username, password=TEST_PASSWORD)
verify_get_own_information(16) verify_get_own_information(17)
# Now make sure that the user can get the same information, even if not active # Now make sure that the user can get the same information, even if not active
self.user.is_active = False self.user.is_active = False
self.user.save() self.user.save()
verify_get_own_information(10) verify_get_own_information(11)
def test_get_account_empty_string(self): def test_get_account_empty_string(self):
""" """
...@@ -414,7 +414,7 @@ class TestAccountsAPI(CacheIsolationTestCase, UserAPITestCase): ...@@ -414,7 +414,7 @@ class TestAccountsAPI(CacheIsolationTestCase, UserAPITestCase):
legacy_profile.save() legacy_profile.save()
self.client.login(username=self.user.username, password=TEST_PASSWORD) self.client.login(username=self.user.username, password=TEST_PASSWORD)
with self.assertNumQueries(16): with self.assertNumQueries(17):
response = self.send_get(self.client) response = self.send_get(self.client)
for empty_field in ("level_of_education", "gender", "country", "bio"): for empty_field in ("level_of_education", "gender", "country", "bio"):
self.assertIsNone(response.data[empty_field]) self.assertIsNone(response.data[empty_field])
......
...@@ -16,7 +16,7 @@ from django.test.testcases import TransactionTestCase ...@@ -16,7 +16,7 @@ from django.test.testcases import TransactionTestCase
from django.test.utils import override_settings from django.test.utils import override_settings
from opaque_keys.edx.locations import SlashSeparatedCourseKey from opaque_keys.edx.locations import SlashSeparatedCourseKey
from pytz import common_timezones_set, UTC from pytz import common_timezones_set, UTC
from social.apps.django_app.default.models import UserSocialAuth from social_django.models import UserSocialAuth, Partial
from django_comment_common import models from django_comment_common import models
from openedx.core.djangoapps.site_configuration.helpers import get_value from openedx.core.djangoapps.site_configuration.helpers import get_value
...@@ -1976,6 +1976,10 @@ class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin, CacheIsolationTe ...@@ -1976,6 +1976,10 @@ class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin, CacheIsolationTe
super(ThirdPartyRegistrationTestMixin, self).setUp() super(ThirdPartyRegistrationTestMixin, self).setUp()
self.url = reverse('user_api_registration') self.url = reverse('user_api_registration')
def tearDown(self):
super(ThirdPartyRegistrationTestMixin, self).tearDown()
Partial.objects.all().delete()
def data(self, user=None): def data(self, user=None):
"""Returns the request data for the endpoint.""" """Returns the request data for the endpoint."""
return { return {
...@@ -1996,7 +2000,6 @@ class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin, CacheIsolationTe ...@@ -1996,7 +2000,6 @@ class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin, CacheIsolationTe
for conflict_attribute in ["username", "email"]: for conflict_attribute in ["username", "email"]:
self.assertIn(conflict_attribute, errors) self.assertIn(conflict_attribute, errors)
self.assertIn("belongs to an existing account", errors[conflict_attribute][0]["user_message"]) self.assertIn("belongs to an existing account", errors[conflict_attribute][0]["user_message"])
self.assertNotIn("partial_pipeline", self.client.session)
def _assert_access_token_error(self, response, expected_error_message): def _assert_access_token_error(self, response, expected_error_message):
"""Assert that the given response was an error for the access_token field with the given error message.""" """Assert that the given response was an error for the access_token field with the given error message."""
...@@ -2006,7 +2009,6 @@ class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin, CacheIsolationTe ...@@ -2006,7 +2009,6 @@ class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin, CacheIsolationTe
response_json, response_json,
{"access_token": [{"user_message": expected_error_message}]} {"access_token": [{"user_message": expected_error_message}]}
) )
self.assertNotIn("partial_pipeline", self.client.session)
def _assert_third_party_session_expired_error(self, response, expected_error_message): def _assert_third_party_session_expired_error(self, response, expected_error_message):
"""Assert that given response is an error due to third party session expiry""" """Assert that given response is an error due to third party session expiry"""
...@@ -2109,8 +2111,6 @@ class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin, CacheIsolationTe ...@@ -2109,8 +2111,6 @@ class ThirdPartyRegistrationTestMixin(ThirdPartyOAuthTestMixin, CacheIsolationTe
# to identify that request is made using browser # to identify that request is made using browser
data.update({"social_auth_provider": "Google"}) data.update({"social_auth_provider": "Google"})
response = self.client.post(self.url, data) response = self.client.post(self.url, data)
# NO partial_pipeline in session means pipeline is expired
self.assertNotIn("partial_pipeline", self.client.session)
self._assert_third_party_session_expired_error( self._assert_third_party_session_expired_error(
response, response,
u"Registration using {provider} has timed out.".format(provider="Google") u"Registration using {provider} has timed out.".format(provider="Google")
...@@ -2127,7 +2127,7 @@ class TestFacebookRegistrationView( ...@@ -2127,7 +2127,7 @@ class TestFacebookRegistrationView(
def test_social_auth_exception(self): def test_social_auth_exception(self):
""" """
According to the do_auth method in social.backends.facebook.py, According to the do_auth method in social_core.backends.facebook.py,
the Facebook API sometimes responds back a JSON with just False as value. the Facebook API sometimes responds back a JSON with just False as value.
""" """
self._setup_provider_response_with_body(200, json.dumps("false")) self._setup_provider_response_with_body(200, json.dumps("false"))
......
...@@ -89,7 +89,7 @@ class TestCourseHomePage(SharedModuleStoreTestCase): ...@@ -89,7 +89,7 @@ class TestCourseHomePage(SharedModuleStoreTestCase):
course_home_url(self.course) course_home_url(self.course)
# Fetch the view and verify the query counts # Fetch the view and verify the query counts
with self.assertNumQueries(47): with self.assertNumQueries(48):
with check_mongo_calls(5): with check_mongo_calls(5):
url = course_home_url(self.course) url = course_home_url(self.course)
self.client.get(url) self.client.get(url)
...@@ -124,7 +124,7 @@ class TestCourseUpdatesPage(SharedModuleStoreTestCase): ...@@ -124,7 +124,7 @@ class TestCourseUpdatesPage(SharedModuleStoreTestCase):
course_updates_url(self.course) course_updates_url(self.course)
# Fetch the view and verify that the query counts haven't changed # Fetch the view and verify that the query counts haven't changed
with self.assertNumQueries(35): with self.assertNumQueries(36):
with check_mongo_calls(4): with check_mongo_calls(4):
url = course_updates_url(self.course) url = course_updates_url(self.course)
self.client.get(url) self.client.get(url)
...@@ -320,7 +320,7 @@ def insert_enterprise_pipeline_elements(pipeline): ...@@ -320,7 +320,7 @@ def insert_enterprise_pipeline_elements(pipeline):
'enterprise.tpa_pipeline.handle_enterprise_logistration', 'enterprise.tpa_pipeline.handle_enterprise_logistration',
) )
# Find the item we need to insert the data sharing consent elements before # Find the item we need to insert the data sharing consent elements before
insert_point = pipeline.index('social.pipeline.social_auth.load_extra_data') insert_point = pipeline.index('social_core.pipeline.social_auth.load_extra_data')
for index, element in enumerate(additional_elements): for index, element in enumerate(additional_elements):
pipeline.insert(insert_point + index, element) pipeline.insert(insert_point + index, element)
......
...@@ -40,11 +40,11 @@ class TestEnterpriseApi(unittest.TestCase): ...@@ -40,11 +40,11 @@ class TestEnterpriseApi(unittest.TestCase):
the utilities to return the expected values. the utilities to return the expected values.
""" """
self.assertTrue(enterprise_enabled()) self.assertTrue(enterprise_enabled())
pipeline = ['abc', 'social.pipeline.social_auth.load_extra_data', 'def'] pipeline = ['abc', 'social_core.pipeline.social_auth.load_extra_data', 'def']
insert_enterprise_pipeline_elements(pipeline) insert_enterprise_pipeline_elements(pipeline)
self.assertEqual(pipeline, ['abc', self.assertEqual(pipeline, ['abc',
'enterprise.tpa_pipeline.handle_enterprise_logistration', 'enterprise.tpa_pipeline.handle_enterprise_logistration',
'social.pipeline.social_auth.load_extra_data', 'social_core.pipeline.social_auth.load_extra_data',
'def']) 'def'])
@override_settings(ENABLE_ENTERPRISE_INTEGRATION=True) @override_settings(ENABLE_ENTERPRISE_INTEGRATION=True)
......
...@@ -52,7 +52,7 @@ edx-lint==0.4.3 ...@@ -52,7 +52,7 @@ edx-lint==0.4.3
astroid==1.3.8 astroid==1.3.8
edx-django-oauth2-provider==1.1.4 edx-django-oauth2-provider==1.1.4
edx-django-sites-extensions==2.1.1 edx-django-sites-extensions==2.1.1
edx-enterprise==0.35.2 edx-enterprise==0.36.1
edx-oauth2-provider==1.2.0 edx-oauth2-provider==1.2.0
edx-opaque-keys==0.4.0 edx-opaque-keys==0.4.0
edx-organizations==0.4.4 edx-organizations==0.4.4
...@@ -92,10 +92,8 @@ python-memcached==1.48 ...@@ -92,10 +92,8 @@ python-memcached==1.48
django-memcached-hashring==0.1.2 django-memcached-hashring==0.1.2
python-openid==2.2.5 python-openid==2.2.5
python-dateutil==2.1 python-dateutil==2.1
# We need to be able to set a maximum session length on our third-party auth providers; social-auth-app-django==1.2.0
# our goal is to upstream these changes and return to the canonical version ASAP. social-auth-core==1.4.0
# python-social-auth==0.2.21
git+https://github.com/edx/python-social-auth@758985102cee98f440fae44ed99617b7cfef3473#egg=python-social-auth==0.2.21.edx.a
pytz==2016.7 pytz==2016.7
pysrt==0.4.7 pysrt==0.4.7
PyYAML==3.12 PyYAML==3.12
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment