Commit e344bb6c by Clinton Blackburn

Updated access token view to return a JWT as an access token

The JWT includes the user email and username, along with details pulled from the original access token (e.g. scope, expiration).

ECOM-4221
parent 543517ef
......@@ -111,8 +111,10 @@ class CcxRestApiTest(CcxTestCase, APITestCase):
token_resp = self.client.post('/oauth2/access_token/', data=token_data)
self.assertEqual(token_resp.status_code, status.HTTP_200_OK)
token_resp_json = json.loads(token_resp.content)
self.assertIn('access_token', token_resp_json)
return 'Bearer {0}'.format(token_resp_json.get('access_token'))
return '{token_type} {token}'.format(
token_type=token_resp_json['token_type'],
token=token_resp_json['access_token']
)
def expect_error(self, http_code, error_code_str, resp_obj):
"""
......
......@@ -17,6 +17,7 @@ from rest_framework_oauth.authentication import OAuth2Authentication
from ccx_keys.locator import CCXLocator
from courseware import courses
from edx_rest_framework_extensions.authentication import JwtAuthentication
from instructor.enrollment import (
enroll_email,
get_email_params,
......@@ -361,7 +362,7 @@ class CCXListView(GenericAPIView):
]
}
"""
authentication_classes = (OAuth2Authentication, SessionAuthentication,)
authentication_classes = (JwtAuthentication, OAuth2Authentication, SessionAuthentication,)
permission_classes = (IsAuthenticated, permissions.IsMasterCourseStaffInstructor)
serializer_class = CCXCourseSerializer
pagination_class = CCXAPIPagination
......@@ -599,7 +600,7 @@ class CCXDetailView(GenericAPIView):
response is returned.
"""
authentication_classes = (OAuth2Authentication, SessionAuthentication,)
authentication_classes = (JwtAuthentication, OAuth2Authentication, SessionAuthentication,)
permission_classes = (IsAuthenticated, permissions.IsCourseStaffInstructor)
serializer_class = CCXCourseSerializer
......
......@@ -12,7 +12,12 @@ class DOTAdapter(object):
backend = object()
def create_confidential_client(self, name, user, redirect_uri, client_id=None):
def create_confidential_client(self,
name,
user,
redirect_uri,
client_id=None,
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE):
"""
Create an oauth client application that is confidential.
"""
......@@ -21,7 +26,7 @@ class DOTAdapter(object):
user=user,
client_id=client_id,
client_type=models.Application.CLIENT_CONFIDENTIAL,
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE,
authorization_grant_type=authorization_grant_type,
redirect_uris=redirect_uri,
)
......
"""
Classes that override default django-oauth-toolkit behavior
"""
from __future__ import unicode_literals
from django.contrib.auth import authenticate, get_user_model
from oauth2_provider.oauth2_validators import OAuth2Validator
......@@ -41,3 +42,25 @@ class EdxOAuth2Validator(OAuth2Validator):
else:
authenticated_user = authenticate(username=email_user.username, password=password)
return authenticated_user
def save_bearer_token(self, token, request, *args, **kwargs):
"""
Ensure that access tokens issued via client credentials grant are associated with the owner of the
``Application``.
"""
grant_type = request.grant_type
user = request.user
if grant_type == 'client_credentials':
# Temporarily remove the grant type to avoid triggering the super method's code that removes request.user.
request.grant_type = None
# Ensure the tokens get associated with the correct user since DOT does not normally
# associate access tokens issued with the client_credentials grant to users.
request.user = request.client.user
super(EdxOAuth2Validator, self).save_bearer_token(token, request, *args, **kwargs)
# Restore the original request attributes
request.grant_type = grant_type
request.user = user
"""
OAuth Dispatch test mixins
"""
import jwt
from django.conf import settings
class AccessTokenMixin(object):
""" Mixin for tests dealing with OAuth 2 access tokens. """
def assert_valid_jwt_access_token(self, access_token, user, scopes=None):
"""
Verify the specified JWT access token is valid, and belongs to the specified user.
Args:
access_token (str): JWT
user (User): User whose information is contained in the JWT payload.
Returns:
dict: Decoded JWT payload
"""
scopes = scopes or []
audience = settings.JWT_AUTH['JWT_AUDIENCE']
issuer = settings.JWT_AUTH['JWT_ISSUER']
payload = jwt.decode(
access_token,
settings.JWT_AUTH['JWT_SECRET_KEY'],
algorithms=[settings.JWT_AUTH['JWT_ALGORITHM']],
audience=audience,
issuer=issuer
)
expected = {
'aud': audience,
'iss': issuer,
'preferred_username': user.username,
}
if 'email' in scopes:
expected['email'] = user.email
self.assertDictContainsSubset(expected, payload)
return payload
......@@ -4,33 +4,37 @@ import json
from django.core.urlresolvers import reverse
from django.test import TestCase
from edx_oauth2_provider.tests.factories import ClientFactory
from oauth2_provider.models import Application
from provider.oauth2.models import AccessToken
from student.tests.factories import UserFactory
from . import mixins
from .constants import DUMMY_REDIRECT_URL
from ..adapters import DOTAdapter
class ClientCredentialsTest(TestCase):
class ClientCredentialsTest(mixins.AccessTokenMixin, TestCase):
""" Tests validating the client credentials grant behavior. """
def setUp(self):
super(ClientCredentialsTest, self).setUp()
self.user = UserFactory()
self.oauth_client = ClientFactory(user=self.user)
def test_access_token(self):
""" Verify the client credentials grant can be used to obtain an access token whose default scopes allow access
to the user info endpoint.
"""
oauth_client = ClientFactory(user=self.user)
data = {
'grant_type': 'client_credentials',
'client_id': self.oauth_client.client_id,
'client_secret': self.oauth_client.client_secret
'client_id': oauth_client.client_id,
'client_secret': oauth_client.client_secret
}
response = self.client.post(reverse('oauth2:access_token'), data)
self.assertEqual(response.status_code, 200)
access_token = json.loads(response.content)['access_token']
expected = AccessToken.objects.filter(client=self.oauth_client, user=self.user).first().token
expected = AccessToken.objects.filter(client=oauth_client, user=self.user).first().token
self.assertEqual(access_token, expected)
headers = {
......@@ -38,3 +42,29 @@ class ClientCredentialsTest(TestCase):
}
response = self.client.get(reverse('oauth2:user_info'), **headers)
self.assertEqual(response.status_code, 200)
def test_jwt_access_token(self):
""" Verify the client credentials grant can be used to obtain a JWT access token. """
application = DOTAdapter().create_confidential_client(
name='test dot application',
user=self.user,
authorization_grant_type=Application.GRANT_CLIENT_CREDENTIALS,
redirect_uri=DUMMY_REDIRECT_URL,
client_id='dot-app-client-id',
)
scopes = ('read', 'write', 'email')
data = {
'grant_type': 'client_credentials',
'client_id': application.client_id,
'client_secret': application.client_secret,
'scope': ' '.join(scopes),
'token_type': 'jwt'
}
response = self.client.post(reverse('access_token'), data)
self.assertEqual(response.status_code, 200)
content = json.loads(response.content)
access_token = content['access_token']
self.assertEqual(content['scope'], data['scope'])
self.assert_valid_jwt_access_token(access_token, self.user, scopes)
......@@ -12,9 +12,10 @@ import httpretty
from student.tests.factories import UserFactory
from third_party_auth.tests.utils import ThirdPartyOAuthTestMixin, ThirdPartyOAuthTestMixinGoogle
from .constants import DUMMY_REDIRECT_URL
from .. import adapters
from .. import views
from .constants import DUMMY_REDIRECT_URL
from . import mixins
class _DispatchingViewTestCase(TestCase):
......@@ -43,14 +44,14 @@ class _DispatchingViewTestCase(TestCase):
client_id='dop-app-client-id',
)
def _post_request(self, user, client):
def _post_request(self, user, client, token_type=None):
"""
Call the view with a POST request objectwith the appropriate format,
returning the response object.
"""
return self.client.post(self.url, self._post_body(user, client))
return self.client.post(self.url, self._post_body(user, client, token_type))
def _post_body(self, user, client):
def _post_body(self, user, client, token_type=None):
"""
Return a dictionary to be used as the body of the POST request
"""
......@@ -58,7 +59,7 @@ class _DispatchingViewTestCase(TestCase):
@ddt.ddt
class TestAccessTokenView(_DispatchingViewTestCase):
class TestAccessTokenView(mixins.AccessTokenMixin, _DispatchingViewTestCase):
"""
Test class for AccessTokenView
"""
......@@ -66,17 +67,22 @@ class TestAccessTokenView(_DispatchingViewTestCase):
view_class = views.AccessTokenView
url = reverse('access_token')
def _post_body(self, user, client):
def _post_body(self, user, client, token_type=None):
"""
Return a dictionary to be used as the body of the POST request
"""
return {
body = {
'client_id': client.client_id,
'grant_type': 'password',
'username': user.username,
'password': 'test',
}
if token_type:
body['token_type'] = token_type
return body
@ddt.data('dop_client', 'dot_app')
def test_access_token_fields(self, client_attr):
client = getattr(self, client_attr)
......@@ -88,6 +94,16 @@ class TestAccessTokenView(_DispatchingViewTestCase):
self.assertIn('scope', data)
self.assertIn('token_type', data)
@ddt.data('dop_client', 'dot_app')
def test_jwt_access_token(self, client_attr):
client = getattr(self, client_attr)
response = self._post_request(self.user, client, token_type='jwt')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content)
self.assertIn('expires_in', data)
self.assertEqual(data['token_type'], 'JWT')
self.assert_valid_jwt_access_token(data['access_token'], self.user, data['scope'].split(' '))
def test_dot_access_token_provides_refresh_token(self):
response = self._post_request(self.user, self.dot_app)
self.assertEqual(response.status_code, 200)
......@@ -111,7 +127,7 @@ class TestAccessTokenExchangeView(ThirdPartyOAuthTestMixinGoogle, ThirdPartyOAut
view_class = views.AccessTokenExchangeView
url = reverse('exchange_access_token', kwargs={'backend': 'google-oauth2'})
def _post_body(self, user, client):
def _post_body(self, user, client, token_type=None):
return {
'client_id': client.client_id,
'access_token': self.access_token,
......
......@@ -5,12 +5,17 @@ django-oauth-toolkit as appropriate.
from __future__ import unicode_literals
import json
from time import time
import jwt
from auth_exchange import views as auth_exchange_views
from django.conf import settings
from django.utils.functional import cached_property
from django.views.generic import View
from edx_oauth2_provider import views as dop_views # django-oauth2-provider views
from oauth2_provider import models as dot_models, views as dot_views # django-oauth-toolkit
from auth_exchange import views as auth_exchange_views
from . import adapters
......@@ -25,6 +30,15 @@ class _DispatchingView(View):
dot_adapter = adapters.DOTAdapter()
dop_adapter = adapters.DOPAdapter()
def get_adapter(self, request):
"""
Returns the appropriate adapter based on the OAuth client linked to the request.
"""
if dot_models.Application.objects.filter(client_id=self._get_client_id(request)).exists():
return self.dot_adapter
else:
return self.dop_adapter
def dispatch(self, request, *args, **kwargs):
"""
Dispatch the request to the selected backend's view.
......@@ -41,11 +55,7 @@ class _DispatchingView(View):
otherwise use the django-oauth2-provider (DOP) adapter, and allow the
calls to fail normally if the client does not exist.
"""
if dot_models.Application.objects.filter(client_id=self._get_client_id(request)).exists():
return self.dot_adapter.backend
else:
return self.dop_adapter.backend
return self.get_adapter(request).backend
def get_view_for_backend(self, backend):
"""
......@@ -72,6 +82,77 @@ class AccessTokenView(_DispatchingView):
dot_view = dot_views.TokenView
dop_view = dop_views.AccessTokenView
@cached_property
def claim_handlers(self):
""" Returns a dictionary mapping scopes to methods that will add claims to the JWT payload. """
return {
'email': self._attach_email_claim,
'profile': self._attach_profile_claim
}
def dispatch(self, request, *args, **kwargs):
response = super(AccessTokenView, self).dispatch(request, *args, **kwargs)
if response.status_code == 200 and request.POST.get('token_type', '').lower() == 'jwt':
expires_in, scopes, user = self._decompose_access_token_response(request, response)
content = {
'access_token': self._generate_jwt(user, scopes, expires_in),
'expires_in': expires_in,
'token_type': 'JWT',
'scope': ' '.join(scopes),
}
response.content = json.dumps(content)
return response
def _decompose_access_token_response(self, request, response):
""" Decomposes the access token in the request to an expiration date, scopes, and User. """
content = json.loads(response.content)
access_token = content['access_token']
scope = content['scope']
access_token_obj = self.get_adapter(request).get_access_token(access_token)
user = access_token_obj.user
scopes = scope.split(' ')
expires_in = content['expires_in']
return expires_in, scopes, user
def _generate_jwt(self, user, scopes, expires_in):
""" Returns a JWT access token. """
now = int(time())
payload = {
'iss': settings.JWT_AUTH['JWT_ISSUER'],
'aud': settings.JWT_AUTH['JWT_AUDIENCE'],
'exp': now + expires_in,
'iat': now,
'preferred_username': user.username,
}
for scope in scopes:
handler = self.claim_handlers.get(scope)
if handler:
handler(payload, user)
secret = settings.JWT_AUTH['JWT_SECRET_KEY']
token = jwt.encode(payload, secret, algorithm=settings.JWT_AUTH['JWT_ALGORITHM'])
return token
def _attach_email_claim(self, payload, user):
""" Add the email claim details to the JWT payload. """
payload['email'] = user.email
def _attach_profile_claim(self, payload, user):
""" Add the profile claim details to the JWT payload. """
payload.update({
'family_name': user.last_name,
'name': user.get_full_name(),
'given_name': user.first_name,
})
class AuthorizationView(_DispatchingView):
"""
......
......@@ -458,7 +458,13 @@ OAUTH_EXPIRE_PUBLIC_CLIENT_DAYS = 30
################################## DJANGO OAUTH TOOLKIT #######################################
OAUTH2_PROVIDER = {
'OAUTH2_VALIDATOR_CLASS': 'lms.djangoapps.oauth_dispatch.dot_overrides.EdxOAuth2Validator'
'OAUTH2_VALIDATOR_CLASS': 'lms.djangoapps.oauth_dispatch.dot_overrides.EdxOAuth2Validator',
'SCOPES': {
'read': 'Read scope',
'write': 'Write scope',
'email': 'Email scope',
'profile': 'Profile scope',
}
}
################################## TEMPLATE CONFIGURATION #####################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment