Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
django-rest-framework
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
edx
django-rest-framework
Commits
c28b7193
Commit
c28b7193
authored
Sep 04, 2012
by
Tom Christie
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactored throttling
parent
8457c871
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
234 additions
and
183 deletions
+234
-183
djangorestframework/exceptions.py
+9
-3
djangorestframework/parsers.py
+1
-1
djangorestframework/permissions.py
+22
-150
djangorestframework/tests/throttling.py
+23
-20
djangorestframework/throttling.py
+139
-0
djangorestframework/views.py
+40
-9
No files found.
djangorestframework/exceptions.py
View file @
c28b7193
...
...
@@ -49,8 +49,14 @@ class UnsupportedMediaType(APIException):
class
Throttled
(
APIException
):
status_code
=
status
.
HTTP_429_TOO_MANY_REQUESTS
default_detail
=
"Request was throttled. Expected available in
%
d seconds."
default_detail
=
"Request was throttled."
extra_detail
=
"Expected available in
%
d second
%
s."
def
__init__
(
self
,
wait
,
detail
=
None
):
def
__init__
(
self
,
wait
=
None
,
detail
=
None
):
import
math
self
.
detail
=
(
detail
or
self
.
default_detail
)
%
int
(
math
.
ceil
(
wait
))
self
.
wait
=
wait
and
math
.
ceil
(
wait
)
or
None
if
wait
is
not
None
:
format
=
detail
or
self
.
default_detail
+
self
.
extra_detail
self
.
detail
=
format
%
(
self
.
wait
,
self
.
wait
!=
1
and
's'
or
''
)
else
:
self
.
detail
=
detail
or
self
.
default_detail
djangorestframework/parsers.py
View file @
c28b7193
...
...
@@ -81,7 +81,7 @@ class BaseParser(object):
Should return parsed data, or a DataAndFiles object consisting of the
parsed data and files.
"""
raise
NotImplementedError
(
".parse_stream()
Must be overridden to be implemented
."
)
raise
NotImplementedError
(
".parse_stream()
must be overridden
."
)
class
JSONParser
(
BaseParser
):
...
...
djangorestframework/permissions.py
View file @
c28b7193
...
...
@@ -5,10 +5,6 @@ for checking if a request passes a certain set of constraints.
Permission behavior is provided by mixing the :class:`mixins.PermissionsMixin` class into a :class:`View` class.
"""
from
django.core.cache
import
cache
from
djangorestframework.exceptions
import
PermissionDenied
,
Throttled
import
time
__all__
=
(
'BasePermission'
,
'FullAnonAccess'
,
...
...
@@ -32,20 +28,11 @@ class BasePermission(object):
"""
self
.
view
=
view
def
check_permission
(
self
,
auth
):
def
check_permission
(
self
,
request
,
obj
=
None
):
"""
Should simply return, or raise an :exc:`response.ImmediateResponse`.
"""
pass
class
FullAnonAccess
(
BasePermission
):
"""
Allows full access.
"""
def
check_permission
(
self
,
user
):
pass
raise
NotImplementedError
(
".check_permission() must be overridden."
)
class
IsAuthenticated
(
BasePermission
):
...
...
@@ -53,9 +40,10 @@ class IsAuthenticated(BasePermission):
Allows access only to authenticated users.
"""
def
check_permission
(
self
,
user
):
if
not
user
.
is_authenticated
():
raise
PermissionDenied
()
def
check_permission
(
self
,
request
,
obj
=
None
):
if
request
.
user
.
is_authenticated
():
return
True
return
False
class
IsAdminUser
(
BasePermission
):
...
...
@@ -63,20 +51,22 @@ class IsAdminUser(BasePermission):
Allows access only to admin users.
"""
def
check_permission
(
self
,
user
):
if
not
user
.
is_staff
:
raise
PermissionDenied
()
def
check_permission
(
self
,
request
,
obj
=
None
):
if
request
.
user
.
is_staff
:
return
True
return
False
class
Is
UserOrIsAnon
ReadOnly
(
BasePermission
):
class
Is
AuthenticatedOr
ReadOnly
(
BasePermission
):
"""
The request is authenticated as a user, or is a read-only request.
"""
def
check_permission
(
self
,
user
):
if
(
not
user
.
is_authenticated
()
and
self
.
view
.
method
not
in
SAFE_METHODS
):
raise
PermissionDenied
()
def
check_permission
(
self
,
request
,
obj
=
None
):
if
(
request
.
user
.
is_authenticated
()
or
request
.
method
in
SAFE_METHODS
):
return
True
return
False
class
DjangoModelPermissions
(
BasePermission
):
...
...
@@ -114,128 +104,10 @@ class DjangoModelPermissions(BasePermission):
}
return
[
perm
%
kwargs
for
perm
in
self
.
perms_map
[
method
]]
def
check_permission
(
self
,
user
):
method
=
self
.
view
.
method
model_cls
=
self
.
view
.
resource
.
model
perms
=
self
.
get_required_permissions
(
method
,
model_cls
)
if
not
user
.
is_authenticated
or
not
user
.
has_perms
(
perms
):
raise
PermissionDenied
()
class
BaseThrottle
(
BasePermission
):
"""
Rate throttling of requests.
The rate (requests / seconds) is set by a :attr:`throttle` attribute
on the :class:`.View` class. The attribute is a string of the form 'number of
requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
attr_name
=
'throttle'
default
=
'0/sec'
timer
=
time
.
time
def
get_cache_key
(
self
):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
"""
pass
def
check_permission
(
self
,
auth
):
"""
Check the throttling.
Return `None` or raise an :exc:`.ImmediateResponse`.
"""
num
,
period
=
getattr
(
self
.
view
,
self
.
attr_name
,
self
.
default
)
.
split
(
'/'
)
self
.
num_requests
=
int
(
num
)
self
.
duration
=
{
's'
:
1
,
'm'
:
60
,
'h'
:
3600
,
'd'
:
86400
}[
period
[
0
]]
self
.
auth
=
auth
self
.
check_throttle
()
def
check_throttle
(
self
):
"""
Implement the check to see if the request should be throttled.
On success calls :meth:`throttle_success`.
On failure calls :meth:`throttle_failure`.
"""
self
.
key
=
self
.
get_cache_key
()
self
.
history
=
cache
.
get
(
self
.
key
,
[])
self
.
now
=
self
.
timer
()
# Drop any requests from the history which have now passed the
# throttle duration
while
self
.
history
and
self
.
history
[
-
1
]
<=
self
.
now
-
self
.
duration
:
self
.
history
.
pop
()
if
len
(
self
.
history
)
>=
self
.
num_requests
:
self
.
throttle_failure
()
else
:
self
.
throttle_success
()
def
throttle_success
(
self
):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self
.
history
.
insert
(
0
,
self
.
now
)
cache
.
set
(
self
.
key
,
self
.
history
,
self
.
duration
)
header
=
'status=SUCCESS; next=
%.2
f sec'
%
self
.
next
()
self
.
view
.
headers
[
'X-Throttle'
]
=
header
def
throttle_failure
(
self
):
"""
Called when a request to the API has failed due to throttling.
Raises a '503 service unavailable' response.
"""
wait
=
self
.
next
()
header
=
'status=FAILURE; next=
%.2
f sec'
%
wait
self
.
view
.
headers
[
'X-Throttle'
]
=
header
raise
Throttled
(
wait
)
def
next
(
self
):
"""
Returns the recommended next request time in seconds.
"""
if
self
.
history
:
remaining_duration
=
self
.
duration
-
(
self
.
now
-
self
.
history
[
-
1
])
else
:
remaining_duration
=
self
.
duration
available_requests
=
self
.
num_requests
-
len
(
self
.
history
)
+
1
return
remaining_duration
/
float
(
available_requests
)
class
PerUserThrottling
(
BaseThrottle
):
"""
Limits the rate of API calls that may be made by a given user.
The user id will be used as a unique identifier if the user is
authenticated. For anonymous requests, the IP address of the client will
be used.
"""
def
get_cache_key
(
self
):
if
self
.
auth
.
is_authenticated
():
ident
=
self
.
auth
.
id
else
:
ident
=
self
.
view
.
request
.
META
.
get
(
'REMOTE_ADDR'
,
None
)
return
'throttle_user_
%
s'
%
ident
class
PerViewThrottling
(
BaseThrottle
):
"""
Limits the rate of API calls that may be used on a given view.
The class name of the view is used as a unique identifier to
throttle against.
"""
def
check_permission
(
self
,
request
,
obj
=
None
):
model_cls
=
self
.
view
.
model
perms
=
self
.
get_required_permissions
(
request
.
method
,
model_cls
)
def
get_cache_key
(
self
):
return
'throttle_view_
%
s'
%
self
.
view
.
__class__
.
__name__
if
request
.
user
.
is_authenticated
()
and
request
.
user
.
has_perms
(
perms
,
obj
):
return
True
return
False
djangorestframework/tests/throttling.py
View file @
c28b7193
...
...
@@ -8,24 +8,24 @@ from django.core.cache import cache
from
djangorestframework.compat
import
RequestFactory
from
djangorestframework.views
import
APIView
from
djangorestframework.
permissions
import
PerUserThrottling
,
PerViewThrottling
from
djangorestframework.
throttling
import
PerUserThrottling
,
PerViewThrottling
from
djangorestframework.response
import
Response
class
MockView
(
APIView
):
permission
_classes
=
(
PerUserThrottling
,)
throttl
e
=
'3/sec'
throttle
_classes
=
(
PerUserThrottling
,)
rat
e
=
'3/sec'
def
get
(
self
,
request
):
return
Response
(
'foo'
)
class
MockView_PerViewThrottling
(
MockView
):
permission
_classes
=
(
PerViewThrottling
,)
throttle
_classes
=
(
PerViewThrottling
,)
class
MockView_MinuteThrottling
(
MockView
):
throttl
e
=
'3/min'
rat
e
=
'3/min'
class
ThrottlingTests
(
TestCase
):
...
...
@@ -51,7 +51,7 @@ class ThrottlingTests(TestCase):
"""
Explicitly set the timer, overriding time.time()
"""
view
.
permission
_classes
[
0
]
.
timer
=
lambda
self
:
value
view
.
throttle
_classes
[
0
]
.
timer
=
lambda
self
:
value
def
test_request_throttling_expires
(
self
):
"""
...
...
@@ -101,17 +101,20 @@ class ThrottlingTests(TestCase):
for
timer
,
expect
in
expected_headers
:
self
.
set_throttle_timer
(
view
,
timer
)
response
=
view
.
as_view
()(
request
)
self
.
assertEquals
(
response
[
'X-Throttle'
],
expect
)
if
expect
is
not
None
:
self
.
assertEquals
(
response
[
'X-Throttle-Wait-Seconds'
],
expect
)
else
:
self
.
assertFalse
(
'X-Throttle-Wait-Seconds'
in
response
.
headers
)
def
test_seconds_fields
(
self
):
"""
Ensure for second based throttles.
"""
self
.
ensure_response_header_contains_proper_throttle_field
(
MockView
,
((
0
,
'status=SUCCESS; next=0.33 sec'
),
(
0
,
'status=SUCCESS; next=0.50 sec'
),
(
0
,
'status=SUCCESS; next=1.00 sec'
),
(
0
,
'
status=FAILURE; next=1.00 sec
'
)
((
0
,
None
),
(
0
,
None
),
(
0
,
None
),
(
0
,
'
1
'
)
))
def
test_minutes_fields
(
self
):
...
...
@@ -119,10 +122,10 @@ class ThrottlingTests(TestCase):
Ensure for minute based throttles.
"""
self
.
ensure_response_header_contains_proper_throttle_field
(
MockView_MinuteThrottling
,
((
0
,
'status=SUCCESS; next=20.00 sec'
),
(
0
,
'status=SUCCESS; next=30.00 sec'
),
(
0
,
'status=SUCCESS; next=60.00 sec'
),
(
0
,
'
status=FAILURE; next=60.00 sec
'
)
((
0
,
None
),
(
0
,
None
),
(
0
,
None
),
(
0
,
'
60
'
)
))
def
test_next_rate_remains_constant_if_followed
(
self
):
...
...
@@ -131,9 +134,9 @@ class ThrottlingTests(TestCase):
the throttling rate should stay constant.
"""
self
.
ensure_response_header_contains_proper_throttle_field
(
MockView_MinuteThrottling
,
((
0
,
'status=SUCCESS; next=20.00 sec'
),
(
20
,
'status=SUCCESS; next=20.00 sec'
),
(
40
,
'status=SUCCESS; next=20.00 sec'
),
(
60
,
'status=SUCCESS; next=20.00 sec'
),
(
80
,
'status=SUCCESS; next=20.00 sec'
)
((
0
,
None
),
(
20
,
None
),
(
40
,
None
),
(
60
,
None
),
(
80
,
None
)
))
djangorestframework/throttling.py
0 → 100644
View file @
c28b7193
from
django.core.cache
import
cache
import
time
class
BaseThrottle
(
object
):
"""
Rate throttling of requests.
"""
def
__init__
(
self
,
view
=
None
):
"""
All throttles hold a reference to the instantiating view.
"""
self
.
view
=
view
def
check_throttle
(
self
,
request
):
"""
Return `True` if the request should be allowed, `False` otherwise.
"""
raise
NotImplementedError
(
'.check_throttle() must be overridden'
)
def
wait
(
self
):
"""
Optionally, return a recommeded number of seconds to wait before
the next request.
"""
return
None
class
SimpleCachingThrottle
(
BaseThrottle
):
"""
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
The rate (requests / seconds) is set by a :attr:`throttle` attribute
on the :class:`.View` class. The attribute is a string of the form 'number of
requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
attr_name
=
'rate'
rate
=
'1000/day'
timer
=
time
.
time
def
__init__
(
self
,
view
):
"""
Check the throttling.
Return `None` or raise an :exc:`.ImmediateResponse`.
"""
super
(
SimpleCachingThrottle
,
self
)
.
__init__
(
view
)
num
,
period
=
getattr
(
view
,
self
.
attr_name
,
self
.
rate
)
.
split
(
'/'
)
self
.
num_requests
=
int
(
num
)
self
.
duration
=
{
's'
:
1
,
'm'
:
60
,
'h'
:
3600
,
'd'
:
86400
}[
period
[
0
]]
def
get_cache_key
(
self
,
request
):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
"""
raise
NotImplementedError
(
'.get_cache_key() must be overridden'
)
def
check_throttle
(
self
,
request
):
"""
Implement the check to see if the request should be throttled.
On success calls :meth:`throttle_success`.
On failure calls :meth:`throttle_failure`.
"""
self
.
key
=
self
.
get_cache_key
(
request
)
self
.
history
=
cache
.
get
(
self
.
key
,
[])
self
.
now
=
self
.
timer
()
# Drop any requests from the history which have now passed the
# throttle duration
while
self
.
history
and
self
.
history
[
-
1
]
<=
self
.
now
-
self
.
duration
:
self
.
history
.
pop
()
if
len
(
self
.
history
)
>=
self
.
num_requests
:
return
self
.
throttle_failure
()
return
self
.
throttle_success
()
def
throttle_success
(
self
):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self
.
history
.
insert
(
0
,
self
.
now
)
cache
.
set
(
self
.
key
,
self
.
history
,
self
.
duration
)
return
True
def
throttle_failure
(
self
):
"""
Called when a request to the API has failed due to throttling.
"""
return
False
def
wait
(
self
):
"""
Returns the recommended next request time in seconds.
"""
if
self
.
history
:
remaining_duration
=
self
.
duration
-
(
self
.
now
-
self
.
history
[
-
1
])
else
:
remaining_duration
=
self
.
duration
available_requests
=
self
.
num_requests
-
len
(
self
.
history
)
+
1
return
remaining_duration
/
float
(
available_requests
)
class
PerUserThrottling
(
SimpleCachingThrottle
):
"""
Limits the rate of API calls that may be made by a given user.
The user id will be used as a unique identifier if the user is
authenticated. For anonymous requests, the IP address of the client will
be used.
"""
def
get_cache_key
(
self
,
request
):
if
request
.
user
.
is_authenticated
():
ident
=
request
.
user
.
id
else
:
ident
=
request
.
META
.
get
(
'REMOTE_ADDR'
,
None
)
return
'throttle_user_
%
s'
%
ident
class
PerViewThrottling
(
SimpleCachingThrottle
):
"""
Limits the rate of API calls that may be used on a given view.
The class name of the view is used as a unique identifier to
throttle against.
"""
def
get_cache_key
(
self
,
request
):
return
'throttle_view_
%
s'
%
self
.
view
.
__class__
.
__name__
djangorestframework/views.py
View file @
c28b7193
...
...
@@ -18,7 +18,7 @@ from djangorestframework.compat import View as _View, apply_markdown
from
djangorestframework.response
import
Response
from
djangorestframework.request
import
Request
from
djangorestframework.settings
import
api_settings
from
djangorestframework
import
parsers
,
authentication
,
permissions
,
status
,
exceptions
,
mixins
from
djangorestframework
import
parsers
,
authentication
,
status
,
exceptions
,
mixins
__all__
=
(
...
...
@@ -86,7 +86,12 @@ class APIView(_View):
List of all authenticating methods to attempt.
"""
permission_classes
=
(
permissions
.
FullAnonAccess
,)
throttle_classes
=
()
"""
List of all throttles to check.
"""
permission_classes
=
()
"""
List of all permissions that must be checked.
"""
...
...
@@ -195,12 +200,27 @@ class APIView(_View):
"""
return
[
permission
(
self
)
for
permission
in
self
.
permission_classes
]
def
check_permissions
(
self
,
user
):
def
get_throttles
(
self
):
"""
Check user permissions and either raise an ``ImmediateResponse`` or return.
Instantiates and returns the list of thottles that this view requires.
"""
return
[
throttle
(
self
)
for
throttle
in
self
.
throttle_classes
]
def
check_permissions
(
self
,
request
,
obj
=
None
):
"""
Check user permissions and either raise an ``PermissionDenied`` or return.
"""
for
permission
in
self
.
get_permissions
():
permission
.
check_permission
(
user
)
if
not
permission
.
check_permission
(
request
,
obj
):
raise
exceptions
.
PermissionDenied
()
def
check_throttles
(
self
,
request
):
"""
Check throttles and either raise a `Throttled` exception or return.
"""
for
throttle
in
self
.
get_throttles
():
if
not
throttle
.
check_throttle
(
request
):
raise
exceptions
.
Throttled
(
throttle
.
wait
())
def
initial
(
self
,
request
,
*
args
,
**
kargs
):
"""
...
...
@@ -232,6 +252,9 @@ class APIView(_View):
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
if
isinstance
(
exc
,
exceptions
.
Throttled
):
self
.
headers
[
'X-Throttle-Wait-Seconds'
]
=
'
%
d'
%
exc
.
wait
if
isinstance
(
exc
,
exceptions
.
APIException
):
return
Response
({
'detail'
:
exc
.
detail
},
status
=
exc
.
status_code
)
elif
isinstance
(
exc
,
Http404
):
...
...
@@ -255,8 +278,9 @@ class APIView(_View):
try
:
self
.
initial
(
request
,
*
args
,
**
kwargs
)
# check that user has the relevant permissions
self
.
check_permissions
(
request
.
user
)
# Check that the request is allowed
self
.
check_permissions
(
request
)
self
.
check_throttles
(
request
)
# Get the appropriate handler method
if
request
.
method
.
lower
()
in
self
.
http_method_names
:
...
...
@@ -283,11 +307,12 @@ class BaseView(APIView):
serializer_class
=
None
def
get_serializer
(
self
,
data
=
None
,
files
=
None
,
instance
=
None
):
# TODO: add support for files
context
=
{
'request'
:
self
.
request
,
'format'
:
self
.
kwargs
.
get
(
'format'
,
None
)
}
return
self
.
serializer_class
(
data
,
context
=
context
)
return
self
.
serializer_class
(
data
,
instance
=
instance
,
context
=
context
)
class
MultipleObjectBaseView
(
MultipleObjectMixin
,
BaseView
):
...
...
@@ -301,7 +326,13 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView):
"""
Base class for generic views onto a model instance.
"""
pass
def
get_object
(
self
):
"""
Override default to add support for object-level permissions.
"""
super
(
self
,
SingleObjectBaseView
)
.
get_object
()
self
.
check_permissions
(
self
.
request
,
self
.
object
)
# Concrete view classes that provide method handlers
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment