Commit f854bc90 by markotibold

* fixed `test_request_throttling_is_per_user` - it didn't make a request for the 2nd user

* implemented per_resource_throttling  + test

needs refactoring
parent 87db5fbd
...@@ -122,3 +122,35 @@ class PerUserThrottling(BasePermission): ...@@ -122,3 +122,35 @@ class PerUserThrottling(BasePermission):
history.insert(0, now) history.insert(0, now)
cache.set(key, history, duration) cache.set(key, history, duration)
class PerResourceThrottling(BasePermission):
"""
Rate throttling of requests on a per-resource basis.
The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class.
The attribute is a two tuple of the form (number of requests, duration in seconds).
The user id will be used as a unique identifier if the user is authenticated.
For anonymous requests, the IP address of the client will be used.
Previous request information used for throttling is stored in the cache.
"""
def check_permission(self, ignore):
(num_requests, duration) = getattr(self.view, 'throttle', (0, 0))
key = 'throttle_%s' % self.view.__class__.__name__
history = cache.get(key, [])
now = time.time()
# Drop any requests from the history which have now passed the throttle duration
while history and history[0] < now - duration:
history.pop()
if len(history) >= num_requests:
raise _503_THROTTLED_RESPONSE
history.insert(0, now)
cache.set(key, history, duration)
...@@ -8,7 +8,7 @@ from django.core.cache import cache ...@@ -8,7 +8,7 @@ from django.core.cache import cache
from djangorestframework.compat import RequestFactory from djangorestframework.compat import RequestFactory
from djangorestframework.views import View from djangorestframework.views import View
from djangorestframework.permissions import PerUserThrottling from djangorestframework.permissions import PerUserThrottling, PerResourceThrottling
class MockView(View): class MockView(View):
...@@ -18,8 +18,16 @@ class MockView(View): ...@@ -18,8 +18,16 @@ class MockView(View):
def get(self, request): def get(self, request):
return 'foo' return 'foo'
class MockView1(View):
permissions = ( PerResourceThrottling, )
throttle = (3, 1) # 3 requests per second
def get(self, request):
return 'foo'
urlpatterns = patterns('', urlpatterns = patterns('',
(r'^$', MockView.as_view()), (r'^$', MockView.as_view()),
(r'^1$', MockView1.as_view()),
) )
class ThrottlingTests(TestCase): class ThrottlingTests(TestCase):
...@@ -37,7 +45,6 @@ class ThrottlingTests(TestCase): ...@@ -37,7 +45,6 @@ class ThrottlingTests(TestCase):
self.assertEqual(503, response.status_code) self.assertEqual(503, response.status_code)
def test_request_throttling_is_per_user(self): def test_request_throttling_is_per_user(self):
#Can not login user.....Dunno why...
"""Ensure request rate is only limited per user, not globally""" """Ensure request rate is only limited per user, not globally"""
for username in ('testuser', 'another_testuser'): for username in ('testuser', 'another_testuser'):
user = User.objects.create(username=username) user = User.objects.create(username=username)
...@@ -49,8 +56,24 @@ class ThrottlingTests(TestCase): ...@@ -49,8 +56,24 @@ class ThrottlingTests(TestCase):
response = self.client.get('/') response = self.client.get('/')
self.client.logout() self.client.logout()
self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed') self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed')
response = self.client.get('/')
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
def test_request_throttling_is_per_resource(self):
"""Ensure request rate is limited globally per View"""
for username in ('testuser', 'another_testuser'):
user = User.objects.create(username=username)
user.set_password('test')
user.save()
self.assertTrue(self.client.login(username='testuser', password='test'), msg='Login Failed')
for dummy in range(3):
response = self.client.get('/1')
self.client.logout()
self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed')
response = self.client.get('/1')
self.assertEqual(503, response.status_code)
def test_request_throttling_expires(self): def test_request_throttling_expires(self):
"""Ensure request rate is limited for a limited duration only""" """Ensure request rate is limited for a limited duration only"""
for dummy in range(3): for dummy in range(3):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment