Commit 1cb84cd4 by Tom Christie

Merge throttling and fix up a coupla things

parents ff6e7832 49a2817e
...@@ -31,11 +31,6 @@ _503_SERVICE_UNAVAILABLE = ErrorResponse( ...@@ -31,11 +31,6 @@ _503_SERVICE_UNAVAILABLE = ErrorResponse(
{'detail': 'request was throttled'}) {'detail': 'request was throttled'})
class ConfigurationException(BaseException):
"""To alert for bad configuration decisions as a convenience."""
pass
class BasePermission(object): class BasePermission(object):
""" """
A base class from which all permission classes should inherit. A base class from which all permission classes should inherit.
...@@ -142,14 +137,13 @@ class BaseThrottle(BasePermission): ...@@ -142,14 +137,13 @@ class BaseThrottle(BasePermission):
# Drop any requests from the history which have now passed the # Drop any requests from the history which have now passed the
# throttle duration # throttle duration
while self.history and self.history[0] <= self.now - self.duration: while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop() self.history.pop()
if len(self.history) >= self.num_requests: if len(self.history) >= self.num_requests:
self.throttle_failure() self.throttle_failure()
else: else:
self.throttle_success() self.throttle_success()
def throttle_success(self): def throttle_success(self):
""" """
Inserts the current request's timestamp along with the key Inserts the current request's timestamp along with the key
...@@ -157,13 +151,30 @@ class BaseThrottle(BasePermission): ...@@ -157,13 +151,30 @@ class BaseThrottle(BasePermission):
""" """
self.history.insert(0, self.now) self.history.insert(0, self.now)
cache.set(self.key, self.history, self.duration) cache.set(self.key, self.history, self.duration)
header = 'status=SUCCESS; next=%s sec' % self.next()
self.view.add_header('X-Throttle', header)
def throttle_failure(self): def throttle_failure(self):
""" """
Called when a request to the API has failed due to throttling. Called when a request to the API has failed due to throttling.
Raises a '503 service unavailable' response. Raises a '503 service unavailable' response.
""" """
header = 'status=FAILURE; next=%s sec' % self.next()
self.view.add_header('X-Throttle', header)
raise _503_SERVICE_UNAVAILABLE raise _503_SERVICE_UNAVAILABLE
def next(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
return '%.2f' % (remaining_duration / float(available_requests))
class PerUserThrottling(BaseThrottle): class PerUserThrottling(BaseThrottle):
......
...@@ -13,25 +13,27 @@ os.environ['DJANGO_SETTINGS_MODULE'] = 'djangorestframework.runtests.settings' ...@@ -13,25 +13,27 @@ os.environ['DJANGO_SETTINGS_MODULE'] = 'djangorestframework.runtests.settings'
from django.conf import settings from django.conf import settings
from django.test.utils import get_runner from django.test.utils import get_runner
def usage():
return """
Usage: python runtests.py [UnitTestClass].[method]
You can pass the Class name of the `UnitTestClass` you want to test.
Append a method name if you only want to test a specific method of that class.
"""
def main(): def main():
TestRunner = get_runner(settings) TestRunner = get_runner(settings)
if hasattr(TestRunner, 'func_name'): test_runner = TestRunner()
# Pre 1.2 test runners were just functions, if len(sys.argv) == 2:
# and did not support the 'failfast' option. test_case = '.' + sys.argv[1]
import warnings elif len(sys.argv) == 1:
warnings.warn( test_case = ''
'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
DeprecationWarning
)
failures = TestRunner(['djangorestframework'])
else: else:
test_runner = TestRunner() print usage()
if len(sys.argv) > 1: sys.exit(1)
test_case = '.' + sys.argv[1] failures = test_runner.run_tests(['djangorestframework' + test_case])
else:
test_case = ''
failures = test_runner.run_tests(['djangorestframework' + test_case])
sys.exit(failures) sys.exit(failures)
......
""" """
Tests for the throttling implementations in the permissions module. Tests for the throttling implementations in the permissions module.
""" """
import time
from django.conf.urls.defaults import patterns
from django.test import TestCase from django.test import TestCase
from django.utils import simplejson as json
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.cache import cache from django.core.cache import cache
from djangorestframework.compat import RequestFactory from djangorestframework.compat import RequestFactory
from djangorestframework.views import View from djangorestframework.views import View
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling, ConfigurationException from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling
from djangorestframework.resources import FormResource from djangorestframework.resources import FormResource
class MockView(View): class MockView(View):
permissions = ( PerUserThrottling, ) permissions = ( PerUserThrottling, )
throttle = '3/sec' # 3 requests per second throttle = '3/sec'
def get(self, request): def get(self, request):
return 'foo' return 'foo'
class MockView1(MockView): class MockView_PerViewThrottling(MockView):
permissions = ( PerViewThrottling, ) permissions = ( PerViewThrottling, )
class MockView2(MockView): class MockView_PerResourceThrottling(MockView):
permissions = ( PerResourceThrottling, ) permissions = ( PerResourceThrottling, )
#No resource set
class MockView3(MockView2):
resource = FormResource resource = FormResource
class MockView_MinuteThrottling(MockView):
throttle = '3/min'
class ThrottlingTests(TestCase): class ThrottlingTests(TestCase):
urls = 'djangorestframework.tests.throttling' urls = 'djangorestframework.tests.throttling'
def setUp(self): def setUp(self):
"""Reset the cache so that no throttles will be active""" """
Reset the cache so that no throttles will be active
"""
cache.clear() cache.clear()
self.factory = RequestFactory() self.factory = RequestFactory()
def test_requests_are_throttled(self): def test_requests_are_throttled(self):
"""Ensure request rate is limited""" """
Ensure request rate is limited
"""
request = self.factory.get('/') request = self.factory.get('/')
for dummy in range(4): for dummy in range(4):
response = MockView.as_view()(request) response = MockView.as_view()(request)
self.assertEqual(503, response.status_code) self.assertEqual(503, response.status_code)
def set_throttle_timer(self, view, value):
"""
Explicitly set the timer, overriding time.time()
"""
view.permissions[0].timer = lambda self: value
def test_request_throttling_expires(self): def test_request_throttling_expires(self):
""" """
Ensure request rate is limited for a limited duration only Ensure request rate is limited for a limited duration only
""" """
# Explicitly set the timer, overridding time.time() self.set_throttle_timer(MockView, 0)
MockView.permissions[0].timer = lambda self: 0
request = self.factory.get('/') request = self.factory.get('/')
for dummy in range(4): for dummy in range(4):
...@@ -59,7 +67,7 @@ class ThrottlingTests(TestCase): ...@@ -59,7 +67,7 @@ class ThrottlingTests(TestCase):
self.assertEqual(503, response.status_code) self.assertEqual(503, response.status_code)
# Advance the timer by one second # Advance the timer by one second
MockView.permissions[0].timer = lambda self: 1 self.set_throttle_timer(MockView, 1)
response = MockView.as_view()(request) response = MockView.as_view()(request)
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
...@@ -68,20 +76,73 @@ class ThrottlingTests(TestCase): ...@@ -68,20 +76,73 @@ class ThrottlingTests(TestCase):
request = self.factory.get('/') request = self.factory.get('/')
request.user = User.objects.create(username='a') request.user = User.objects.create(username='a')
for dummy in range(3): for dummy in range(3):
response = view.as_view()(request) view.as_view()(request)
request.user = User.objects.create(username='b') request.user = User.objects.create(username='b')
response = view.as_view()(request) response = view.as_view()(request)
self.assertEqual(expect, response.status_code) self.assertEqual(expect, response.status_code)
def test_request_throttling_is_per_user(self): def test_request_throttling_is_per_user(self):
"""Ensure request rate is only limited per user, not globally for PerUserThrottles""" """
Ensure request rate is only limited per user, not globally for
PerUserThrottles
"""
self.ensure_is_throttled(MockView, 200) self.ensure_is_throttled(MockView, 200)
def test_request_throttling_is_per_view(self): def test_request_throttling_is_per_view(self):
"""Ensure request rate is limited globally per View for PerViewThrottles""" """
self.ensure_is_throttled(MockView1, 503) Ensure request rate is limited globally per View for PerViewThrottles
"""
self.ensure_is_throttled(MockView_PerViewThrottling, 503)
def test_request_throttling_is_per_resource(self): def test_request_throttling_is_per_resource(self):
"""Ensure request rate is limited globally per Resource for PerResourceThrottles""" """
self.ensure_is_throttled(MockView3, 503) Ensure request rate is limited globally per Resource for PerResourceThrottles
"""
\ No newline at end of file self.ensure_is_throttled(MockView_PerResourceThrottling, 503)
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
"""
Ensure the response returns an X-Throttle field with status and next attributes
set properly.
"""
request = self.factory.get('/')
for timer, expect in expected_headers:
self.set_throttle_timer(view, timer)
response = view.as_view()(request)
self.assertEquals(response['X-Throttle'], expect)
def test_seconds_fields(self):
"""
Ensure for second based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView,
((0, 'status=SUCCESS; next=0.33 sec'),
(0, 'status=SUCCESS; next=0.50 sec'),
(0, 'status=SUCCESS; next=1.00 sec'),
(0, 'status=FAILURE; next=1.00 sec')
))
def test_minutes_fields(self):
"""
Ensure for minute based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, 'status=SUCCESS; next=20.00 sec'),
(0, 'status=SUCCESS; next=30.00 sec'),
(0, 'status=SUCCESS; next=60.00 sec'),
(0, 'status=FAILURE; next=60.00 sec')
))
def test_next_rate_remains_constant_if_followed(self):
"""
If a client follows the recommended next request rate,
the throttling rate should stay constant.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, 'status=SUCCESS; next=20.00 sec'),
(20, 'status=SUCCESS; next=20.00 sec'),
(40, 'status=SUCCESS; next=20.00 sec'),
(60, 'status=SUCCESS; next=20.00 sec'),
(80, 'status=SUCCESS; next=20.00 sec')
))
...@@ -64,7 +64,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): ...@@ -64,7 +64,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
""" """
permissions = ( permissions.FullAnonAccess, ) permissions = ( permissions.FullAnonAccess, )
@classmethod @classmethod
def as_view(cls, **initkwargs): def as_view(cls, **initkwargs):
""" """
...@@ -101,6 +101,14 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): ...@@ -101,6 +101,14 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
""" """
pass pass
def add_header(self, field, value):
"""
Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class.
"""
self.headers[field] = value
# Note: session based authentication is explicitly CSRF validated, # Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt. # all other authentication is CSRF exempt.
@csrf_exempt @csrf_exempt
...@@ -108,6 +116,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): ...@@ -108,6 +116,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
self.request = request self.request = request
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs
self.headers = {}
# Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here. # Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here.
prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host()) prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host())
...@@ -149,9 +158,12 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): ...@@ -149,9 +158,12 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
# also it's currently sub-obtimal for HTTP caching - need to sort that out. # also it's currently sub-obtimal for HTTP caching - need to sort that out.
response.headers['Allow'] = ', '.join(self.allowed_methods) response.headers['Allow'] = ', '.join(self.allowed_methods)
response.headers['Vary'] = 'Authenticate, Accept' response.headers['Vary'] = 'Authenticate, Accept'
# merge with headers possibly set at some point in the view
response.headers.update(self.headers)
return self.render(response)
return self.render(response)
class ModelView(View): class ModelView(View):
"""A RESTful view that maps to a model in the database.""" """A RESTful view that maps to a model in the database."""
......
...@@ -31,7 +31,7 @@ Resources ...@@ -31,7 +31,7 @@ Resources
* The ``djangorestframework`` package is `available on PyPI <http://pypi.python.org/pypi/djangorestframework>`_. * The ``djangorestframework`` package is `available on PyPI <http://pypi.python.org/pypi/djangorestframework>`_.
* We have an active `discussion group <http://groups.google.com/group/django-rest-framework>`_ and a `project blog <http://blog.django-rest-framework.org>`_. * We have an active `discussion group <http://groups.google.com/group/django-rest-framework>`_ and a `project blog <http://blog.django-rest-framework.org>`_.
* Bug reports are handled on the `issue tracker <https://github.com/tomchristie/django-rest-framework/issues>`_. * Bug reports are handled on the `issue tracker <https://github.com/tomchristie/django-rest-framework/issues>`_.
* There is a `Jenkins CI server <http://datacenter.tibold.nl/job/djangorestframework/>`_ which tracks test status and coverage reporting. (Thanks Marko!) * There is a `Jenkins CI server <http://jenkins.tibold.nl/job/djangorestframework/>`_ which tracks test status and coverage reporting. (Thanks Marko!)
Any and all questions, thoughts, bug reports and contributions are *hugely appreciated*. Any and all questions, thoughts, bug reports and contributions are *hugely appreciated*.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment