Commit 1ba96f37 by Cliff Dyer

Merge pull request #11182 from edx/mobile/access-token-messages-MA-1900

Provide useful response messages with access token errors
parents db694e28 864e081f
""" Common Authentication Handlers used across projects. """
"""
Common Authentication Handlers used across projects.
"""
from rest_framework.authentication import SessionAuthentication
from rest_framework import exceptions as drf_exceptions
from rest_framework_oauth.authentication import OAuth2Authentication
from rest_framework.exceptions import AuthenticationFailed
from .exceptions import AuthenticationFailed
from rest_framework_oauth.compat import oauth2_provider, provider_now
OAUTH2_TOKEN_ERROR = u'token_error'
OAUTH2_TOKEN_ERROR_EXPIRED = u'token_expired'
OAUTH2_TOKEN_ERROR_MALFORMED = u'token_malformed'
OAUTH2_TOKEN_ERROR_NONEXISTENT = u'token_nonexistent'
OAUTH2_TOKEN_ERROR_NOT_PROVIDED = u'token_not_provided'
class SessionAuthenticationAllowInactiveUser(SessionAuthentication):
"""Ensure that the user is logged in, but do not require the account to be active.
......@@ -65,19 +76,55 @@ class OAuth2AuthenticationAllowInactiveUser(OAuth2Authentication):
This class can be used for an OAuth2-accessible endpoint that allows users to access
that endpoint without having their email verified. For example, this is used
for mobile endpoints.
"""
def authenticate(self, *args, **kwargs):
"""
Returns two-tuple of (user, token) if access token authentication
succeeds, raises an AuthenticationFailed (HTTP 401) if authentication
fails or None if the user did not try to authenticate using an access
token.
Overrides base class implementation to return edX-style error
responses.
"""
try:
return super(OAuth2AuthenticationAllowInactiveUser, self).authenticate(*args, **kwargs)
except AuthenticationFailed:
# AuthenticationFailed is a subclass of drf_exceptions.AuthenticationFailed,
# but we don't want to post-process the exception detail for our own class.
raise
except drf_exceptions.AuthenticationFailed as exc:
if 'No credentials provided' in exc.detail:
error_code = OAUTH2_TOKEN_ERROR_NOT_PROVIDED
elif 'Token string should not contain spaces' in exc.detail:
error_code = OAUTH2_TOKEN_ERROR_MALFORMED
else:
error_code = OAUTH2_TOKEN_ERROR
raise AuthenticationFailed({
u'error_code': error_code,
u'developer_message': exc.detail
})
def authenticate_credentials(self, request, access_token):
"""
Authenticate the request, given the access token.
Override base class implementation to discard failure if user is inactive.
Overrides base class implementation to discard failure if user is inactive.
"""
try:
token = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user')
# provider_now switches to timezone aware datetime when
# the oauth2_provider version supports to it.
token = token.get(token=access_token, expires__gt=provider_now())
except oauth2_provider.oauth2.models.AccessToken.DoesNotExist:
raise AuthenticationFailed('Invalid token')
return token.user, token
token_query = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user')
token = token_query.filter(token=access_token).first()
if not token:
raise AuthenticationFailed({
u'error_code': OAUTH2_TOKEN_ERROR_NONEXISTENT,
u'developer_message': u'The provided access token does not match any valid tokens.'
})
# provider_now switches to timezone aware datetime when
# the oauth2_provider version supports it.
elif token.expires < provider_now():
raise AuthenticationFailed({
u'error_code': OAUTH2_TOKEN_ERROR_EXPIRED,
u'developer_message': u'The provided access token has expired and is no longer valid.',
})
else:
return token.user, token
"""
Custom exceptions, that allow details to be passed as dict values (which can be
converted to JSON, like other API responses.
"""
from rest_framework import exceptions
# TODO: Override Throttled, UnsupportedMediaType, ValidationError. These types require
# more careful handling of arguments.
class _DictAPIException(exceptions.APIException):
"""
Intermediate class to allow exceptions to pass dict detail values. Use by
subclassing this along with another subclass of `exceptions.APIException`.
"""
def __init__(self, detail):
if isinstance(detail, dict):
self.detail = detail
else:
super(_DictAPIException, self).__init__(detail)
class AuthenticationFailed(exceptions.AuthenticationFailed, _DictAPIException):
"""
Override of DRF's AuthenticationFailed exception to allow dictionary responses.
"""
pass
class MethodNotAllowed(exceptions.MethodNotAllowed, _DictAPIException):
"""
Override of DRF's MethodNotAllowed exception to allow dictionary responses.
"""
def __init__(self, method, detail=None):
if isinstance(detail, dict):
self.detail = detail
else:
super(MethodNotAllowed, self).__init__(method, detail)
class NotAcceptable(exceptions.NotAcceptable, _DictAPIException):
"""
Override of DRF's NotAcceptable exception to allow dictionary responses.
"""
def __init__(self, detail=None, available_renderers=None):
self.available_renderers = available_renderers
if isinstance(detail, dict):
self.detail = detail
else:
super(NotAcceptable, self).__init__(detail, available_renderers)
class NotAuthenticated(exceptions.NotAuthenticated, _DictAPIException):
"""
Override of DRF's NotAuthenticated exception to allow dictionary responses.
"""
pass
class NotFound(exceptions.NotFound, _DictAPIException):
"""
Override of DRF's NotFound exception to allow dictionary responses.
"""
pass
class ParseError(exceptions.ParseError, _DictAPIException):
"""
Override of DRF's ParseError exception to allow dictionary responses.
"""
pass
class PermissionDenied(exceptions.PermissionDenied, _DictAPIException):
"""
Override of DRF's PermissionDenied exception to allow dictionary responses.
"""
pass
"""
Tests for OAuth2. This module is copied from django-rest-framework-oauth (tests/test_authentication.py)
and updated to use our subclass of OAuth2Authentication.
Tests for OAuth2. This module is copied from django-rest-framework-oauth
(tests/test_authentication.py) and updated to use our subclass of OAuth2Authentication.
"""
from __future__ import unicode_literals
import datetime
from collections import namedtuple
from datetime import datetime, timedelta
import itertools
import json
import ddt
from django.conf.urls import patterns, url, include
from django.contrib.auth.models import User
......@@ -22,7 +27,7 @@ from rest_framework.views import APIView
from provider import scope, constants
from ..authentication import OAuth2AuthenticationAllowInactiveUser
from .. import authentication
factory = APIRequestFactory() # pylint: disable=invalid-name
......@@ -43,30 +48,35 @@ class MockView(APIView): # pylint: disable=missing-docstring
# This is the a change we've made from the django-rest-framework-oauth version
# of these tests. We're subclassing our custom OAuth2AuthenticationAllowInactiveUser
# instead of OAuth2Authentication.
class OAuth2AuthenticationDebug(OAuth2AuthenticationAllowInactiveUser): # pylint: disable=missing-docstring
class OAuth2AuthenticationDebug(authentication.OAuth2AuthenticationAllowInactiveUser): # pylint: disable=missing-docstring
allow_query_params_token = True
urlpatterns = patterns(
'',
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationAllowInactiveUser])),
url(
r'^oauth2-test/$',
MockView.as_view(authentication_classes=[authentication.OAuth2AuthenticationAllowInactiveUser])
),
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
url(
r'^oauth2-with-scope-test/$',
MockView.as_view(
authentication_classes=[OAuth2AuthenticationAllowInactiveUser],
authentication_classes=[authentication.OAuth2AuthenticationAllowInactiveUser],
permission_classes=[permissions.TokenHasReadWriteScope]
)
),
)
@ddt.ddt
class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication"""
urls = 'openedx.core.lib.api.tests.test_authentication'
def setUp(self):
super(OAuth2Tests, self).setUp()
self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
......@@ -75,8 +85,8 @@ class OAuth2Tests(TestCase):
self.CLIENT_ID = 'client_key' # pylint: disable=invalid-name
self.CLIENT_SECRET = 'client_secret' # pylint: disable=invalid-name
self.ACCESS_TOKEN = "access_token" # pylint: disable=invalid-name
self.REFRESH_TOKEN = "refresh_token" # pylint: disable=invalid-name
self.ACCESS_TOKEN = 'access_token' # pylint: disable=invalid-name
self.REFRESH_TOKEN = 'refresh_token' # pylint: disable=invalid-name
self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create(
client_id=self.CLIENT_ID,
......@@ -111,51 +121,59 @@ class OAuth2Tests(TestCase):
# edx-auth2-provider.
scope.SCOPE_NAME_DICT = {'read': constants.READ, 'write': constants.WRITE}
def get_with_bearer_token(self, target_url, params=None, token=None):
"""
Make a GET request to the specified URL with an OAuth2 bearer token. If
no token is provided, a valid token will be used. Query parameters can
also be passed in if desired.
"""
auth = self._create_authorization_header(token)
return self.csrf_client.get(target_url, params, HTTP_AUTHORIZATION=auth)
def post_with_bearer_token(self, target_url, token=None):
"""
Make a POST request to the specified URL with an OAuth2 bearer token. If
no token is provided, a valid token will be used.
"""
auth = self._create_authorization_header(token)
return self.csrf_client.post(target_url, HTTP_AUTHORIZATION=auth)
def check_error_codes(self, response, status_code, error_code):
"""
Ensure that the response has the appropriate HTTP status, and provides
the expected error_code in the JSON response body.
"""
response_dict = json.loads(response.content)
self.assertEqual(response.status_code, status_code)
self.assertEqual(response_dict['error_code'], error_code)
def _create_authorization_header(self, token=None): # pylint: disable=missing-docstring
return "Bearer {0}".format(token or self.access_token.token)
if token is None:
token = self.access_token.token
return "Bearer {0}".format(token)
@ddt.data(None, {})
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_with_wrong_authorization_header_token_type_failing(self):
def test_get_form_with_wrong_authorization_header_token_type_failing(self, params):
"""Ensure that a wrong token type lead to the correct HTTP error status code"""
auth = "Wrong token-type-obviously"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_with_wrong_authorization_header_token_format_failing(self):
"""Ensure that a wrong token format lead to the correct HTTP error status code"""
auth = "Bearer wrong token format"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
response = self.csrf_client.get(
'/oauth2-test/',
params,
HTTP_AUTHORIZATION='Wrong token-type-obviously'
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_with_wrong_authorization_header_token_failing(self):
"""Ensure that a wrong token lead to the correct HTTP error status code"""
auth = "Bearer wrong-token"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_with_wrong_authorization_header_token_missing(self):
"""Ensure that a missing token lead to the correct HTTP error status code"""
auth = "Bearer"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
# If no Authorization header is provided that contains a bearer token,
# authorization passes to the next registered authorization class, or
# (in this case) to standard DRF fallback code, so no error_code is
# provided (yet).
self.assertNotIn('error_code', json.loads(response.content))
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_passing_auth(self):
"""Ensure GETing form over OAuth with correct client credentials succeed"""
auth = self._create_authorization_header()
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
response = self.get_with_bearer_token('/oauth2-test/')
self.assertEqual(response.status_code, status.HTTP_200_OK)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth_url_transport(self):
......@@ -164,72 +182,106 @@ class OAuth2Tests(TestCase):
'/oauth2-test/',
data={'access_token': self.access_token.token}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_passing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True"""
query = urlencode({'access_token': self.access_token.token})
response = self.csrf_client.get('/oauth2-test-debug/?%s' % query)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_failing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False"""
query = urlencode({'access_token': self.access_token.token})
response = self.csrf_client.get('/oauth2-test/?%s' % query)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# This case is handled directly by DRF so no error_code is provided (yet).
self.assertNotIn('error_code', json.loads(response.content))
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth(self):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
auth = self._create_authorization_header()
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
response = self.post_with_bearer_token('/oauth2-test/')
self.assertEqual(response.status_code, status.HTTP_200_OK)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_token_removed_failing_auth(self):
"""Ensure POSTing when there is no OAuth access token in db fails"""
self.access_token.delete()
auth = self._create_authorization_header()
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
response = self.post_with_bearer_token('/oauth2-test/')
self.check_error_codes(
response,
status_code=status.HTTP_401_UNAUTHORIZED,
error_code=authentication.OAUTH2_TOKEN_ERROR_NONEXISTENT
)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_with_refresh_token_failing_auth(self):
"""Ensure POSTing with refresh token instead of access token fails"""
auth = self._create_authorization_header(token=self.refresh_token.token)
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
response = self.post_with_bearer_token('/oauth2-test/', token=self.refresh_token.token)
self.check_error_codes(
response,
status_code=status.HTTP_401_UNAUTHORIZED,
error_code=authentication.OAUTH2_TOKEN_ERROR_NONEXISTENT
)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_with_expired_access_token_failing_auth(self):
"""Ensure POSTing with expired access token fails with an 'Invalid token' error"""
self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
"""Ensure POSTing with expired access token fails with a 'token_expired' error"""
self.access_token.expires = datetime.now() - timedelta(seconds=10) # 10 seconds late
self.access_token.save()
auth = self._create_authorization_header()
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
self.assertIn('Invalid token', response.content)
response = self.post_with_bearer_token('/oauth2-test/')
self.check_error_codes(
response,
status_code=status.HTTP_401_UNAUTHORIZED,
error_code=authentication.OAUTH2_TOKEN_ERROR_EXPIRED
)
TokenErrorDDT = namedtuple('TokenErrorDDT', ['token', 'error_code'])
@ddt.data(
*itertools.product(
[None, {}],
[
TokenErrorDDT('wrong format', authentication.OAUTH2_TOKEN_ERROR_MALFORMED),
TokenErrorDDT('wrong-token', authentication.OAUTH2_TOKEN_ERROR_NONEXISTENT),
TokenErrorDDT('', authentication.OAUTH2_TOKEN_ERROR_NOT_PROVIDED),
]
)
)
@ddt.unpack
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_with_invalid_scope_failing_auth(self):
"""Ensure POSTing with a readonly scope instead of a write scope fails"""
read_only_access_token = self.access_token
read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
read_only_access_token.save()
auth = self._create_authorization_header(token=read_only_access_token.token)
response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_response_for_get_request_with_bad_auth_token(self, http_params, token_error):
response = self.get_with_bearer_token('/oauth2-test/', http_params, token=token_error.token)
self.check_error_codes(
response,
status_code=status.HTTP_401_UNAUTHORIZED,
error_code=token_error.error_code
)
@ddt.data(
TokenErrorDDT('notatoken', authentication.OAUTH2_TOKEN_ERROR_NONEXISTENT),
TokenErrorDDT('malformed token', authentication.OAUTH2_TOKEN_ERROR_MALFORMED),
TokenErrorDDT('', authentication.OAUTH2_TOKEN_ERROR_NOT_PROVIDED),
)
def test_response_for_post_request_with_bad_auth_token(self, token_error):
response = self.post_with_bearer_token('/oauth2-test/', token=token_error.token)
self.check_error_codes(response, status_code=status.HTTP_401_UNAUTHORIZED, error_code=token_error.error_code)
ScopeStatusDDT = namedtuple('ScopeStatusDDT', ['scope', 'read_status', 'write_status'])
@ddt.data(
ScopeStatusDDT('read', read_status=status.HTTP_200_OK, write_status=status.HTTP_403_FORBIDDEN),
ScopeStatusDDT('write', status.HTTP_403_FORBIDDEN, status.HTTP_200_OK),
)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_with_valid_scope_passing_auth(self):
"""Ensure POSTing with a write scope succeed"""
read_write_access_token = self.access_token
read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
read_write_access_token.save()
auth = self._create_authorization_header(token=read_write_access_token.token)
response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_responses_to_scoped_requests(self, scope_statuses):
self.access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT[scope_statuses.scope]
self.access_token.save()
response = self.get_with_bearer_token('/oauth2-with-scope-test/', token=self.access_token.token)
self.assertEqual(response.status_code, scope_statuses.read_status)
response = self.post_with_bearer_token('/oauth2-with-scope-test/', token=self.access_token.token)
self.assertEqual(response.status_code, scope_statuses.write_status)
"""
Test Custom Exceptions
"""
import ddt
from django.test import TestCase
from rest_framework import exceptions as drf_exceptions
from .. import exceptions
@ddt.ddt
class TestDictExceptionsAllowDictDetails(TestCase):
"""
Standard DRF exceptions coerce detail inputs to strings. We want to use
dicts to allow better customization of error messages. Demonstrate that
we can provide dictionaries as exception details, and that custom
classes subclass the relevant DRF exceptions, to provide consistent
exception catching behavior.
"""
def test_drf_errors_coerce_strings(self):
# Demonstrate the base issue we are trying to solve.
exc = drf_exceptions.AuthenticationFailed({u'error_code': -1})
self.assertEqual(exc.detail, u"{u'error_code': -1}")
@ddt.data(
exceptions.AuthenticationFailed,
exceptions.NotAuthenticated,
exceptions.NotFound,
exceptions.ParseError,
exceptions.PermissionDenied,
)
def test_exceptions_allows_dict_detail(self, exception_class):
exc = exception_class({u'error_code': -1})
self.assertEqual(exc.detail, {u'error_code': -1})
def test_method_not_allowed_allows_dict_detail(self):
exc = exceptions.MethodNotAllowed(u'POST', {u'error_code': -1})
self.assertEqual(exc.detail, {u'error_code': -1})
def test_not_acceptable_allows_dict_detail(self):
exc = exceptions.NotAcceptable({u'error_code': -1}, available_renderers=['application/json'])
self.assertEqual(exc.detail, {u'error_code': -1})
self.assertEqual(exc.available_renderers, ['application/json'])
@ddt.ddt
class TestDictExceptionSubclassing(TestCase):
"""
Custom exceptions should subclass standard DRF exceptions, so code that
catches the DRF exceptions also catches ours.
"""
@ddt.data(
(exceptions.AuthenticationFailed, drf_exceptions.AuthenticationFailed),
(exceptions.NotAcceptable, drf_exceptions.NotAcceptable),
(exceptions.NotAuthenticated, drf_exceptions.NotAuthenticated),
(exceptions.NotFound, drf_exceptions.NotFound),
(exceptions.ParseError, drf_exceptions.ParseError),
(exceptions.PermissionDenied, drf_exceptions.PermissionDenied),
)
@ddt.unpack
def test_exceptions_subclass_drf_exceptions(self, exception_class, drf_exception_class):
exc = exception_class({u'error_code': -1})
self.assertIsInstance(exc, drf_exception_class)
def test_method_not_allowed_subclasses_drf_exception(self):
exc = exceptions.MethodNotAllowed(u'POST', {u'error_code': -1})
self.assertIsInstance(exc, drf_exceptions.MethodNotAllowed)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment