Commit df957c86 by Tom Christie

Fix and tests for ScopedRateThrottle. Closes #935

parent 6cc4fe56
...@@ -7,7 +7,7 @@ from django.contrib.auth.models import User ...@@ -7,7 +7,7 @@ from django.contrib.auth.models import User
from django.core.cache import cache from django.core.cache import cache
from django.test.client import RequestFactory from django.test.client import RequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.throttling import UserRateThrottle from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle
from rest_framework.response import Response from rest_framework.response import Response
...@@ -36,8 +36,6 @@ class MockView_MinuteThrottling(APIView): ...@@ -36,8 +36,6 @@ class MockView_MinuteThrottling(APIView):
class ThrottlingTests(TestCase): class ThrottlingTests(TestCase):
urls = 'rest_framework.tests.test_throttling'
def setUp(self): def setUp(self):
""" """
Reset the cache so that no throttles will be active Reset the cache so that no throttles will be active
...@@ -141,3 +139,108 @@ class ThrottlingTests(TestCase): ...@@ -141,3 +139,108 @@ class ThrottlingTests(TestCase):
(60, None), (60, None),
(80, None) (80, None)
)) ))
class ScopedRateThrottleTests(TestCase):
"""
Tests for ScopedRateThrottle.
"""
def setUp(self):
class XYScopedRateThrottle(ScopedRateThrottle):
TIMER_SECONDS = 0
THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
timer = lambda self: self.TIMER_SECONDS
class XView(APIView):
throttle_classes = (XYScopedRateThrottle,)
throttle_scope = 'x'
def get(self, request):
return Response('x')
class YView(APIView):
throttle_classes = (XYScopedRateThrottle,)
throttle_scope = 'y'
def get(self, request):
return Response('y')
class UnscopedView(APIView):
throttle_classes = (XYScopedRateThrottle,)
def get(self, request):
return Response('y')
self.throttle_class = XYScopedRateThrottle
self.factory = RequestFactory()
self.x_view = XView.as_view()
self.y_view = YView.as_view()
self.unscoped_view = UnscopedView.as_view()
def increment_timer(self, seconds=1):
self.throttle_class.TIMER_SECONDS += seconds
def test_scoped_rate_throttle(self):
request = self.factory.get('/')
# Should be able to hit x view 3 times per minute.
response = self.x_view(request)
self.assertEqual(200, response.status_code)
self.increment_timer()
response = self.x_view(request)
self.assertEqual(200, response.status_code)
self.increment_timer()
response = self.x_view(request)
self.assertEqual(200, response.status_code)
self.increment_timer()
response = self.x_view(request)
self.assertEqual(429, response.status_code)
# Should be able to hit y view 1 time per minute.
self.increment_timer()
response = self.y_view(request)
self.assertEqual(200, response.status_code)
self.increment_timer()
response = self.y_view(request)
self.assertEqual(429, response.status_code)
# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(55)
# Should still be able to hit x view 3 times per minute.
response = self.x_view(request)
self.assertEqual(200, response.status_code)
self.increment_timer()
response = self.x_view(request)
self.assertEqual(200, response.status_code)
self.increment_timer()
response = self.x_view(request)
self.assertEqual(200, response.status_code)
self.increment_timer()
response = self.x_view(request)
self.assertEqual(429, response.status_code)
# Should still be able to hit y view 1 time per minute.
self.increment_timer()
response = self.y_view(request)
self.assertEqual(200, response.status_code)
self.increment_timer()
response = self.y_view(request)
self.assertEqual(429, response.status_code)
def test_unscoped_view_not_throttled(self):
request = self.factory.get('/')
for idx in range(10):
self.increment_timer()
response = self.unscoped_view(request)
self.assertEqual(200, response.status_code)
...@@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle): ...@@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle):
""" """
timer = time.time timer = time.time
settings = api_settings
cache_format = 'throtte_%(scope)s_%(ident)s' cache_format = 'throtte_%(scope)s_%(ident)s'
scope = None scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self): def __init__(self):
if not getattr(self, 'rate', None): if not getattr(self, 'rate', None):
...@@ -68,7 +68,7 @@ class SimpleRateThrottle(BaseThrottle): ...@@ -68,7 +68,7 @@ class SimpleRateThrottle(BaseThrottle):
raise ImproperlyConfigured(msg) raise ImproperlyConfigured(msg)
try: try:
return self.settings.DEFAULT_THROTTLE_RATES[self.scope] return self.THROTTLE_RATES[self.scope]
except KeyError: except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg) raise ImproperlyConfigured(msg)
...@@ -187,6 +187,19 @@ class ScopedRateThrottle(SimpleRateThrottle): ...@@ -187,6 +187,19 @@ class ScopedRateThrottle(SimpleRateThrottle):
""" """
scope_attr = 'throttle_scope' scope_attr = 'throttle_scope'
def __init__(self):
pass
def allow_request(self, request, view):
self.scope = getattr(view, self.scope_attr, None)
if not self.scope:
return True
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
return super(ScopedRateThrottle, self).allow_request(request, view)
def get_cache_key(self, request, view): def get_cache_key(self, request, view):
""" """
If `view.throttle_scope` is not set, don't apply this throttle. If `view.throttle_scope` is not set, don't apply this throttle.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment