Commit c28b7193 by Tom Christie

Refactored throttling

parent 8457c871
...@@ -49,8 +49,14 @@ class UnsupportedMediaType(APIException): ...@@ -49,8 +49,14 @@ class UnsupportedMediaType(APIException):
class Throttled(APIException): class Throttled(APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS status_code = status.HTTP_429_TOO_MANY_REQUESTS
default_detail = "Request was throttled. Expected available in %d seconds." default_detail = "Request was throttled."
extra_detail = "Expected available in %d second%s."
def __init__(self, wait, detail=None): def __init__(self, wait=None, detail=None):
import math import math
self.detail = (detail or self.default_detail) % int(math.ceil(wait)) self.wait = wait and math.ceil(wait) or None
if wait is not None:
format = detail or self.default_detail + self.extra_detail
self.detail = format % (self.wait, self.wait != 1 and 's' or '')
else:
self.detail = detail or self.default_detail
...@@ -81,7 +81,7 @@ class BaseParser(object): ...@@ -81,7 +81,7 @@ class BaseParser(object):
Should return parsed data, or a DataAndFiles object consisting of the Should return parsed data, or a DataAndFiles object consisting of the
parsed data and files. parsed data and files.
""" """
raise NotImplementedError(".parse_stream() Must be overridden to be implemented.") raise NotImplementedError(".parse_stream() must be overridden.")
class JSONParser(BaseParser): class JSONParser(BaseParser):
......
...@@ -5,10 +5,6 @@ for checking if a request passes a certain set of constraints. ...@@ -5,10 +5,6 @@ for checking if a request passes a certain set of constraints.
Permission behavior is provided by mixing the :class:`mixins.PermissionsMixin` class into a :class:`View` class. Permission behavior is provided by mixing the :class:`mixins.PermissionsMixin` class into a :class:`View` class.
""" """
from django.core.cache import cache
from djangorestframework.exceptions import PermissionDenied, Throttled
import time
__all__ = ( __all__ = (
'BasePermission', 'BasePermission',
'FullAnonAccess', 'FullAnonAccess',
...@@ -32,20 +28,11 @@ class BasePermission(object): ...@@ -32,20 +28,11 @@ class BasePermission(object):
""" """
self.view = view self.view = view
def check_permission(self, auth): def check_permission(self, request, obj=None):
""" """
Should simply return, or raise an :exc:`response.ImmediateResponse`. Should simply return, or raise an :exc:`response.ImmediateResponse`.
""" """
pass raise NotImplementedError(".check_permission() must be overridden.")
class FullAnonAccess(BasePermission):
"""
Allows full access.
"""
def check_permission(self, user):
pass
class IsAuthenticated(BasePermission): class IsAuthenticated(BasePermission):
...@@ -53,9 +40,10 @@ class IsAuthenticated(BasePermission): ...@@ -53,9 +40,10 @@ class IsAuthenticated(BasePermission):
Allows access only to authenticated users. Allows access only to authenticated users.
""" """
def check_permission(self, user): def check_permission(self, request, obj=None):
if not user.is_authenticated(): if request.user.is_authenticated():
raise PermissionDenied() return True
return False
class IsAdminUser(BasePermission): class IsAdminUser(BasePermission):
...@@ -63,20 +51,22 @@ class IsAdminUser(BasePermission): ...@@ -63,20 +51,22 @@ class IsAdminUser(BasePermission):
Allows access only to admin users. Allows access only to admin users.
""" """
def check_permission(self, user): def check_permission(self, request, obj=None):
if not user.is_staff: if request.user.is_staff:
raise PermissionDenied() return True
return False
class IsUserOrIsAnonReadOnly(BasePermission): class IsAuthenticatedOrReadOnly(BasePermission):
""" """
The request is authenticated as a user, or is a read-only request. The request is authenticated as a user, or is a read-only request.
""" """
def check_permission(self, user): def check_permission(self, request, obj=None):
if (not user.is_authenticated() and if (request.user.is_authenticated() or
self.view.method not in SAFE_METHODS): request.method in SAFE_METHODS):
raise PermissionDenied() return True
return False
class DjangoModelPermissions(BasePermission): class DjangoModelPermissions(BasePermission):
...@@ -114,128 +104,10 @@ class DjangoModelPermissions(BasePermission): ...@@ -114,128 +104,10 @@ class DjangoModelPermissions(BasePermission):
} }
return [perm % kwargs for perm in self.perms_map[method]] return [perm % kwargs for perm in self.perms_map[method]]
def check_permission(self, user): def check_permission(self, request, obj=None):
method = self.view.method model_cls = self.view.model
model_cls = self.view.resource.model perms = self.get_required_permissions(request.method, model_cls)
perms = self.get_required_permissions(method, model_cls)
if not user.is_authenticated or not user.has_perms(perms):
raise PermissionDenied()
class BaseThrottle(BasePermission):
"""
Rate throttling of requests.
The rate (requests / seconds) is set by a :attr:`throttle` attribute
on the :class:`.View` class. The attribute is a string of the form 'number of
requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
attr_name = 'throttle'
default = '0/sec'
timer = time.time
def get_cache_key(self):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
"""
pass
def check_permission(self, auth):
"""
Check the throttling.
Return `None` or raise an :exc:`.ImmediateResponse`.
"""
num, period = getattr(self.view, self.attr_name, self.default).split('/')
self.num_requests = int(num)
self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
self.auth = auth
self.check_throttle()
def check_throttle(self):
"""
Implement the check to see if the request should be throttled.
On success calls :meth:`throttle_success`.
On failure calls :meth:`throttle_failure`.
"""
self.key = self.get_cache_key()
self.history = cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
self.throttle_failure()
else:
self.throttle_success()
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self.history.insert(0, self.now)
cache.set(self.key, self.history, self.duration)
header = 'status=SUCCESS; next=%.2f sec' % self.next()
self.view.headers['X-Throttle'] = header
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
Raises a '503 service unavailable' response.
"""
wait = self.next()
header = 'status=FAILURE; next=%.2f sec' % wait
self.view.headers['X-Throttle'] = header
raise Throttled(wait)
def next(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
return remaining_duration / float(available_requests)
class PerUserThrottling(BaseThrottle):
"""
Limits the rate of API calls that may be made by a given user.
The user id will be used as a unique identifier if the user is
authenticated. For anonymous requests, the IP address of the client will
be used.
"""
def get_cache_key(self):
if self.auth.is_authenticated():
ident = self.auth.id
else:
ident = self.view.request.META.get('REMOTE_ADDR', None)
return 'throttle_user_%s' % ident
class PerViewThrottling(BaseThrottle):
"""
Limits the rate of API calls that may be used on a given view.
The class name of the view is used as a unique identifier to
throttle against.
"""
def get_cache_key(self): if request.user.is_authenticated() and request.user.has_perms(perms, obj):
return 'throttle_view_%s' % self.view.__class__.__name__ return True
return False
...@@ -8,24 +8,24 @@ from django.core.cache import cache ...@@ -8,24 +8,24 @@ from django.core.cache import cache
from djangorestframework.compat import RequestFactory from djangorestframework.compat import RequestFactory
from djangorestframework.views import APIView from djangorestframework.views import APIView
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling from djangorestframework.throttling import PerUserThrottling, PerViewThrottling
from djangorestframework.response import Response from djangorestframework.response import Response
class MockView(APIView): class MockView(APIView):
permission_classes = (PerUserThrottling,) throttle_classes = (PerUserThrottling,)
throttle = '3/sec' rate = '3/sec'
def get(self, request): def get(self, request):
return Response('foo') return Response('foo')
class MockView_PerViewThrottling(MockView): class MockView_PerViewThrottling(MockView):
permission_classes = (PerViewThrottling,) throttle_classes = (PerViewThrottling,)
class MockView_MinuteThrottling(MockView): class MockView_MinuteThrottling(MockView):
throttle = '3/min' rate = '3/min'
class ThrottlingTests(TestCase): class ThrottlingTests(TestCase):
...@@ -51,7 +51,7 @@ class ThrottlingTests(TestCase): ...@@ -51,7 +51,7 @@ class ThrottlingTests(TestCase):
""" """
Explicitly set the timer, overriding time.time() Explicitly set the timer, overriding time.time()
""" """
view.permission_classes[0].timer = lambda self: value view.throttle_classes[0].timer = lambda self: value
def test_request_throttling_expires(self): def test_request_throttling_expires(self):
""" """
...@@ -101,17 +101,20 @@ class ThrottlingTests(TestCase): ...@@ -101,17 +101,20 @@ class ThrottlingTests(TestCase):
for timer, expect in expected_headers: for timer, expect in expected_headers:
self.set_throttle_timer(view, timer) self.set_throttle_timer(view, timer)
response = view.as_view()(request) response = view.as_view()(request)
self.assertEquals(response['X-Throttle'], expect) if expect is not None:
self.assertEquals(response['X-Throttle-Wait-Seconds'], expect)
else:
self.assertFalse('X-Throttle-Wait-Seconds' in response.headers)
def test_seconds_fields(self): def test_seconds_fields(self):
""" """
Ensure for second based throttles. Ensure for second based throttles.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView, self.ensure_response_header_contains_proper_throttle_field(MockView,
((0, 'status=SUCCESS; next=0.33 sec'), ((0, None),
(0, 'status=SUCCESS; next=0.50 sec'), (0, None),
(0, 'status=SUCCESS; next=1.00 sec'), (0, None),
(0, 'status=FAILURE; next=1.00 sec') (0, '1')
)) ))
def test_minutes_fields(self): def test_minutes_fields(self):
...@@ -119,10 +122,10 @@ class ThrottlingTests(TestCase): ...@@ -119,10 +122,10 @@ class ThrottlingTests(TestCase):
Ensure for minute based throttles. Ensure for minute based throttles.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, 'status=SUCCESS; next=20.00 sec'), ((0, None),
(0, 'status=SUCCESS; next=30.00 sec'), (0, None),
(0, 'status=SUCCESS; next=60.00 sec'), (0, None),
(0, 'status=FAILURE; next=60.00 sec') (0, '60')
)) ))
def test_next_rate_remains_constant_if_followed(self): def test_next_rate_remains_constant_if_followed(self):
...@@ -131,9 +134,9 @@ class ThrottlingTests(TestCase): ...@@ -131,9 +134,9 @@ class ThrottlingTests(TestCase):
the throttling rate should stay constant. the throttling rate should stay constant.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, 'status=SUCCESS; next=20.00 sec'), ((0, None),
(20, 'status=SUCCESS; next=20.00 sec'), (20, None),
(40, 'status=SUCCESS; next=20.00 sec'), (40, None),
(60, 'status=SUCCESS; next=20.00 sec'), (60, None),
(80, 'status=SUCCESS; next=20.00 sec') (80, None)
)) ))
from django.core.cache import cache
import time
class BaseThrottle(object):
"""
Rate throttling of requests.
"""
def __init__(self, view=None):
"""
All throttles hold a reference to the instantiating view.
"""
self.view = view
def check_throttle(self, request):
"""
Return `True` if the request should be allowed, `False` otherwise.
"""
raise NotImplementedError('.check_throttle() must be overridden')
def wait(self):
"""
Optionally, return a recommeded number of seconds to wait before
the next request.
"""
return None
class SimpleCachingThrottle(BaseThrottle):
"""
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
The rate (requests / seconds) is set by a :attr:`throttle` attribute
on the :class:`.View` class. The attribute is a string of the form 'number of
requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
attr_name = 'rate'
rate = '1000/day'
timer = time.time
def __init__(self, view):
"""
Check the throttling.
Return `None` or raise an :exc:`.ImmediateResponse`.
"""
super(SimpleCachingThrottle, self).__init__(view)
num, period = getattr(view, self.attr_name, self.rate).split('/')
self.num_requests = int(num)
self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
def get_cache_key(self, request):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
"""
raise NotImplementedError('.get_cache_key() must be overridden')
def check_throttle(self, request):
"""
Implement the check to see if the request should be throttled.
On success calls :meth:`throttle_success`.
On failure calls :meth:`throttle_failure`.
"""
self.key = self.get_cache_key(request)
self.history = cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self.history.insert(0, self.now)
cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
"""
return False
def wait(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
return remaining_duration / float(available_requests)
class PerUserThrottling(SimpleCachingThrottle):
"""
Limits the rate of API calls that may be made by a given user.
The user id will be used as a unique identifier if the user is
authenticated. For anonymous requests, the IP address of the client will
be used.
"""
def get_cache_key(self, request):
if request.user.is_authenticated():
ident = request.user.id
else:
ident = request.META.get('REMOTE_ADDR', None)
return 'throttle_user_%s' % ident
class PerViewThrottling(SimpleCachingThrottle):
"""
Limits the rate of API calls that may be used on a given view.
The class name of the view is used as a unique identifier to
throttle against.
"""
def get_cache_key(self, request):
return 'throttle_view_%s' % self.view.__class__.__name__
...@@ -18,7 +18,7 @@ from djangorestframework.compat import View as _View, apply_markdown ...@@ -18,7 +18,7 @@ from djangorestframework.compat import View as _View, apply_markdown
from djangorestframework.response import Response from djangorestframework.response import Response
from djangorestframework.request import Request from djangorestframework.request import Request
from djangorestframework.settings import api_settings from djangorestframework.settings import api_settings
from djangorestframework import parsers, authentication, permissions, status, exceptions, mixins from djangorestframework import parsers, authentication, status, exceptions, mixins
__all__ = ( __all__ = (
...@@ -86,7 +86,12 @@ class APIView(_View): ...@@ -86,7 +86,12 @@ class APIView(_View):
List of all authenticating methods to attempt. List of all authenticating methods to attempt.
""" """
permission_classes = (permissions.FullAnonAccess,) throttle_classes = ()
"""
List of all throttles to check.
"""
permission_classes = ()
""" """
List of all permissions that must be checked. List of all permissions that must be checked.
""" """
...@@ -195,12 +200,27 @@ class APIView(_View): ...@@ -195,12 +200,27 @@ class APIView(_View):
""" """
return [permission(self) for permission in self.permission_classes] return [permission(self) for permission in self.permission_classes]
def check_permissions(self, user): def get_throttles(self):
""" """
Check user permissions and either raise an ``ImmediateResponse`` or return. Instantiates and returns the list of thottles that this view requires.
"""
return [throttle(self) for throttle in self.throttle_classes]
def check_permissions(self, request, obj=None):
"""
Check user permissions and either raise an ``PermissionDenied`` or return.
""" """
for permission in self.get_permissions(): for permission in self.get_permissions():
permission.check_permission(user) if not permission.check_permission(request, obj):
raise exceptions.PermissionDenied()
def check_throttles(self, request):
"""
Check throttles and either raise a `Throttled` exception or return.
"""
for throttle in self.get_throttles():
if not throttle.check_throttle(request):
raise exceptions.Throttled(throttle.wait())
def initial(self, request, *args, **kargs): def initial(self, request, *args, **kargs):
""" """
...@@ -232,6 +252,9 @@ class APIView(_View): ...@@ -232,6 +252,9 @@ class APIView(_View):
Handle any exception that occurs, by returning an appropriate response, Handle any exception that occurs, by returning an appropriate response,
or re-raising the error. or re-raising the error.
""" """
if isinstance(exc, exceptions.Throttled):
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
if isinstance(exc, exceptions.APIException): if isinstance(exc, exceptions.APIException):
return Response({'detail': exc.detail}, status=exc.status_code) return Response({'detail': exc.detail}, status=exc.status_code)
elif isinstance(exc, Http404): elif isinstance(exc, Http404):
...@@ -255,8 +278,9 @@ class APIView(_View): ...@@ -255,8 +278,9 @@ class APIView(_View):
try: try:
self.initial(request, *args, **kwargs) self.initial(request, *args, **kwargs)
# check that user has the relevant permissions # Check that the request is allowed
self.check_permissions(request.user) self.check_permissions(request)
self.check_throttles(request)
# Get the appropriate handler method # Get the appropriate handler method
if request.method.lower() in self.http_method_names: if request.method.lower() in self.http_method_names:
...@@ -283,11 +307,12 @@ class BaseView(APIView): ...@@ -283,11 +307,12 @@ class BaseView(APIView):
serializer_class = None serializer_class = None
def get_serializer(self, data=None, files=None, instance=None): def get_serializer(self, data=None, files=None, instance=None):
# TODO: add support for files
context = { context = {
'request': self.request, 'request': self.request,
'format': self.kwargs.get('format', None) 'format': self.kwargs.get('format', None)
} }
return self.serializer_class(data, context=context) return self.serializer_class(data, instance=instance, context=context)
class MultipleObjectBaseView(MultipleObjectMixin, BaseView): class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
...@@ -301,7 +326,13 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView): ...@@ -301,7 +326,13 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView):
""" """
Base class for generic views onto a model instance. Base class for generic views onto a model instance.
""" """
pass
def get_object(self):
"""
Override default to add support for object-level permissions.
"""
super(self, SingleObjectBaseView).get_object()
self.check_permissions(self.request, self.object)
# Concrete view classes that provide method handlers # Concrete view classes that provide method handlers
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment