Commit 4739e1c0 by Tom Christie

Merge work from sebpiq

parents 44df8345 1ff741d1
""" """
The :mod:`authentication` module provides a set of pluggable authentication classes. The :mod:`authentication` module provides a set of pluggable authentication classes.
Authentication behavior is provided by mixing the :class:`mixins.AuthMixin` class into a :class:`View` class. Authentication behavior is provided by mixing the :class:`mixins.RequestMixin` class into a :class:`View` class.
The set of authentication methods which are used is then specified by setting the
:attr:`authentication` attribute on the :class:`View` class, and listing a set of :class:`authentication` classes.
""" """
from django.contrib.auth import authenticate from django.contrib.auth import authenticate
...@@ -23,12 +20,6 @@ class BaseAuthentication(object): ...@@ -23,12 +20,6 @@ class BaseAuthentication(object):
All authentication classes should extend BaseAuthentication. All authentication classes should extend BaseAuthentication.
""" """
def __init__(self, view):
"""
:class:`Authentication` classes are always passed the current view on creation.
"""
self.view = view
def authenticate(self, request): def authenticate(self, request):
""" """
Authenticate the :obj:`request` and return a :obj:`User` or :const:`None`. [*]_ Authenticate the :obj:`request` and return a :obj:`User` or :const:`None`. [*]_
...@@ -87,12 +78,14 @@ class UserLoggedInAuthentication(BaseAuthentication): ...@@ -87,12 +78,14 @@ class UserLoggedInAuthentication(BaseAuthentication):
Returns a :obj:`User` if the request session currently has a logged in user. Returns a :obj:`User` if the request session currently has a logged in user.
Otherwise returns :const:`None`. Otherwise returns :const:`None`.
""" """
if getattr(request, 'user', None) and request.user.is_active: user = getattr(request._request, 'user', None)
if user and user.is_active:
# Enforce CSRF validation for session based authentication. # Enforce CSRF validation for session based authentication.
resp = CsrfViewMiddleware().process_view(request, None, (), {}) resp = CsrfViewMiddleware().process_view(request, None, (), {})
if resp is None: # csrf passed if resp is None: # csrf passed
return request.user return user
return None return None
......
...@@ -3,7 +3,6 @@ The :mod:`mixins` module provides a set of reusable `mixin` ...@@ -3,7 +3,6 @@ The :mod:`mixins` module provides a set of reusable `mixin`
classes that can be added to a `View`. classes that can be added to a `View`.
""" """
from django.contrib.auth.models import AnonymousUser
from django.core.paginator import Paginator from django.core.paginator import Paginator
from django.db.models.fields.related import ForeignKey from django.db.models.fields.related import ForeignKey
from urlobject import URLObject from urlobject import URLObject
...@@ -19,7 +18,7 @@ __all__ = ( ...@@ -19,7 +18,7 @@ __all__ = (
# Base behavior mixins # Base behavior mixins
'RequestMixin', 'RequestMixin',
'ResponseMixin', 'ResponseMixin',
'AuthMixin', 'PermissionsMixin',
'ResourceMixin', 'ResourceMixin',
# Model behavior mixins # Model behavior mixins
'ReadModelMixin', 'ReadModelMixin',
...@@ -49,7 +48,7 @@ class RequestMixin(object): ...@@ -49,7 +48,7 @@ class RequestMixin(object):
This new instance wraps the `request` passed as a parameter, and use This new instance wraps the `request` passed as a parameter, and use
the parsers set on the view. the parsers set on the view.
""" """
return self.request_class(request, parsers=self.parsers) return self.request_class(request, parsers=self.parsers, authentication=self.authentication)
@property @property
def _parsed_media_types(self): def _parsed_media_types(self):
...@@ -101,57 +100,32 @@ class ResponseMixin(object): ...@@ -101,57 +100,32 @@ class ResponseMixin(object):
return self.renderers[0] return self.renderers[0]
########## Auth Mixin ########## ########## Permissions Mixin ##########
class AuthMixin(object): class PermissionsMixin(object):
""" """
Simple :class:`mixin` class to add authentication and permission checking to a :class:`View` class. Simple :class:`mixin` class to add permission checking to a :class:`View` class.
""" """
authentication = () permissions_classes = ()
"""
The set of authentication types that this view can handle.
Should be a tuple/list of classes as described in the :mod:`authentication` module.
"""
permissions = ()
""" """
The set of permissions that will be enforced on this view. The set of permissions that will be enforced on this view.
Should be a tuple/list of classes as described in the :mod:`permissions` module. Should be a tuple/list of classes as described in the :mod:`permissions` module.
""" """
@property def get_permissions(self):
def user(self):
"""
Returns the :obj:`user` for the current request, as determined by the set of
:class:`authentication` classes applied to the :class:`View`.
"""
if not hasattr(self, '_user'):
self._user = self._authenticate()
return self._user
def _authenticate(self):
""" """
Attempt to authenticate the request using each authentication class in turn. Instantiates and returns the list of permissions that this view requires.
Returns a ``User`` object, which may be ``AnonymousUser``.
""" """
for authentication_cls in self.authentication: return [p(self) for p in self.permissions_classes]
authentication = authentication_cls(self)
user = authentication.authenticate(self.request)
if user:
return user
return AnonymousUser()
# TODO: wrap this behavior around dispatch() # TODO: wrap this behavior around dispatch()
def _check_permissions(self): def check_permissions(self, user):
""" """
Check user permissions and either raise an ``ImmediateResponse`` or return. Check user permissions and either raise an ``ImmediateResponse`` or return.
""" """
user = self.user for permission in self.get_permissions():
for permission_cls in self.permissions:
permission = permission_cls(self)
permission.check_permission(user) permission.check_permission(user)
......
""" """
The :mod:`permissions` module bundles a set of permission classes that are used The :mod:`permissions` module bundles a set of permission classes that are used
for checking if a request passes a certain set of constraints. You can assign a permission for checking if a request passes a certain set of constraints.
class to your view by setting your View's :attr:`permissions` class attribute.
Permission behavior is provided by mixing the :class:`mixins.PermissionsMixin` class into a :class:`View` class.
""" """
from django.core.cache import cache from django.core.cache import cache
...@@ -126,7 +127,7 @@ class DjangoModelPermissions(BasePermission): ...@@ -126,7 +127,7 @@ class DjangoModelPermissions(BasePermission):
try: try:
return [perm % kwargs for perm in self.perms_map[method]] return [perm % kwargs for perm in self.perms_map[method]]
except KeyError: except KeyError:
ErrorResponse(status.HTTP_405_METHOD_NOT_ALLOWED) ImmediateResponse(status.HTTP_405_METHOD_NOT_ALLOWED)
def check_permission(self, user): def check_permission(self, user):
method = self.view.method method = self.view.method
......
...@@ -9,12 +9,13 @@ The wrapped request then offers a richer API, in particular : ...@@ -9,12 +9,13 @@ The wrapped request then offers a richer API, in particular :
- full support of PUT method, including support for file uploads - full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content - form overloading of HTTP method, content type and content
""" """
from StringIO import StringIO
from django.contrib.auth.models import AnonymousUser
from djangorestframework import status from djangorestframework import status
from djangorestframework.utils.mediatypes import is_form_media_type from djangorestframework.utils.mediatypes import is_form_media_type
from StringIO import StringIO
__all__ = ('Request',) __all__ = ('Request',)
...@@ -34,6 +35,7 @@ class Request(object): ...@@ -34,6 +35,7 @@ class Request(object):
Kwargs: Kwargs:
- request(HttpRequest). The original request instance. - request(HttpRequest). The original request instance.
- parsers(list/tuple). The parsers to use for parsing the request content. - parsers(list/tuple). The parsers to use for parsing the request content.
- authentications(list/tuple). The authentications used to try authenticating the request's user.
""" """
_USE_FORM_OVERLOADING = True _USE_FORM_OVERLOADING = True
...@@ -41,9 +43,10 @@ class Request(object): ...@@ -41,9 +43,10 @@ class Request(object):
_CONTENTTYPE_PARAM = '_content_type' _CONTENTTYPE_PARAM = '_content_type'
_CONTENT_PARAM = '_content' _CONTENT_PARAM = '_content'
def __init__(self, request=None, parsers=None): def __init__(self, request=None, parsers=None, authentication=None):
self._request = request self._request = request
self.parsers = parsers or () self.parsers = parsers or ()
self.authentication = authentication or ()
self._data = Empty self._data = Empty
self._files = Empty self._files = Empty
self._method = Empty self._method = Empty
...@@ -56,6 +59,12 @@ class Request(object): ...@@ -56,6 +59,12 @@ class Request(object):
""" """
return [parser() for parser in self.parsers] return [parser() for parser in self.parsers]
def get_authentications(self):
"""
Instantiates and returns the list of parsers the request will use.
"""
return [authentication() for authentication in self.authentication]
@property @property
def method(self): def method(self):
""" """
...@@ -113,6 +122,16 @@ class Request(object): ...@@ -113,6 +122,16 @@ class Request(object):
self._load_data_and_files() self._load_data_and_files()
return self._files return self._files
@property
def user(self):
"""
Returns the :obj:`user` for the current request, authenticated
with the set of :class:`authentication` instances applied to the :class:`Request`.
"""
if not hasattr(self, '_user'):
self._user = self._authenticate()
return self._user
def _load_data_and_files(self): def _load_data_and_files(self):
""" """
Parses the request content into self.DATA and self.FILES. Parses the request content into self.DATA and self.FILES.
...@@ -205,6 +224,17 @@ class Request(object): ...@@ -205,6 +224,17 @@ class Request(object):
}, },
status=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE) status=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE)
def _authenticate(self):
"""
Attempt to authenticate the request using each authentication instance in turn.
Returns a ``User`` object, which may be ``AnonymousUser``.
"""
for authentication in self.get_authentications():
user = authentication.authenticate(self)
if user:
return user
return AnonymousUser()
def __getattr__(self, name): def __getattr__(self, name):
""" """
Proxy other attributes to the underlying HttpRequest object. Proxy other attributes to the underlying HttpRequest object.
......
...@@ -168,6 +168,18 @@ class ImmediateResponse(Response, Exception): ...@@ -168,6 +168,18 @@ class ImmediateResponse(Response, Exception):
An exception representing an Response that should be returned immediately. An exception representing an Response that should be returned immediately.
Any content should be serialized as-is, without being filtered. Any content should be serialized as-is, without being filtered.
""" """
#TODO: this is just a temporary fix, the whole rendering/support for ImmediateResponse, should be remade : see issue #163
def render(self):
try:
return super(Response, self).render()
except ImmediateResponse:
renderer, media_type = self._determine_renderer()
self.renderers.remove(renderer)
if len(self.renderers) == 0:
raise RuntimeError('Caught an ImmediateResponse while '\
'trying to render an ImmediateResponse')
return self.render()
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.response = Response(*args, **kwargs) self.response = Response(*args, **kwargs)
...@@ -12,7 +12,7 @@ import base64 ...@@ -12,7 +12,7 @@ import base64
class MockView(View): class MockView(View):
permissions = (permissions.IsAuthenticated,) permissions_classes = (permissions.IsAuthenticated,)
def post(self, request): def post(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3}) return HttpResponse({'a': 1, 'b': 2, 'c': 3})
......
...@@ -281,6 +281,6 @@ class TestPagination(TestCase): ...@@ -281,6 +281,6 @@ class TestPagination(TestCase):
paginated URLs. So page 1 should contain ?page=2, not ?page=1&page=2 """ paginated URLs. So page 1 should contain ?page=2, not ?page=1&page=2 """
request = self.req.get('/paginator/?page=1') request = self.req.get('/paginator/?page=1')
response = MockPaginatorView.as_view()(request) response = MockPaginatorView.as_view()(request)
content = json.loads(response.rendered_content) content = response.raw_content
self.assertTrue('page=2' in content['next']) self.assertTrue('page=2' in content['next'])
self.assertFalse('page=1' in content['next']) self.assertFalse('page=1' in content['next'])
...@@ -4,7 +4,7 @@ import unittest ...@@ -4,7 +4,7 @@ import unittest
from django.conf.urls.defaults import patterns, url, include from django.conf.urls.defaults import patterns, url, include
from django.test import TestCase from django.test import TestCase
from djangorestframework.response import Response, NotAcceptable from djangorestframework.response import Response, NotAcceptable, ImmediateResponse
from djangorestframework.views import View from djangorestframework.views import View
from djangorestframework.compat import RequestFactory from djangorestframework.compat import RequestFactory
from djangorestframework import status from djangorestframework import status
...@@ -95,10 +95,9 @@ class TestResponseDetermineRenderer(TestCase): ...@@ -95,10 +95,9 @@ class TestResponseDetermineRenderer(TestCase):
class TestResponseRenderContent(TestCase): class TestResponseRenderContent(TestCase):
def get_response(self, url='', accept_list=[], content=None, renderers=None):
def get_response(self, url='', accept_list=[], content=None):
request = RequestFactory().get(url, HTTP_ACCEPT=','.join(accept_list)) request = RequestFactory().get(url, HTTP_ACCEPT=','.join(accept_list))
return Response(request=request, content=content, renderers=DEFAULT_RENDERERS) return Response(request=request, content=content, renderers=renderers or DEFAULT_RENDERERS)
def test_render(self): def test_render(self):
""" """
...@@ -107,10 +106,43 @@ class TestResponseRenderContent(TestCase): ...@@ -107,10 +106,43 @@ class TestResponseRenderContent(TestCase):
content = {'a': 1, 'b': [1, 2, 3]} content = {'a': 1, 'b': [1, 2, 3]}
content_type = 'application/json' content_type = 'application/json'
response = self.get_response(accept_list=[content_type], content=content) response = self.get_response(accept_list=[content_type], content=content)
response.render() response = response.render()
self.assertEqual(json.loads(response.content), content) self.assertEqual(json.loads(response.content), content)
self.assertEqual(response['Content-Type'], content_type) self.assertEqual(response['Content-Type'], content_type)
def test_render_no_renderer(self):
"""
Test rendering response when no renderer can satisfy accept.
"""
content = 'bla'
content_type = 'weirdcontenttype'
response = self.get_response(accept_list=[content_type], content=content)
response = response.render()
self.assertEqual(response.status_code, 406)
self.assertIsNotNone(response.content)
# def test_render_renderer_raises_ImmediateResponse(self):
# """
# Test rendering response when renderer raises ImmediateResponse
# """
# class PickyJSONRenderer(BaseRenderer):
# """
# A renderer that doesn't make much sense, just to try
# out raising an ImmediateResponse
# """
# media_type = 'application/json'
# def render(self, obj=None, media_type=None):
# raise ImmediateResponse({'error': '!!!'}, status=400)
# response = self.get_response(
# accept_list=['application/json'],
# renderers=[PickyJSONRenderer, JSONRenderer]
# )
# response = response.render()
# self.assertEqual(response.status_code, 400)
# self.assertEqual(response.content, json.dumps({'error': '!!!'}))
DUMMYSTATUS = status.HTTP_200_OK DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent' DUMMYCONTENT = 'dummycontent'
......
...@@ -13,17 +13,17 @@ from djangorestframework.resources import FormResource ...@@ -13,17 +13,17 @@ from djangorestframework.resources import FormResource
from djangorestframework.response import Response from djangorestframework.response import Response
class MockView(View): class MockView(View):
permissions = ( PerUserThrottling, ) permissions_classes = ( PerUserThrottling, )
throttle = '3/sec' throttle = '3/sec'
def get(self, request): def get(self, request):
return Response('foo') return Response('foo')
class MockView_PerViewThrottling(MockView): class MockView_PerViewThrottling(MockView):
permissions = ( PerViewThrottling, ) permissions_classes = ( PerViewThrottling, )
class MockView_PerResourceThrottling(MockView): class MockView_PerResourceThrottling(MockView):
permissions = ( PerResourceThrottling, ) permissions_classes = ( PerResourceThrottling, )
resource = FormResource resource = FormResource
class MockView_MinuteThrottling(MockView): class MockView_MinuteThrottling(MockView):
...@@ -54,7 +54,7 @@ class ThrottlingTests(TestCase): ...@@ -54,7 +54,7 @@ class ThrottlingTests(TestCase):
""" """
Explicitly set the timer, overriding time.time() Explicitly set the timer, overriding time.time()
""" """
view.permissions[0].timer = lambda self: value view.permissions_classes[0].timer = lambda self: value
def test_request_throttling_expires(self): def test_request_throttling_expires(self):
""" """
......
...@@ -67,7 +67,7 @@ _resource_classes = ( ...@@ -67,7 +67,7 @@ _resource_classes = (
) )
class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): class View(ResourceMixin, RequestMixin, ResponseMixin, PermissionsMixin, DjangoView):
""" """
Handles incoming requests and maps them to REST operations. Handles incoming requests and maps them to REST operations.
Performs request deserialization, response serialization, authentication and input validation. Performs request deserialization, response serialization, authentication and input validation.
...@@ -90,7 +90,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): ...@@ -90,7 +90,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
""" """
authentication = (authentication.UserLoggedInAuthentication, authentication = (authentication.UserLoggedInAuthentication,
authentication.BasicAuthentication) authentication.BasicAuthentication)
""" """
List of all authenticating methods to attempt. List of all authenticating methods to attempt.
""" """
...@@ -215,6 +215,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): ...@@ -215,6 +215,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
request = self.create_request(request) request = self.create_request(request)
self.request = request self.request = request
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs
self.headers = self.default_response_headers self.headers = self.default_response_headers
...@@ -222,8 +223,8 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): ...@@ -222,8 +223,8 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
try: try:
self.initial(request, *args, **kwargs) self.initial(request, *args, **kwargs)
# Authenticate and check request has the relevant permissions # check that user has the relevant permissions
self._check_permissions() self.check_permissions(request.user)
# Get the appropriate handler method # Get the appropriate handler method
if request.method.lower() in self.http_method_names: if request.method.lower() in self.http_method_names:
......
from django.test import TestCase
from django.core.urlresolvers import reverse
from django.test.client import Client
class NaviguatePermissionsExamples(TestCase):
"""
Sanity checks for permissions examples
"""
def test_throttled_resource(self):
url = reverse('throttled-resource')
for i in range(0, 10):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
response = self.client.get(url)
self.assertEqual(response.status_code, 503)
def test_loggedin_resource(self):
url = reverse('loggedin-resource')
response = self.client.get(url)
self.assertEqual(response.status_code, 403)
loggedin_client = Client()
loggedin_client.login(username='test', password='test')
response = loggedin_client.get(url)
self.assertEqual(response.status_code, 200)
...@@ -30,7 +30,7 @@ class ThrottlingExampleView(View): ...@@ -30,7 +30,7 @@ class ThrottlingExampleView(View):
throttle will be applied until 60 seconds have passed since the first request. throttle will be applied until 60 seconds have passed since the first request.
""" """
permissions = (PerUserThrottling,) permissions_classes = (PerUserThrottling,)
throttle = '10/min' throttle = '10/min'
def get(self, request): def get(self, request):
...@@ -47,7 +47,7 @@ class LoggedInExampleView(View): ...@@ -47,7 +47,7 @@ class LoggedInExampleView(View):
`curl -X GET -H 'Accept: application/json' -u test:test http://localhost:8000/permissions-example` `curl -X GET -H 'Accept: application/json' -u test:test http://localhost:8000/permissions-example`
""" """
permissions = (IsAuthenticated, ) permissions_classes = (IsAuthenticated, )
def get(self, request): def get(self, request):
return Response('You have permission to view this resource') return Response('You have permission to view this resource')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment