Commit 6941fcd7 by Clinton Blackburn Committed by Clinton Blackburn

Updated access token view to return a JWT as an access token

The JWT includes the user email and username, along with details pulled from the original access token (e.g. scope, expiration).

ECOM-4221
parent 5adf6fec
...@@ -111,8 +111,10 @@ class CcxRestApiTest(CcxTestCase, APITestCase): ...@@ -111,8 +111,10 @@ class CcxRestApiTest(CcxTestCase, APITestCase):
token_resp = self.client.post('/oauth2/access_token/', data=token_data) token_resp = self.client.post('/oauth2/access_token/', data=token_data)
self.assertEqual(token_resp.status_code, status.HTTP_200_OK) self.assertEqual(token_resp.status_code, status.HTTP_200_OK)
token_resp_json = json.loads(token_resp.content) token_resp_json = json.loads(token_resp.content)
self.assertIn('access_token', token_resp_json) return '{token_type} {token}'.format(
return 'Bearer {0}'.format(token_resp_json.get('access_token')) token_type=token_resp_json['token_type'],
token=token_resp_json['access_token']
)
def expect_error(self, http_code, error_code_str, resp_obj): def expect_error(self, http_code, error_code_str, resp_obj):
""" """
......
...@@ -17,6 +17,7 @@ from rest_framework_oauth.authentication import OAuth2Authentication ...@@ -17,6 +17,7 @@ from rest_framework_oauth.authentication import OAuth2Authentication
from ccx_keys.locator import CCXLocator from ccx_keys.locator import CCXLocator
from courseware import courses from courseware import courses
from edx_rest_framework_extensions.authentication import JwtAuthentication
from instructor.enrollment import ( from instructor.enrollment import (
enroll_email, enroll_email,
get_email_params, get_email_params,
...@@ -361,7 +362,7 @@ class CCXListView(GenericAPIView): ...@@ -361,7 +362,7 @@ class CCXListView(GenericAPIView):
] ]
} }
""" """
authentication_classes = (OAuth2Authentication, SessionAuthentication,) authentication_classes = (JwtAuthentication, OAuth2Authentication, SessionAuthentication,)
permission_classes = (IsAuthenticated, permissions.IsMasterCourseStaffInstructor) permission_classes = (IsAuthenticated, permissions.IsMasterCourseStaffInstructor)
serializer_class = CCXCourseSerializer serializer_class = CCXCourseSerializer
pagination_class = CCXAPIPagination pagination_class = CCXAPIPagination
...@@ -599,7 +600,7 @@ class CCXDetailView(GenericAPIView): ...@@ -599,7 +600,7 @@ class CCXDetailView(GenericAPIView):
response is returned. response is returned.
""" """
authentication_classes = (OAuth2Authentication, SessionAuthentication,) authentication_classes = (JwtAuthentication, OAuth2Authentication, SessionAuthentication,)
permission_classes = (IsAuthenticated, permissions.IsCourseStaffInstructor) permission_classes = (IsAuthenticated, permissions.IsCourseStaffInstructor)
serializer_class = CCXCourseSerializer serializer_class = CCXCourseSerializer
......
...@@ -12,7 +12,12 @@ class DOTAdapter(object): ...@@ -12,7 +12,12 @@ class DOTAdapter(object):
backend = object() backend = object()
def create_confidential_client(self, name, user, redirect_uri, client_id=None): def create_confidential_client(self,
name,
user,
redirect_uri,
client_id=None,
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE):
""" """
Create an oauth client application that is confidential. Create an oauth client application that is confidential.
""" """
...@@ -21,7 +26,7 @@ class DOTAdapter(object): ...@@ -21,7 +26,7 @@ class DOTAdapter(object):
user=user, user=user,
client_id=client_id, client_id=client_id,
client_type=models.Application.CLIENT_CONFIDENTIAL, client_type=models.Application.CLIENT_CONFIDENTIAL,
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE, authorization_grant_type=authorization_grant_type,
redirect_uris=redirect_uri, redirect_uris=redirect_uri,
) )
......
""" """
Classes that override default django-oauth-toolkit behavior Classes that override default django-oauth-toolkit behavior
""" """
from __future__ import unicode_literals
from django.contrib.auth import authenticate, get_user_model from django.contrib.auth import authenticate, get_user_model
from oauth2_provider.oauth2_validators import OAuth2Validator from oauth2_provider.oauth2_validators import OAuth2Validator
...@@ -41,3 +42,25 @@ class EdxOAuth2Validator(OAuth2Validator): ...@@ -41,3 +42,25 @@ class EdxOAuth2Validator(OAuth2Validator):
else: else:
authenticated_user = authenticate(username=email_user.username, password=password) authenticated_user = authenticate(username=email_user.username, password=password)
return authenticated_user return authenticated_user
def save_bearer_token(self, token, request, *args, **kwargs):
"""
Ensure that access tokens issued via client credentials grant are associated with the owner of the
``Application``.
"""
grant_type = request.grant_type
user = request.user
if grant_type == 'client_credentials':
# Temporarily remove the grant type to avoid triggering the super method's code that removes request.user.
request.grant_type = None
# Ensure the tokens get associated with the correct user since DOT does not normally
# associate access tokens issued with the client_credentials grant to users.
request.user = request.client.user
super(EdxOAuth2Validator, self).save_bearer_token(token, request, *args, **kwargs)
# Restore the original request attributes
request.grant_type = grant_type
request.user = user
""" """
OAuth Dispatch test mixins OAuth Dispatch test mixins
""" """
import jwt
from django.conf import settings
class AccessTokenMixin(object):
""" Mixin for tests dealing with OAuth 2 access tokens. """
def assert_valid_jwt_access_token(self, access_token, user, scopes=None):
"""
Verify the specified JWT access token is valid, and belongs to the specified user.
Args:
access_token (str): JWT
user (User): User whose information is contained in the JWT payload.
Returns:
dict: Decoded JWT payload
"""
scopes = scopes or []
audience = settings.JWT_AUTH['JWT_AUDIENCE']
issuer = settings.JWT_AUTH['JWT_ISSUER']
payload = jwt.decode(
access_token,
settings.JWT_AUTH['JWT_SECRET_KEY'],
algorithms=[settings.JWT_AUTH['JWT_ALGORITHM']],
audience=audience,
issuer=issuer
)
expected = {
'aud': audience,
'iss': issuer,
'preferred_username': user.username,
}
if 'email' in scopes:
expected['email'] = user.email
self.assertDictContainsSubset(expected, payload)
return payload
...@@ -4,33 +4,37 @@ import json ...@@ -4,33 +4,37 @@ import json
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.test import TestCase from django.test import TestCase
from edx_oauth2_provider.tests.factories import ClientFactory from edx_oauth2_provider.tests.factories import ClientFactory
from oauth2_provider.models import Application
from provider.oauth2.models import AccessToken from provider.oauth2.models import AccessToken
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
from . import mixins
from .constants import DUMMY_REDIRECT_URL
from ..adapters import DOTAdapter
class ClientCredentialsTest(TestCase):
class ClientCredentialsTest(mixins.AccessTokenMixin, TestCase):
""" Tests validating the client credentials grant behavior. """ """ Tests validating the client credentials grant behavior. """
def setUp(self): def setUp(self):
super(ClientCredentialsTest, self).setUp() super(ClientCredentialsTest, self).setUp()
self.user = UserFactory() self.user = UserFactory()
self.oauth_client = ClientFactory(user=self.user)
def test_access_token(self): def test_access_token(self):
""" Verify the client credentials grant can be used to obtain an access token whose default scopes allow access """ Verify the client credentials grant can be used to obtain an access token whose default scopes allow access
to the user info endpoint. to the user info endpoint.
""" """
oauth_client = ClientFactory(user=self.user)
data = { data = {
'grant_type': 'client_credentials', 'grant_type': 'client_credentials',
'client_id': self.oauth_client.client_id, 'client_id': oauth_client.client_id,
'client_secret': self.oauth_client.client_secret 'client_secret': oauth_client.client_secret
} }
response = self.client.post(reverse('oauth2:access_token'), data) response = self.client.post(reverse('oauth2:access_token'), data)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
access_token = json.loads(response.content)['access_token'] access_token = json.loads(response.content)['access_token']
expected = AccessToken.objects.filter(client=self.oauth_client, user=self.user).first().token expected = AccessToken.objects.filter(client=oauth_client, user=self.user).first().token
self.assertEqual(access_token, expected) self.assertEqual(access_token, expected)
headers = { headers = {
...@@ -38,3 +42,29 @@ class ClientCredentialsTest(TestCase): ...@@ -38,3 +42,29 @@ class ClientCredentialsTest(TestCase):
} }
response = self.client.get(reverse('oauth2:user_info'), **headers) response = self.client.get(reverse('oauth2:user_info'), **headers)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_jwt_access_token(self):
""" Verify the client credentials grant can be used to obtain a JWT access token. """
application = DOTAdapter().create_confidential_client(
name='test dot application',
user=self.user,
authorization_grant_type=Application.GRANT_CLIENT_CREDENTIALS,
redirect_uri=DUMMY_REDIRECT_URL,
client_id='dot-app-client-id',
)
scopes = ('read', 'write', 'email')
data = {
'grant_type': 'client_credentials',
'client_id': application.client_id,
'client_secret': application.client_secret,
'scope': ' '.join(scopes),
'token_type': 'jwt'
}
response = self.client.post(reverse('access_token'), data)
self.assertEqual(response.status_code, 200)
content = json.loads(response.content)
access_token = content['access_token']
self.assertEqual(content['scope'], data['scope'])
self.assert_valid_jwt_access_token(access_token, self.user, scopes)
...@@ -12,9 +12,10 @@ import httpretty ...@@ -12,9 +12,10 @@ import httpretty
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
from third_party_auth.tests.utils import ThirdPartyOAuthTestMixin, ThirdPartyOAuthTestMixinGoogle from third_party_auth.tests.utils import ThirdPartyOAuthTestMixin, ThirdPartyOAuthTestMixinGoogle
from .constants import DUMMY_REDIRECT_URL
from .. import adapters from .. import adapters
from .. import views from .. import views
from .constants import DUMMY_REDIRECT_URL from . import mixins
class _DispatchingViewTestCase(TestCase): class _DispatchingViewTestCase(TestCase):
...@@ -43,14 +44,14 @@ class _DispatchingViewTestCase(TestCase): ...@@ -43,14 +44,14 @@ class _DispatchingViewTestCase(TestCase):
client_id='dop-app-client-id', client_id='dop-app-client-id',
) )
def _post_request(self, user, client): def _post_request(self, user, client, token_type=None):
""" """
Call the view with a POST request objectwith the appropriate format, Call the view with a POST request objectwith the appropriate format,
returning the response object. returning the response object.
""" """
return self.client.post(self.url, self._post_body(user, client)) return self.client.post(self.url, self._post_body(user, client, token_type))
def _post_body(self, user, client): def _post_body(self, user, client, token_type=None):
""" """
Return a dictionary to be used as the body of the POST request Return a dictionary to be used as the body of the POST request
""" """
...@@ -58,7 +59,7 @@ class _DispatchingViewTestCase(TestCase): ...@@ -58,7 +59,7 @@ class _DispatchingViewTestCase(TestCase):
@ddt.ddt @ddt.ddt
class TestAccessTokenView(_DispatchingViewTestCase): class TestAccessTokenView(mixins.AccessTokenMixin, _DispatchingViewTestCase):
""" """
Test class for AccessTokenView Test class for AccessTokenView
""" """
...@@ -66,17 +67,22 @@ class TestAccessTokenView(_DispatchingViewTestCase): ...@@ -66,17 +67,22 @@ class TestAccessTokenView(_DispatchingViewTestCase):
view_class = views.AccessTokenView view_class = views.AccessTokenView
url = reverse('access_token') url = reverse('access_token')
def _post_body(self, user, client): def _post_body(self, user, client, token_type=None):
""" """
Return a dictionary to be used as the body of the POST request Return a dictionary to be used as the body of the POST request
""" """
return { body = {
'client_id': client.client_id, 'client_id': client.client_id,
'grant_type': 'password', 'grant_type': 'password',
'username': user.username, 'username': user.username,
'password': 'test', 'password': 'test',
} }
if token_type:
body['token_type'] = token_type
return body
@ddt.data('dop_client', 'dot_app') @ddt.data('dop_client', 'dot_app')
def test_access_token_fields(self, client_attr): def test_access_token_fields(self, client_attr):
client = getattr(self, client_attr) client = getattr(self, client_attr)
...@@ -88,6 +94,16 @@ class TestAccessTokenView(_DispatchingViewTestCase): ...@@ -88,6 +94,16 @@ class TestAccessTokenView(_DispatchingViewTestCase):
self.assertIn('scope', data) self.assertIn('scope', data)
self.assertIn('token_type', data) self.assertIn('token_type', data)
@ddt.data('dop_client', 'dot_app')
def test_jwt_access_token(self, client_attr):
client = getattr(self, client_attr)
response = self._post_request(self.user, client, token_type='jwt')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content)
self.assertIn('expires_in', data)
self.assertEqual(data['token_type'], 'JWT')
self.assert_valid_jwt_access_token(data['access_token'], self.user, data['scope'].split(' '))
def test_dot_access_token_provides_refresh_token(self): def test_dot_access_token_provides_refresh_token(self):
response = self._post_request(self.user, self.dot_app) response = self._post_request(self.user, self.dot_app)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
...@@ -111,7 +127,7 @@ class TestAccessTokenExchangeView(ThirdPartyOAuthTestMixinGoogle, ThirdPartyOAut ...@@ -111,7 +127,7 @@ class TestAccessTokenExchangeView(ThirdPartyOAuthTestMixinGoogle, ThirdPartyOAut
view_class = views.AccessTokenExchangeView view_class = views.AccessTokenExchangeView
url = reverse('exchange_access_token', kwargs={'backend': 'google-oauth2'}) url = reverse('exchange_access_token', kwargs={'backend': 'google-oauth2'})
def _post_body(self, user, client): def _post_body(self, user, client, token_type=None):
return { return {
'client_id': client.client_id, 'client_id': client.client_id,
'access_token': self.access_token, 'access_token': self.access_token,
......
...@@ -5,12 +5,17 @@ django-oauth-toolkit as appropriate. ...@@ -5,12 +5,17 @@ django-oauth-toolkit as appropriate.
from __future__ import unicode_literals from __future__ import unicode_literals
import json
from time import time
import jwt
from auth_exchange import views as auth_exchange_views
from django.conf import settings
from django.utils.functional import cached_property
from django.views.generic import View from django.views.generic import View
from edx_oauth2_provider import views as dop_views # django-oauth2-provider views from edx_oauth2_provider import views as dop_views # django-oauth2-provider views
from oauth2_provider import models as dot_models, views as dot_views # django-oauth-toolkit from oauth2_provider import models as dot_models, views as dot_views # django-oauth-toolkit
from auth_exchange import views as auth_exchange_views
from . import adapters from . import adapters
...@@ -25,6 +30,15 @@ class _DispatchingView(View): ...@@ -25,6 +30,15 @@ class _DispatchingView(View):
dot_adapter = adapters.DOTAdapter() dot_adapter = adapters.DOTAdapter()
dop_adapter = adapters.DOPAdapter() dop_adapter = adapters.DOPAdapter()
def get_adapter(self, request):
"""
Returns the appropriate adapter based on the OAuth client linked to the request.
"""
if dot_models.Application.objects.filter(client_id=self._get_client_id(request)).exists():
return self.dot_adapter
else:
return self.dop_adapter
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
""" """
Dispatch the request to the selected backend's view. Dispatch the request to the selected backend's view.
...@@ -41,11 +55,7 @@ class _DispatchingView(View): ...@@ -41,11 +55,7 @@ class _DispatchingView(View):
otherwise use the django-oauth2-provider (DOP) adapter, and allow the otherwise use the django-oauth2-provider (DOP) adapter, and allow the
calls to fail normally if the client does not exist. calls to fail normally if the client does not exist.
""" """
return self.get_adapter(request).backend
if dot_models.Application.objects.filter(client_id=self._get_client_id(request)).exists():
return self.dot_adapter.backend
else:
return self.dop_adapter.backend
def get_view_for_backend(self, backend): def get_view_for_backend(self, backend):
""" """
...@@ -72,6 +82,77 @@ class AccessTokenView(_DispatchingView): ...@@ -72,6 +82,77 @@ class AccessTokenView(_DispatchingView):
dot_view = dot_views.TokenView dot_view = dot_views.TokenView
dop_view = dop_views.AccessTokenView dop_view = dop_views.AccessTokenView
@cached_property
def claim_handlers(self):
""" Returns a dictionary mapping scopes to methods that will add claims to the JWT payload. """
return {
'email': self._attach_email_claim,
'profile': self._attach_profile_claim
}
def dispatch(self, request, *args, **kwargs):
response = super(AccessTokenView, self).dispatch(request, *args, **kwargs)
if response.status_code == 200 and request.POST.get('token_type', '').lower() == 'jwt':
expires_in, scopes, user = self._decompose_access_token_response(request, response)
content = {
'access_token': self._generate_jwt(user, scopes, expires_in),
'expires_in': expires_in,
'token_type': 'JWT',
'scope': ' '.join(scopes),
}
response.content = json.dumps(content)
return response
def _decompose_access_token_response(self, request, response):
""" Decomposes the access token in the request to an expiration date, scopes, and User. """
content = json.loads(response.content)
access_token = content['access_token']
scope = content['scope']
access_token_obj = self.get_adapter(request).get_access_token(access_token)
user = access_token_obj.user
scopes = scope.split(' ')
expires_in = content['expires_in']
return expires_in, scopes, user
def _generate_jwt(self, user, scopes, expires_in):
""" Returns a JWT access token. """
now = int(time())
payload = {
'iss': settings.JWT_AUTH['JWT_ISSUER'],
'aud': settings.JWT_AUTH['JWT_AUDIENCE'],
'exp': now + expires_in,
'iat': now,
'preferred_username': user.username,
}
for scope in scopes:
handler = self.claim_handlers.get(scope)
if handler:
handler(payload, user)
secret = settings.JWT_AUTH['JWT_SECRET_KEY']
token = jwt.encode(payload, secret, algorithm=settings.JWT_AUTH['JWT_ALGORITHM'])
return token
def _attach_email_claim(self, payload, user):
""" Add the email claim details to the JWT payload. """
payload['email'] = user.email
def _attach_profile_claim(self, payload, user):
""" Add the profile claim details to the JWT payload. """
payload.update({
'family_name': user.last_name,
'name': user.get_full_name(),
'given_name': user.first_name,
})
class AuthorizationView(_DispatchingView): class AuthorizationView(_DispatchingView):
""" """
......
...@@ -458,7 +458,13 @@ OAUTH_EXPIRE_PUBLIC_CLIENT_DAYS = 30 ...@@ -458,7 +458,13 @@ OAUTH_EXPIRE_PUBLIC_CLIENT_DAYS = 30
################################## DJANGO OAUTH TOOLKIT ####################################### ################################## DJANGO OAUTH TOOLKIT #######################################
OAUTH2_PROVIDER = { OAUTH2_PROVIDER = {
'OAUTH2_VALIDATOR_CLASS': 'lms.djangoapps.oauth_dispatch.dot_overrides.EdxOAuth2Validator' 'OAUTH2_VALIDATOR_CLASS': 'lms.djangoapps.oauth_dispatch.dot_overrides.EdxOAuth2Validator',
'SCOPES': {
'read': 'Read scope',
'write': 'Write scope',
'email': 'Email scope',
'profile': 'Profile scope',
}
} }
################################## TEMPLATE CONFIGURATION ##################################### ################################## TEMPLATE CONFIGURATION #####################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment