Commit b98956bd by Braden MacDonald

Merge pull request #9672 from open-craft/tpa-providers-api

Initial implementation of API for listing a user's third party auth providers
parents 5ad3ed4d 4be8aa5d
"""
Tests for the Third Party Auth REST API
"""
import json
import unittest
import ddt
from mock import patch
from django.test import Client
from django.core.urlresolvers import reverse
from rest_framework.test import APITestCase
from rest_framework import status
from django.conf import settings
from django.test.utils import override_settings
from util.testing import UrlResetMixin
from openedx.core.lib.django_test_client_utils import get_absolute_url
from social.apps.django_app.default.models import UserSocialAuth
from student.tests.factories import UserFactory
from third_party_auth.tests.testutil import ThirdPartyAuthTestMixin
VALID_API_KEY = "i am a key"
@override_settings(EDX_API_KEY=VALID_API_KEY)
@ddt.ddt
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class ThirdPartyAuthAPITests(ThirdPartyAuthTestMixin, APITestCase):
"""
Test the Third Party Auth REST API
"""
ALICE_USERNAME = "alice"
CARL_USERNAME = "carl"
STAFF_USERNAME = "staff"
ADMIN_USERNAME = "admin"
# These users will be created and linked to third party accounts:
LINKED_USERS = (ALICE_USERNAME, STAFF_USERNAME, ADMIN_USERNAME)
PASSWORD = "edx"
def setUp(self):
""" Create users for use in the tests """
super(ThirdPartyAuthAPITests, self).setUp()
google = self.configure_google_provider(enabled=True)
self.configure_facebook_provider(enabled=True)
self.configure_linkedin_provider(enabled=False)
self.enable_saml()
testshib = self.configure_saml_provider(name='TestShib', enabled=True, idp_slug='testshib')
# Create several users and link each user to Google and TestShib
for username in self.LINKED_USERS:
make_superuser = (username == self.ADMIN_USERNAME)
make_staff = (username == self.STAFF_USERNAME) or make_superuser
user = UserFactory.create(
username=username,
password=self.PASSWORD,
is_staff=make_staff,
is_superuser=make_superuser
)
UserSocialAuth.objects.create(
user=user,
provider=google.backend_name,
uid='{}@gmail.com'.format(username),
)
UserSocialAuth.objects.create(
user=user,
provider=testshib.backend_name,
uid='{}:{}'.format(testshib.idp_slug, username),
)
# Create another user not linked to any providers:
UserFactory.create(username=self.CARL_USERNAME, password=self.PASSWORD)
def expected_active(self, username):
""" The JSON active providers list response expected for the given user """
if username not in self.LINKED_USERS:
return []
return [
{
"provider_id": "oa2-google-oauth2",
"name": "Google",
"remote_id": "{}@gmail.com".format(username),
},
{
"provider_id": "saml-testshib",
"name": "TestShib",
# The "testshib:" prefix is stored in the UserSocialAuth.uid field but should
# not be present in the 'remote_id', since that's an implementation detail:
"remote_id": username,
},
]
@ddt.data(
# Any user can query their own list of providers
(ALICE_USERNAME, ALICE_USERNAME, 200),
(CARL_USERNAME, CARL_USERNAME, 200),
# A regular user cannot query another user nor deduce the existence of users based on the status code
(ALICE_USERNAME, STAFF_USERNAME, 403),
(ALICE_USERNAME, "nonexistent_user", 403),
# Even Staff cannot query other users
(STAFF_USERNAME, ALICE_USERNAME, 403),
# But admins can
(ADMIN_USERNAME, ALICE_USERNAME, 200),
(ADMIN_USERNAME, CARL_USERNAME, 200),
(ADMIN_USERNAME, "invalid_username", 404),
)
@ddt.unpack
def test_list_connected_providers(self, request_user, target_user, expect_result):
self.client.login(username=request_user, password=self.PASSWORD)
url = reverse('third_party_auth_users_api', kwargs={'username': target_user})
response = self.client.get(url)
self.assertEqual(response.status_code, expect_result)
if expect_result == 200:
self.assertIn("active", response.data)
self.assertItemsEqual(response.data["active"], self.expected_active(target_user))
@ddt.data(
# A server with a valid API key can query any user's list of providers
(VALID_API_KEY, ALICE_USERNAME, 200),
(VALID_API_KEY, "invalid_username", 404),
("i am an invalid key", ALICE_USERNAME, 403),
(None, ALICE_USERNAME, 403),
)
@ddt.unpack
def test_list_connected_providers__withapi_key(self, api_key, target_user, expect_result):
url = reverse('third_party_auth_users_api', kwargs={'username': target_user})
response = self.client.get(url, HTTP_X_EDX_API_KEY=api_key)
self.assertEqual(response.status_code, expect_result)
if expect_result == 200:
self.assertIn("active", response.data)
self.assertItemsEqual(response.data["active"], self.expected_active(target_user))
""" URL configuration for the third party auth API """
from django.conf.urls import patterns, url
from .views import UserView
USERNAME_PATTERN = r'(?P<username>[\w.+-]+)'
urlpatterns = patterns(
'',
url(r'^v0/users/' + USERNAME_PATTERN + '$', UserView.as_view(), name='third_party_auth_users_api'),
)
"""
Third Party Auth REST API views
"""
from django.contrib.auth.models import User
from openedx.core.lib.api.authentication import (
OAuth2AuthenticationAllowInactiveUser,
SessionAuthenticationAllowInactiveUser,
)
from openedx.core.lib.api.permissions import (
ApiKeyHeaderPermission,
)
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView
from third_party_auth import pipeline
class UserView(APIView):
"""
List the third party auth accounts linked to the specified user account.
**Example Request**
GET /api/third_party_auth/v0/users/{username}
**Response Values**
If the request for information about the user is successful, an HTTP 200 "OK" response
is returned.
The HTTP 200 response has the following values.
* active: A list of all the third party auth providers currently linked
to the given user's account. Each object in this list has the
following attributes:
* provider_id: The unique identifier of this provider (string)
* name: The name of this provider (string)
* remote_id: The ID of the user according to the provider. This ID
is what is used to link the user to their edX account during
login.
"""
authentication_classes = (
# Users may want to view/edit the providers used for authentication before they've
# activated their account, so we allow inactive users.
OAuth2AuthenticationAllowInactiveUser,
SessionAuthenticationAllowInactiveUser,
)
def get(self, request, username):
"""Create, read, or update enrollment information for a user.
HTTP Endpoint for all CRUD operations for a user course enrollment. Allows creation, reading, and
updates of the current enrollment for a particular course.
Args:
request (Request): The HTTP GET request
username (str): Fetch the list of providers linked to this user
Return:
JSON serialized list of the providers linked to this user.
"""
if request.user.username != username:
# We are querying permissions for a user other than the current user.
if not request.user.is_superuser and not ApiKeyHeaderPermission().has_permission(request, self):
# Return a 403 (Unauthorized) without validating 'username', so that we
# do not let users probe the existence of other user accounts.
return Response(status=status.HTTP_403_FORBIDDEN)
try:
user = User.objects.get(username=username)
except User.DoesNotExist:
return Response(status=status.HTTP_404_NOT_FOUND)
providers = pipeline.get_provider_user_states(user)
active_providers = [
{
"provider_id": assoc.provider.provider_id,
"name": assoc.provider.name,
"remote_id": assoc.remote_id,
}
for assoc in providers if assoc.has_account
]
# In the future this can be trivially modified to return the inactive/disconnected providers as well.
return Response({
"active": active_providers
})
......@@ -130,6 +130,12 @@ class ProviderConfig(ConfigurationModel):
""" Is this provider being used for this UserSocialAuth entry? """
return self.backend_name == social_auth.provider
def get_remote_id_from_social_auth(self, social_auth):
""" Given a UserSocialAuth object, return the remote ID used by this provider. """
# This is generally the same thing as the UID, expect when one backend is used for multiple providers
assert self.match_social_auth(social_auth)
return social_auth.uid
@classmethod
def get_register_form_data(cls, pipeline_kwargs):
"""Gets dict of data to display on the register form.
......@@ -293,6 +299,12 @@ class SAMLProviderConfig(ProviderConfig):
prefix = self.idp_slug + ":"
return self.backend_name == social_auth.provider and social_auth.uid.startswith(prefix)
def get_remote_id_from_social_auth(self, social_auth):
""" Given a UserSocialAuth object, return the remote ID used by this provider. """
assert self.match_social_auth(social_auth)
# Remove the prefix from the UID
return social_auth.uid[len(self.idp_slug) + 1:]
def get_config(self):
"""
Return a SAMLIdentityProvider instance for use by SAMLAuthBackend.
......@@ -508,6 +520,12 @@ class LTIProviderConfig(ProviderConfig):
prefix = self.lti_consumer_key + ":"
return self.backend_name == social_auth.provider and social_auth.uid.startswith(prefix)
def get_remote_id_from_social_auth(self, social_auth):
""" Given a UserSocialAuth object, return the remote ID used by this provider. """
assert self.match_social_auth(social_auth)
# Remove the prefix from the UID
return social_auth.uid[len(self.lti_consumer_key) + 1:]
def is_active_for_pipeline(self, pipeline):
""" Is this provider being used for the specified pipeline? """
try:
......
......@@ -170,11 +170,17 @@ class ProviderUserState(object):
lms/templates/dashboard.html.
"""
def __init__(self, enabled_provider, user, association_id=None):
# UserSocialAuth row ID
self.association_id = association_id
def __init__(self, enabled_provider, user, association):
# Boolean. Whether the user has an account associated with the provider
self.has_account = association_id is not None
self.has_account = association is not None
if self.has_account:
# UserSocialAuth row ID
self.association_id = association.id
# Identifier of this user according to the remote provider:
self.remote_id = enabled_provider.get_remote_id_from_social_auth(association)
else:
self.association_id = None
self.remote_id = None
# provider.BaseProvider child. Callers must verify that the provider is
# enabled.
self.provider = enabled_provider
......@@ -367,14 +373,14 @@ def get_provider_user_states(user):
found_user_auths = list(models.DjangoStorage.user.get_social_auth_for_user(user))
for enabled_provider in provider.Registry.enabled():
association_id = None
association = None
for auth in found_user_auths:
if enabled_provider.match_social_auth(auth):
association_id = auth.id
association = auth
break
if enabled_provider.accepts_logins or association_id:
if enabled_provider.accepts_logins or association:
states.append(
ProviderUserState(enabled_provider, user, association_id)
ProviderUserState(enabled_provider, user, association)
)
return states
......
......@@ -41,5 +41,5 @@ class ProviderUserStateTestCase(testutil.TestCase):
def test_get_unlink_form_name(self):
google_provider = self.configure_google_provider(enabled=True)
state = pipeline.ProviderUserState(google_provider, object(), 1000)
state = pipeline.ProviderUserState(google_provider, object(), None)
self.assertEqual(google_provider.provider_id + '_unlink_form', state.get_unlink_form_name())
......@@ -653,6 +653,7 @@ if settings.FEATURES.get('AUTOMATIC_AUTH_FOR_TESTING'):
if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
urlpatterns += (
url(r'', include('third_party_auth.urls')),
url(r'api/third_party_auth/', include('third_party_auth.api.urls')),
# NOTE: The following login_oauth_token endpoint is DEPRECATED.
# Please use the exchange_access_token endpoint instead.
url(r'^login_oauth_token/(?P<backend>[^/]+)/$', 'student.views.login_oauth_token'),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment