Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
django-rest-framework
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
edx
django-rest-framework
Commits
df957c86
Commit
df957c86
authored
Jun 14, 2013
by
Tom Christie
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix and tests for ScopedRateThrottle. Closes #935
parent
6cc4fe56
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
5 deletions
+121
-5
rest_framework/tests/test_throttling.py
+106
-3
rest_framework/throttling.py
+15
-2
No files found.
rest_framework/tests/test_throttling.py
View file @
df957c86
...
@@ -7,7 +7,7 @@ from django.contrib.auth.models import User
...
@@ -7,7 +7,7 @@ from django.contrib.auth.models import User
from
django.core.cache
import
cache
from
django.core.cache
import
cache
from
django.test.client
import
RequestFactory
from
django.test.client
import
RequestFactory
from
rest_framework.views
import
APIView
from
rest_framework.views
import
APIView
from
rest_framework.throttling
import
UserRateThrottle
from
rest_framework.throttling
import
UserRateThrottle
,
ScopedRateThrottle
from
rest_framework.response
import
Response
from
rest_framework.response
import
Response
...
@@ -36,8 +36,6 @@ class MockView_MinuteThrottling(APIView):
...
@@ -36,8 +36,6 @@ class MockView_MinuteThrottling(APIView):
class
ThrottlingTests
(
TestCase
):
class
ThrottlingTests
(
TestCase
):
urls
=
'rest_framework.tests.test_throttling'
def
setUp
(
self
):
def
setUp
(
self
):
"""
"""
Reset the cache so that no throttles will be active
Reset the cache so that no throttles will be active
...
@@ -141,3 +139,108 @@ class ThrottlingTests(TestCase):
...
@@ -141,3 +139,108 @@ class ThrottlingTests(TestCase):
(
60
,
None
),
(
60
,
None
),
(
80
,
None
)
(
80
,
None
)
))
))
class
ScopedRateThrottleTests
(
TestCase
):
"""
Tests for ScopedRateThrottle.
"""
def
setUp
(
self
):
class
XYScopedRateThrottle
(
ScopedRateThrottle
):
TIMER_SECONDS
=
0
THROTTLE_RATES
=
{
'x'
:
'3/min'
,
'y'
:
'1/min'
}
timer
=
lambda
self
:
self
.
TIMER_SECONDS
class
XView
(
APIView
):
throttle_classes
=
(
XYScopedRateThrottle
,)
throttle_scope
=
'x'
def
get
(
self
,
request
):
return
Response
(
'x'
)
class
YView
(
APIView
):
throttle_classes
=
(
XYScopedRateThrottle
,)
throttle_scope
=
'y'
def
get
(
self
,
request
):
return
Response
(
'y'
)
class
UnscopedView
(
APIView
):
throttle_classes
=
(
XYScopedRateThrottle
,)
def
get
(
self
,
request
):
return
Response
(
'y'
)
self
.
throttle_class
=
XYScopedRateThrottle
self
.
factory
=
RequestFactory
()
self
.
x_view
=
XView
.
as_view
()
self
.
y_view
=
YView
.
as_view
()
self
.
unscoped_view
=
UnscopedView
.
as_view
()
def
increment_timer
(
self
,
seconds
=
1
):
self
.
throttle_class
.
TIMER_SECONDS
+=
seconds
def
test_scoped_rate_throttle
(
self
):
request
=
self
.
factory
.
get
(
'/'
)
# Should be able to hit x view 3 times per minute.
response
=
self
.
x_view
(
request
)
self
.
assertEqual
(
200
,
response
.
status_code
)
self
.
increment_timer
()
response
=
self
.
x_view
(
request
)
self
.
assertEqual
(
200
,
response
.
status_code
)
self
.
increment_timer
()
response
=
self
.
x_view
(
request
)
self
.
assertEqual
(
200
,
response
.
status_code
)
self
.
increment_timer
()
response
=
self
.
x_view
(
request
)
self
.
assertEqual
(
429
,
response
.
status_code
)
# Should be able to hit y view 1 time per minute.
self
.
increment_timer
()
response
=
self
.
y_view
(
request
)
self
.
assertEqual
(
200
,
response
.
status_code
)
self
.
increment_timer
()
response
=
self
.
y_view
(
request
)
self
.
assertEqual
(
429
,
response
.
status_code
)
# Ensure throttles properly reset by advancing the rest of the minute
self
.
increment_timer
(
55
)
# Should still be able to hit x view 3 times per minute.
response
=
self
.
x_view
(
request
)
self
.
assertEqual
(
200
,
response
.
status_code
)
self
.
increment_timer
()
response
=
self
.
x_view
(
request
)
self
.
assertEqual
(
200
,
response
.
status_code
)
self
.
increment_timer
()
response
=
self
.
x_view
(
request
)
self
.
assertEqual
(
200
,
response
.
status_code
)
self
.
increment_timer
()
response
=
self
.
x_view
(
request
)
self
.
assertEqual
(
429
,
response
.
status_code
)
# Should still be able to hit y view 1 time per minute.
self
.
increment_timer
()
response
=
self
.
y_view
(
request
)
self
.
assertEqual
(
200
,
response
.
status_code
)
self
.
increment_timer
()
response
=
self
.
y_view
(
request
)
self
.
assertEqual
(
429
,
response
.
status_code
)
def
test_unscoped_view_not_throttled
(
self
):
request
=
self
.
factory
.
get
(
'/'
)
for
idx
in
range
(
10
):
self
.
increment_timer
()
response
=
self
.
unscoped_view
(
request
)
self
.
assertEqual
(
200
,
response
.
status_code
)
rest_framework/throttling.py
View file @
df957c86
...
@@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle):
...
@@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle):
"""
"""
timer
=
time
.
time
timer
=
time
.
time
settings
=
api_settings
cache_format
=
'throtte_
%(scope)
s_
%(ident)
s'
cache_format
=
'throtte_
%(scope)
s_
%(ident)
s'
scope
=
None
scope
=
None
THROTTLE_RATES
=
api_settings
.
DEFAULT_THROTTLE_RATES
def
__init__
(
self
):
def
__init__
(
self
):
if
not
getattr
(
self
,
'rate'
,
None
):
if
not
getattr
(
self
,
'rate'
,
None
):
...
@@ -68,7 +68,7 @@ class SimpleRateThrottle(BaseThrottle):
...
@@ -68,7 +68,7 @@ class SimpleRateThrottle(BaseThrottle):
raise
ImproperlyConfigured
(
msg
)
raise
ImproperlyConfigured
(
msg
)
try
:
try
:
return
self
.
settings
.
DEFAULT_
THROTTLE_RATES
[
self
.
scope
]
return
self
.
THROTTLE_RATES
[
self
.
scope
]
except
KeyError
:
except
KeyError
:
msg
=
"No default throttle rate set for '
%
s' scope"
%
self
.
scope
msg
=
"No default throttle rate set for '
%
s' scope"
%
self
.
scope
raise
ImproperlyConfigured
(
msg
)
raise
ImproperlyConfigured
(
msg
)
...
@@ -187,6 +187,19 @@ class ScopedRateThrottle(SimpleRateThrottle):
...
@@ -187,6 +187,19 @@ class ScopedRateThrottle(SimpleRateThrottle):
"""
"""
scope_attr
=
'throttle_scope'
scope_attr
=
'throttle_scope'
def
__init__
(
self
):
pass
def
allow_request
(
self
,
request
,
view
):
self
.
scope
=
getattr
(
view
,
self
.
scope_attr
,
None
)
if
not
self
.
scope
:
return
True
self
.
rate
=
self
.
get_rate
()
self
.
num_requests
,
self
.
duration
=
self
.
parse_rate
(
self
.
rate
)
return
super
(
ScopedRateThrottle
,
self
)
.
allow_request
(
request
,
view
)
def
get_cache_key
(
self
,
request
,
view
):
def
get_cache_key
(
self
,
request
,
view
):
"""
"""
If `view.throttle_scope` is not set, don't apply this throttle.
If `view.throttle_scope` is not set, don't apply this throttle.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment