Commit 698e542f by Nimisha Asthagiri

Merge pull request #9842 from ubc/tpa-mapping-api

Implement third party auth ID mapping API
parents dba67df2 f6930437
...@@ -2,12 +2,21 @@ ...@@ -2,12 +2,21 @@
""" """
Admin site configuration for third party authentication Admin site configuration for third party authentication
""" """
from django import forms
from django.contrib import admin from django.contrib import admin
from config_models.admin import ConfigurationModelAdmin, KeyedConfigurationModelAdmin from config_models.admin import ConfigurationModelAdmin, KeyedConfigurationModelAdmin
from .models import OAuth2ProviderConfig, SAMLProviderConfig, SAMLConfiguration, SAMLProviderData, LTIProviderConfig from .models import (
OAuth2ProviderConfig,
SAMLProviderConfig,
SAMLConfiguration,
SAMLProviderData,
LTIProviderConfig,
ProviderApiPermissions
)
from .tasks import fetch_saml_metadata from .tasks import fetch_saml_metadata
from third_party_auth.provider import Registry
class OAuth2ProviderConfigAdmin(KeyedConfigurationModelAdmin): class OAuth2ProviderConfigAdmin(KeyedConfigurationModelAdmin):
...@@ -111,3 +120,26 @@ class LTIProviderConfigAdmin(KeyedConfigurationModelAdmin): ...@@ -111,3 +120,26 @@ class LTIProviderConfigAdmin(KeyedConfigurationModelAdmin):
) )
admin.site.register(LTIProviderConfig, LTIProviderConfigAdmin) admin.site.register(LTIProviderConfig, LTIProviderConfigAdmin)
class ApiPermissionsAdminForm(forms.ModelForm):
""" Django admin form for ApiPermissions model """
class Meta(object): # pylint: disable=missing-docstring
model = ProviderApiPermissions
provider_id = forms.ChoiceField(choices=[], required=True)
def __init__(self, *args, **kwargs):
super(ApiPermissionsAdminForm, self).__init__(*args, **kwargs)
self.fields['provider_id'].choices = (
(provider.provider_id, "{} ({})".format(provider.name, provider.provider_id))
for provider in Registry.enabled()
)
class ApiPermissionsAdmin(admin.ModelAdmin):
""" Django Admin class for ApiPermissions """
list_display = ('client', 'provider_id')
form = ApiPermissionsAdminForm
admin.site.register(ProviderApiPermissions, ApiPermissionsAdmin)
"""
Third party auth API related permissions
"""
from rest_framework import permissions
from third_party_auth.models import ProviderApiPermissions
class ThirdPartyAuthProviderApiPermission(permissions.BasePermission):
"""
Allow someone to access the view if they have valid OAuth client credential.
"""
def __init__(self, provider_id):
""" Initialize the class with a provider_id """
self.provider_id = provider_id
def has_permission(self, request, view):
"""
Check if the OAuth client associated with auth token in current request has permission to access
the information for provider
"""
if not request.auth or not self.provider_id:
# doesn't have access token or no provider_id specified
return False
try:
ProviderApiPermissions.objects.get(client__pk=request.auth.client_id, provider_id=self.provider_id)
except ProviderApiPermissions.DoesNotExist:
return False
return True
""" Django REST Framework Serializers """
from rest_framework import serializers
class UserMappingSerializer(serializers.Serializer): # pylint: disable=abstract-method
""" Serializer for User Mapping"""
provider = None
username = serializers.SerializerMethodField()
remote_id = serializers.SerializerMethodField()
def __init__(self, *args, **kwargs):
self.provider = kwargs['context'].get('provider', None)
super(UserMappingSerializer, self).__init__(*args, **kwargs)
def get_username(self, social_user):
""" Gets the edx username from a social user """
return social_user.user.username
def get_remote_id(self, social_user):
""" Gets remote id from social user based on provider """
return self.provider.get_remote_id_from_social_auth(social_user)
"""
Tests for the Third Party Auth permissions
"""
import unittest
import ddt
from mock import Mock
from rest_framework.test import APITestCase
from django.conf import settings
from third_party_auth.api.permissions import ThirdPartyAuthProviderApiPermission
from third_party_auth.tests.testutil import ThirdPartyAuthTestMixin
IDP_SLUG_TESTSHIB = 'testshib'
PROVIDER_ID_TESTSHIB = 'saml-' + IDP_SLUG_TESTSHIB
@ddt.ddt
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class ThirdPartyAuthApiPermissionTest(ThirdPartyAuthTestMixin, APITestCase):
""" Tests for third party auth API permission """
def setUp(self):
""" Create users and oauth client for use in the tests """
super(ThirdPartyAuthApiPermissionTest, self).setUp()
client = self.configure_oauth_client()
self.configure_api_permission(client, PROVIDER_ID_TESTSHIB)
@ddt.data(
(1, PROVIDER_ID_TESTSHIB, True),
(1, 'invalid-provider-id', False),
(999, PROVIDER_ID_TESTSHIB, False),
(999, 'invalid-provider-id', False),
(1, None, False),
)
@ddt.unpack
def test_api_permission(self, client_pk, provider_id, expect):
request = Mock()
request.auth = Mock()
request.auth.client_id = client_pk
result = ThirdPartyAuthProviderApiPermission(provider_id).has_permission(request, None)
self.assertEqual(result, expect)
def test_api_permission_unauthorized_client(self):
client = self.configure_oauth_client()
self.configure_api_permission(client, 'saml-anotherprovider')
request = Mock()
request.auth = Mock()
request.auth.client_id = client.pk
result = ThirdPartyAuthProviderApiPermission(PROVIDER_ID_TESTSHIB).has_permission(request, None)
self.assertEqual(result, False)
...@@ -2,11 +2,14 @@ ...@@ -2,11 +2,14 @@
from django.conf.urls import patterns, url from django.conf.urls import patterns, url
from .views import UserView from .views import UserView, UserMappingView
USERNAME_PATTERN = r'(?P<username>[\w.+-]+)' USERNAME_PATTERN = r'(?P<username>[\w.+-]+)'
PROVIDER_PATTERN = r'(?P<provider_id>[\w.+-]+)(?:\:(?P<idp_slug>[\w.+-]+))?'
urlpatterns = patterns( urlpatterns = patterns(
'', '',
url(r'^v0/users/' + USERNAME_PATTERN + '$', UserView.as_view(), name='third_party_auth_users_api'), url(r'^v0/users/' + USERNAME_PATTERN + '$', UserView.as_view(), name='third_party_auth_users_api'),
url(r'^v0/providers/' + PROVIDER_PATTERN + '/users$', UserMappingView.as_view(),
name='third_party_auth_user_mapping_api'),
) )
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
Third Party Auth REST API views Third Party Auth REST API views
""" """
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db.models import Q
from django.http import Http404
from rest_framework.generics import ListAPIView
from rest_framework_oauth.authentication import OAuth2Authentication
from social.apps.django_app.default.models import UserSocialAuth
from openedx.core.lib.api.authentication import ( from openedx.core.lib.api.authentication import (
OAuth2AuthenticationAllowInactiveUser, OAuth2AuthenticationAllowInactiveUser,
SessionAuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser,
...@@ -9,10 +14,13 @@ from openedx.core.lib.api.authentication import ( ...@@ -9,10 +14,13 @@ from openedx.core.lib.api.authentication import (
from openedx.core.lib.api.permissions import ( from openedx.core.lib.api.permissions import (
ApiKeyHeaderPermission, ApiKeyHeaderPermission,
) )
from rest_framework import status from rest_framework import status, exceptions
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from third_party_auth import pipeline from third_party_auth import pipeline
from third_party_auth.api import serializers
from third_party_auth.api.permissions import ThirdPartyAuthProviderApiPermission
from third_party_auth.provider import Registry
class UserView(APIView): class UserView(APIView):
...@@ -89,3 +97,136 @@ class UserView(APIView): ...@@ -89,3 +97,136 @@ class UserView(APIView):
return Response({ return Response({
"active": active_providers "active": active_providers
}) })
class UserMappingView(ListAPIView):
"""
Map between the third party auth account IDs (remote_id) and EdX username.
This API is intended to be a server-to-server endpoint. An on-campus middleware or system should consume this.
**Use Case**
Get a paginated list of mappings between edX users and remote user IDs for all users currently
linked to the given backend.
The list can be filtered by edx username or third party ids. The filter is limited by the max length of URL.
It is suggested to query no more than 50 usernames or remote_ids in each request to stay within above
limitation
The page size can be changed by specifying `page_size` parameter in the request.
**Example Requests**
GET /api/third_party_auth/v0/providers/{provider_id}/users
GET /api/third_party_auth/v0/providers/{provider_id}/users?username={username1},{username2}
GET /api/third_party_auth/v0/providers/{provider_id}/users?username={username1}&usernames={username2}
GET /api/third_party_auth/v0/providers/{provider_id}/users?remote_id={remote_id1},{remote_id2}
GET /api/third_party_auth/v0/providers/{provider_id}/users?remote_id={remote_id1}&remote_id={remote_id2}
GET /api/third_party_auth/v0/providers/{provider_id}/users?username={username1}&remote_id={remote_id1}
**URL Parameters**
* provider_id: The unique identifier of third_party_auth provider (e.g. "saml-ubc", "oa2-google", etc.
This is not the same thing as the backend_name.). (Optional/future: We may also want to allow
this to be an 'external domain' like 'ssl:MIT' so that this API can also search the legacy
ExternalAuthMap table used by Standford/MIT)
**Query Parameters**
* remote_ids: Optional. List of comma separated remote (third party) user IDs to filter the result set.
e.g. ?remote_ids=8721384623
* usernames: Optional. List of comma separated edX usernames to filter the result set.
e.g. ?usernames=bob123,jane456
* page, page_size: Optional. Used for paging the result set, especially when getting
an unfiltered list.
**Response Values**
If the request for information about the user is successful, an HTTP 200 "OK" response
is returned.
The HTTP 200 response has the following values:
* count: The number of mappings for the backend.
* next: The URI to the next page of the mappings.
* previous: The URI to the previous page of the mappings.
* num_pages: The number of pages listing the mappings.
* results: A list of mappings returned. Each collection in the list
contains these fields.
* username: The edx username
* remote_id: The Id from third party auth provider
"""
authentication_classes = (
OAuth2Authentication,
)
serializer_class = serializers.UserMappingSerializer
provider = None
def get_queryset(self):
provider_id = self.kwargs.get('provider_id')
# permission checking. We allow both API_KEY access and OAuth2 client credential access
if not (
self.request.user.is_superuser or ApiKeyHeaderPermission().has_permission(self.request, self) or
ThirdPartyAuthProviderApiPermission(provider_id).has_permission(self.request, self)
):
raise exceptions.PermissionDenied()
# provider existence checking
self.provider = Registry.get(provider_id)
if not self.provider:
raise Http404
query_set = UserSocialAuth.objects.select_related('user').filter(provider=self.provider.backend_name)
# build our query filters
# When using multi-IdP backend, we only retrieve the ones that are for current IdP.
# test if the current provider has a slug
uid = self.provider.get_social_auth_uid('uid')
if uid is not 'uid':
# if yes, we add a filter for the slug on uid column
query_set = query_set.filter(uid__startswith=uid[:-3])
query = Q()
usernames = self.request.QUERY_PARAMS.getlist('username', None)
remote_ids = self.request.QUERY_PARAMS.getlist('remote_id', None)
if usernames:
usernames = ','.join(usernames)
usernames = set(usernames.split(',')) if usernames else set()
if len(usernames):
query = query | Q(user__username__in=usernames)
if remote_ids:
remote_ids = ','.join(remote_ids)
remote_ids = set(remote_ids.split(',')) if remote_ids else set()
if len(remote_ids):
query = query | Q(uid__in=[self.provider.get_social_auth_uid(remote_id) for remote_id in remote_ids])
return query_set.filter(query)
def get_serializer_context(self):
"""
Extra context provided to the serializer class with current provider. We need the provider to
remove idp_slug from the remote_id if there is any
"""
context = super(UserMappingView, self).get_serializer_context()
context['provider'] = self.provider
return context
...@@ -14,6 +14,7 @@ from django.utils.translation import ugettext_lazy as _ ...@@ -14,6 +14,7 @@ from django.utils.translation import ugettext_lazy as _
import json import json
import logging import logging
from provider.utils import long_token from provider.utils import long_token
from provider.oauth2.models import Client
from social.backends.base import BaseAuth from social.backends.base import BaseAuth
from social.backends.oauth import OAuthAuth from social.backends.oauth import OAuthAuth
from social.backends.saml import SAMLAuth, SAMLIdentityProvider from social.backends.saml import SAMLAuth, SAMLIdentityProvider
...@@ -136,6 +137,14 @@ class ProviderConfig(ConfigurationModel): ...@@ -136,6 +137,14 @@ class ProviderConfig(ConfigurationModel):
assert self.match_social_auth(social_auth) assert self.match_social_auth(social_auth)
return social_auth.uid return social_auth.uid
def get_social_auth_uid(self, remote_id):
"""
Return the uid in social auth.
This is default implementation. Subclass may override with a different one.
"""
return remote_id
@classmethod @classmethod
def get_register_form_data(cls, pipeline_kwargs): def get_register_form_data(cls, pipeline_kwargs):
"""Gets dict of data to display on the register form. """Gets dict of data to display on the register form.
...@@ -305,6 +314,10 @@ class SAMLProviderConfig(ProviderConfig): ...@@ -305,6 +314,10 @@ class SAMLProviderConfig(ProviderConfig):
# Remove the prefix from the UID # Remove the prefix from the UID
return social_auth.uid[len(self.idp_slug) + 1:] return social_auth.uid[len(self.idp_slug) + 1:]
def get_social_auth_uid(self, remote_id):
""" Get social auth uid from remote id by prepending idp_slug to the remote id """
return '{}:{}'.format(self.idp_slug, remote_id)
def get_config(self): def get_config(self):
""" """
Return a SAMLIdentityProvider instance for use by SAMLAuthBackend. Return a SAMLIdentityProvider instance for use by SAMLAuthBackend.
...@@ -554,3 +567,22 @@ class LTIProviderConfig(ProviderConfig): ...@@ -554,3 +567,22 @@ class LTIProviderConfig(ProviderConfig):
class Meta(object): class Meta(object):
verbose_name = "Provider Configuration (LTI)" verbose_name = "Provider Configuration (LTI)"
verbose_name_plural = verbose_name verbose_name_plural = verbose_name
class ProviderApiPermissions(models.Model):
"""
This model links OAuth2 client with provider Id.
It gives permission for a OAuth2 client to access the information under certain IdPs.
"""
client = models.ForeignKey(Client)
provider_id = models.CharField(
max_length=255,
help_text=(
'Uniquely identify a provider. This is different from backend_name.'
)
)
class Meta(object): # pylint: disable=missing-docstring
verbose_name = "Provider API Permission"
verbose_name_plural = verbose_name + 's'
...@@ -7,6 +7,8 @@ Used by Django and non-Django tests; must not have Django deps. ...@@ -7,6 +7,8 @@ Used by Django and non-Django tests; must not have Django deps.
from contextlib import contextmanager from contextlib import contextmanager
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import User from django.contrib.auth.models import User
from provider.oauth2.models import Client as OAuth2Client
from provider import constants
import django.test import django.test
import mock import mock
import os.path import os.path
...@@ -17,6 +19,7 @@ from third_party_auth.models import ( ...@@ -17,6 +19,7 @@ from third_party_auth.models import (
SAMLConfiguration, SAMLConfiguration,
LTIProviderConfig, LTIProviderConfig,
cache as config_cache, cache as config_cache,
ProviderApiPermissions,
) )
...@@ -114,6 +117,16 @@ class ThirdPartyAuthTestMixin(object): ...@@ -114,6 +117,16 @@ class ThirdPartyAuthTestMixin(object):
user.save() user.save()
@staticmethod @staticmethod
def configure_oauth_client():
""" Configure a oauth client for testing """
return OAuth2Client.objects.create(client_type=constants.CONFIDENTIAL)
@staticmethod
def configure_api_permission(client, provider_id):
""" Configure the client and provider_id pair. This will give the access to a client for that provider. """
return ProviderApiPermissions.objects.create(client=client, provider_id=provider_id)
@staticmethod
def read_data_file(filename): def read_data_file(filename):
""" Read the contents of a file in the data folder """ """ Read the contents of a file in the data folder """
with open(os.path.join(os.path.dirname(__file__), 'data', filename)) as f: with open(os.path.join(os.path.dirname(__file__), 'data', filename)) as f:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment