Commit 37f7b76f by Tom Christie

Merge pull request #3785 from sheppard/authtoken-import

don't import authtoken model until needed
parents dceb6867 4f407141
...@@ -10,7 +10,6 @@ from django.middleware.csrf import CsrfViewMiddleware ...@@ -10,7 +10,6 @@ from django.middleware.csrf import CsrfViewMiddleware
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import HTTP_HEADER_ENCODING, exceptions from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.authtoken.models import Token
def get_authorization_header(request): def get_authorization_header(request):
...@@ -149,7 +148,14 @@ class TokenAuthentication(BaseAuthentication): ...@@ -149,7 +148,14 @@ class TokenAuthentication(BaseAuthentication):
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
""" """
model = Token model = None
def get_model(self):
if self.model is not None:
return self.model
from rest_framework.authtoken.models import Token
return Token
""" """
A custom token model may be used, but must have the following properties. A custom token model may be used, but must have the following properties.
...@@ -179,9 +185,10 @@ class TokenAuthentication(BaseAuthentication): ...@@ -179,9 +185,10 @@ class TokenAuthentication(BaseAuthentication):
return self.authenticate_credentials(token) return self.authenticate_credentials(token)
def authenticate_credentials(self, key): def authenticate_credentials(self, key):
model = self.get_model()
try: try:
token = self.model.objects.select_related('user').get(key=key) token = model.objects.select_related('user').get(key=key)
except self.model.DoesNotExist: except model.DoesNotExist:
raise exceptions.AuthenticationFailed(_('Invalid token.')) raise exceptions.AuthenticationFailed(_('Invalid token.'))
if not token.user.is_active: if not token.user.is_active:
......
...@@ -21,14 +21,6 @@ class Token(models.Model): ...@@ -21,14 +21,6 @@ class Token(models.Model):
user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token') user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token')
created = models.DateTimeField(auto_now_add=True) created = models.DateTimeField(auto_now_add=True)
class Meta:
# Work around for a bug in Django:
# https://code.djangoproject.com/ticket/19422
#
# Also see corresponding ticket:
# https://github.com/tomchristie/django-rest-framework/issues/705
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if not self.key: if not self.key:
self.key = self.generate_key() self.key = self.generate_key()
......
...@@ -6,6 +6,7 @@ import base64 ...@@ -6,6 +6,7 @@ import base64
from django.conf.urls import include, url from django.conf.urls import include, url
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db import models
from django.http import HttpResponse from django.http import HttpResponse
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
...@@ -25,6 +26,15 @@ from rest_framework.views import APIView ...@@ -25,6 +26,15 @@ from rest_framework.views import APIView
factory = APIRequestFactory() factory = APIRequestFactory()
class CustomToken(models.Model):
key = models.CharField(max_length=40, primary_key=True)
user = models.OneToOneField(User)
class CustomTokenAuthentication(TokenAuthentication):
model = CustomToken
class MockView(APIView): class MockView(APIView):
permission_classes = (permissions.IsAuthenticated,) permission_classes = (permissions.IsAuthenticated,)
...@@ -42,6 +52,7 @@ urlpatterns = [ ...@@ -42,6 +52,7 @@ urlpatterns = [
url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
url(r'^customtoken/$', MockView.as_view(authentication_classes=[CustomTokenAuthentication])),
url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
] ]
...@@ -142,9 +153,11 @@ class SessionAuthTests(TestCase): ...@@ -142,9 +153,11 @@ class SessionAuthTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
class TokenAuthTests(TestCase): class BaseTokenAuthTests(object):
"""Token authentication""" """Token authentication"""
urls = 'tests.test_authentication' urls = 'tests.test_authentication'
model = None
path = None
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
...@@ -154,24 +167,30 @@ class TokenAuthTests(TestCase): ...@@ -154,24 +167,30 @@ class TokenAuthTests(TestCase):
self.user = User.objects.create_user(self.username, self.email, self.password) self.user = User.objects.create_user(self.username, self.email, self.password)
self.key = 'abcd1234' self.key = 'abcd1234'
self.token = Token.objects.create(key=self.key, user=self.user) self.token = self.model.objects.create(key=self.key, user=self.user)
def test_post_form_passing_token_auth(self): def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = 'Token ' + self.key auth = 'Token ' + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_fail_post_form_passing_nonexistent_token_auth(self):
# use a nonexistent token key
auth = 'Token wxyz6789'
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_fail_post_form_passing_invalid_token_auth(self): def test_fail_post_form_passing_invalid_token_auth(self):
# add an 'invalid' unicode character # add an 'invalid' unicode character
auth = 'Token ' + self.key + "¸" auth = 'Token ' + self.key + "¸"
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_post_json_passing_token_auth(self): def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key auth = "Token " + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_json_makes_one_db_query(self): def test_post_json_makes_one_db_query(self):
...@@ -179,29 +198,34 @@ class TokenAuthTests(TestCase): ...@@ -179,29 +198,34 @@ class TokenAuthTests(TestCase):
auth = "Token " + self.key auth = "Token " + self.key
def func_to_test(): def func_to_test():
return self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) return self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertNumQueries(1, func_to_test) self.assertNumQueries(1, func_to_test)
def test_post_form_failing_token_auth(self): def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails""" """Ensure POSTing form over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', {'example': 'example'}) response = self.csrf_client.post(self.path, {'example': 'example'})
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_post_json_failing_token_auth(self): def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails""" """Ensure POSTing json over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') response = self.csrf_client.post(self.path, {'example': 'example'}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
class TokenAuthTests(BaseTokenAuthTests, TestCase):
model = Token
path = '/token/'
def test_token_has_auto_assigned_key_if_none_provided(self): def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key""" """Ensure creating a token with no key will auto-assign a key"""
self.token.delete() self.token.delete()
token = Token.objects.create(user=self.user) token = self.model.objects.create(user=self.user)
self.assertTrue(bool(token.key)) self.assertTrue(bool(token.key))
def test_generate_key_returns_string(self): def test_generate_key_returns_string(self):
"""Ensure generate_key returns a string""" """Ensure generate_key returns a string"""
token = Token() token = self.model()
key = token.generate_key() key = token.generate_key()
self.assertTrue(isinstance(key, six.string_types)) self.assertTrue(isinstance(key, six.string_types))
...@@ -236,6 +260,11 @@ class TokenAuthTests(TestCase): ...@@ -236,6 +260,11 @@ class TokenAuthTests(TestCase):
self.assertEqual(response.data['token'], self.key) self.assertEqual(response.data['token'], self.key)
class CustomTokenAuthTests(BaseTokenAuthTests, TestCase):
model = CustomToken
path = '/customtoken/'
class IncorrectCredentialsTests(TestCase): class IncorrectCredentialsTests(TestCase):
def test_incorrect_credentials(self): def test_incorrect_credentials(self):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment