Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
django-rest-framework
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
edx
django-rest-framework
Commits
399ac70b
Commit
399ac70b
authored
Mar 30, 2013
by
Tom Christie
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' of
https://github.com/tomchristie/django-rest-framework
parents
c4eda3a6
2e06f5c8
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
176 additions
and
106 deletions
+176
-106
docs/api-guide/authentication.md
+1
-1
docs/topics/release-notes.md
+6
-0
rest_framework/authentication.py
+11
-15
rest_framework/compat.py
+31
-2
rest_framework/filters.py
+1
-1
rest_framework/templatetags/rest_framework.py
+46
-38
rest_framework/tests/authentication.py
+11
-33
rest_framework/tests/filterset.py
+69
-6
rest_framework/tests/pagination.py
+0
-10
No files found.
docs/api-guide/authentication.md
View file @
399ac70b
...
...
@@ -300,7 +300,7 @@ The only thing needed to make the `OAuth2Authentication` class work is to insert
The command line to test the authentication looks like:
curl -H "Authorization: Bearer <your-access-token>" http://localhost:8000/api/
?client_id=YOUR_CLIENT_ID\&client_secret=YOUR_CLIENT_SECRET
curl -H "Authorization: Bearer <your-access-token>" http://localhost:8000/api/
---
...
...
docs/topics/release-notes.md
View file @
399ac70b
...
...
@@ -40,6 +40,12 @@ You can determine your currently installed version using `pip freeze`:
## 2.2.x series
### Master
*
OAuth2 authentication no longer requires unneccessary URL parameters in addition to the token.
*
URL hyperlinking in browseable API now handles more cases correctly.
*
Bugfix
:
Fix regression with DjangoFilterBackend not worthing correctly with single object views.
### 2.2.5
*
*Date
**
:
26th March 2013
...
...
rest_framework/authentication.py
View file @
399ac70b
...
...
@@ -2,14 +2,16 @@
Provides a set of pluggable authentication policies.
"""
from
__future__
import
unicode_literals
import
base64
from
datetime
import
datetime
from
django.contrib.auth
import
authenticate
from
django.core.exceptions
import
ImproperlyConfigured
from
rest_framework
import
exceptions
,
HTTP_HEADER_ENCODING
from
rest_framework.compat
import
CsrfViewMiddleware
from
rest_framework.compat
import
oauth
,
oauth_provider
,
oauth_provider_store
from
rest_framework.compat
import
oauth2_provider
,
oauth2_provider_forms
,
oauth2_provider_backends
from
rest_framework.compat
import
oauth2_provider
,
oauth2_provider_forms
from
rest_framework.authtoken.models
import
Token
import
base64
def
get_authorization_header
(
request
):
...
...
@@ -315,21 +317,15 @@ class OAuth2Authentication(BaseAuthentication):
Authenticate the request, given the access token.
"""
# Authenticate the client
oauth2_client_form
=
oauth2_provider_forms
.
ClientAuthForm
(
request
.
REQUEST
)
if
not
oauth2_client_form
.
is_valid
():
raise
exceptions
.
AuthenticationFailed
(
'Client could not be validated'
)
client
=
oauth2_client_form
.
cleaned_data
.
get
(
'client'
)
# Retrieve the `OAuth2AccessToken` instance from the access_token
auth_backend
=
oauth2_provider_backends
.
AccessTokenBackend
()
token
=
auth_backend
.
authenticate
(
access_token
,
client
)
if
token
is
None
:
try
:
token
=
oauth2_provider
.
models
.
AccessToken
.
objects
.
select_related
(
'user'
)
# TODO: Change to timezone aware datetime when oauth2_provider add
# support to it.
token
=
token
.
get
(
token
=
access_token
,
expires__gt
=
datetime
.
now
())
except
oauth2_provider
.
models
.
AccessToken
.
DoesNotExist
:
raise
exceptions
.
AuthenticationFailed
(
'Invalid token'
)
user
=
token
.
user
if
not
user
.
is_active
:
if
not
token
.
user
.
is_active
:
msg
=
'User inactive or deleted:
%
s'
%
user
.
username
raise
exceptions
.
AuthenticationFailed
(
msg
)
...
...
rest_framework/compat.py
View file @
399ac70b
...
...
@@ -395,6 +395,37 @@ except ImportError:
kw
=
dict
((
k
,
int
(
v
))
for
k
,
v
in
kw
.
iteritems
()
if
v
is
not
None
)
return
datetime
.
datetime
(
**
kw
)
# smart_urlquote is new on Django 1.4
try
:
from
django.utils.html
import
smart_urlquote
except
ImportError
:
try
:
from
urllib.parse
import
quote
,
urlsplit
,
urlunsplit
except
ImportError
:
# Python 2
from
urllib
import
quote
from
urlparse
import
urlsplit
,
urlunsplit
def
smart_urlquote
(
url
):
"Quotes a URL if it isn't already quoted."
# Handle IDN before quoting.
scheme
,
netloc
,
path
,
query
,
fragment
=
urlsplit
(
url
)
try
:
netloc
=
netloc
.
encode
(
'idna'
)
.
decode
(
'ascii'
)
# IDN -> ACE
except
UnicodeError
:
# invalid domain part
pass
else
:
url
=
urlunsplit
((
scheme
,
netloc
,
path
,
query
,
fragment
))
# An URL is considered unquoted if it contains no % characters or
# contains a % not followed by two hexadecimal digits. See #9655.
if
'
%
'
not
in
url
or
unquoted_percents_re
.
search
(
url
):
# See http://bugs.python.org/issue2637
url
=
quote
(
force_bytes
(
url
),
safe
=
b
'!*
\'
();:@&=+$,/?#[]~'
)
return
force_text
(
url
)
# Markdown is optional
try
:
import
markdown
...
...
@@ -445,14 +476,12 @@ except ImportError:
# OAuth 2 support is optional
try
:
import
provider.oauth2
as
oauth2_provider
from
provider.oauth2
import
backends
as
oauth2_provider_backends
from
provider.oauth2
import
models
as
oauth2_provider_models
from
provider.oauth2
import
forms
as
oauth2_provider_forms
from
provider
import
scope
as
oauth2_provider_scope
from
provider
import
constants
as
oauth2_constants
except
ImportError
:
oauth2_provider
=
None
oauth2_provider_backends
=
None
oauth2_provider_models
=
None
oauth2_provider_forms
=
None
oauth2_provider_scope
=
None
...
...
rest_framework/filters.py
View file @
399ac70b
...
...
@@ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend):
filter_class
=
self
.
get_filter_class
(
view
)
if
filter_class
:
return
filter_class
(
request
.
QUERY_PARAMS
,
queryset
=
queryset
)
return
filter_class
(
request
.
QUERY_PARAMS
,
queryset
=
queryset
)
.
qs
return
queryset
rest_framework/templatetags/rest_framework.py
View file @
399ac70b
...
...
@@ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch
from
django.http
import
QueryDict
from
django.utils.html
import
escape
from
django.utils.safestring
import
SafeData
,
mark_safe
from
rest_framework.compat
import
urlparse
from
rest_framework.compat
import
force_text
from
rest_framework.compat
import
six
import
re
import
string
from
rest_framework.compat
import
urlparse
,
force_text
,
six
,
smart_urlquote
import
re
,
string
register
=
template
.
Library
()
...
...
@@ -112,22 +109,6 @@ def replace_query_param(url, key, val):
class_re
=
re
.
compile
(
r'(?<=class=["\'])(.*)(?=["\'])'
)
# Bunch of stuff cloned from urlize
LEADING_PUNCTUATION
=
[
'('
,
'<'
,
'<'
,
'"'
,
"'"
]
TRAILING_PUNCTUATION
=
[
'.'
,
','
,
')'
,
'>'
,
'
\n
'
,
'>'
,
'"'
,
"'"
]
DOTS
=
[
'·'
,
'*'
,
'
\xe2\x80\xa2
'
,
'•'
,
'•'
,
'•'
]
unencoded_ampersands_re
=
re
.
compile
(
r'&(?!(\w+|#\d+);)'
)
word_split_re
=
re
.
compile
(
r'(\s+)'
)
punctuation_re
=
re
.
compile
(
'^(?P<lead>(?:
%
s)*)(?P<middle>.*?)(?P<trail>(?:
%
s)*)$'
%
\
(
'|'
.
join
([
re
.
escape
(
x
)
for
x
in
LEADING_PUNCTUATION
]),
'|'
.
join
([
re
.
escape
(
x
)
for
x
in
TRAILING_PUNCTUATION
])))
simple_email_re
=
re
.
compile
(
r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$'
)
link_target_attribute_re
=
re
.
compile
(
r'(<a [^>]*?)target=[^\s>]+'
)
html_gunk_re
=
re
.
compile
(
r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)'
,
re
.
IGNORECASE
)
hard_coded_bullets_re
=
re
.
compile
(
r'((?:<p>(?:
%
s).*?[a-zA-Z].*?</p>\s*)+)'
%
'|'
.
join
([
re
.
escape
(
x
)
for
x
in
DOTS
]),
re
.
DOTALL
)
trailing_empty_content_re
=
re
.
compile
(
r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z'
)
# And the template tags themselves...
@register.simple_tag
...
...
@@ -195,15 +176,25 @@ def add_class(value, css_class):
return
value
# Bunch of stuff cloned from urlize
TRAILING_PUNCTUATION
=
[
'.'
,
','
,
':'
,
';'
,
'.)'
,
'"'
,
"'"
]
WRAPPING_PUNCTUATION
=
[(
'('
,
')'
),
(
'<'
,
'>'
),
(
'['
,
']'
),
(
'<'
,
'>'
),
(
'"'
,
'"'
),
(
"'"
,
"'"
)]
word_split_re
=
re
.
compile
(
r'(\s+)'
)
simple_url_re
=
re
.
compile
(
r'^https?://\w'
,
re
.
IGNORECASE
)
simple_url_2_re
=
re
.
compile
(
r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$'
,
re
.
IGNORECASE
)
simple_email_re
=
re
.
compile
(
r'^\S+@\S+\.\S+$'
)
@register.filter
def
urlize_quoted_links
(
text
,
trim_url_limit
=
None
,
nofollow
=
True
,
autoescape
=
True
):
"""
Converts any URLs in text into clickable links.
Works on http://, https://, www. links
and links ending in .org, .net or
.com. Links can have trailing punctuation (periods, commas, close-parens)
and leading punctuation (opening parens) and it'll still do the right
thing.
Works on http://, https://, www. links
, and also on links ending in one of
the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org).
Links can have trailing punctuation (periods, commas, close-parens) and
leading punctuation (opening parens) and it'll still do the right
thing.
If trim_url_limit is not None, the URLs in link text longer than this limit
will truncated to trim_url_limit-3 characters and appended with an elipsis.
...
...
@@ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
trim_url
=
lambda
x
,
limit
=
trim_url_limit
:
limit
is
not
None
and
(
len
(
x
)
>
limit
and
(
'
%
s...'
%
x
[:
max
(
0
,
limit
-
3
)]))
or
x
safe_input
=
isinstance
(
text
,
SafeData
)
words
=
word_split_re
.
split
(
force_text
(
text
))
nofollow_attr
=
nofollow
and
' rel="nofollow"'
or
''
for
i
,
word
in
enumerate
(
words
):
match
=
None
if
'.'
in
word
or
'@'
in
word
or
':'
in
word
:
match
=
punctuation_re
.
match
(
word
)
if
match
:
lead
,
middle
,
trail
=
match
.
groups
()
# Deal with punctuation.
lead
,
middle
,
trail
=
''
,
word
,
''
for
punctuation
in
TRAILING_PUNCTUATION
:
if
middle
.
endswith
(
punctuation
):
middle
=
middle
[:
-
len
(
punctuation
)]
trail
=
punctuation
+
trail
for
opening
,
closing
in
WRAPPING_PUNCTUATION
:
if
middle
.
startswith
(
opening
):
middle
=
middle
[
len
(
opening
):]
lead
=
lead
+
opening
# Keep parentheses at the end only if they're balanced.
if
(
middle
.
endswith
(
closing
)
and
middle
.
count
(
closing
)
==
middle
.
count
(
opening
)
+
1
):
middle
=
middle
[:
-
len
(
closing
)]
trail
=
closing
+
trail
# Make URL we want to point to.
url
=
None
if
middle
.
startswith
(
'http://'
)
or
middle
.
startswith
(
'https://'
):
url
=
middle
elif
middle
.
startswith
(
'www.'
)
or
(
'@'
not
in
middle
and
\
middle
and
middle
[
0
]
in
string
.
ascii_letters
+
string
.
digits
and
\
(
middle
.
endswith
(
'.org'
)
or
middle
.
endswith
(
'.net'
)
or
middle
.
endswith
(
'.com'
))):
url
=
'http://
%
s'
%
middle
elif
'@'
in
middle
and
not
':'
in
middle
and
simple_email_re
.
match
(
middle
):
url
=
'mailto:
%
s'
%
middle
nofollow_attr
=
' rel="nofollow"'
if
nofollow
else
''
if
simple_url_re
.
match
(
middle
):
url
=
smart_urlquote
(
middle
)
elif
simple_url_2_re
.
match
(
middle
):
url
=
smart_urlquote
(
'http://
%
s'
%
middle
)
elif
not
':'
in
middle
and
simple_email_re
.
match
(
middle
):
local
,
domain
=
middle
.
rsplit
(
'@'
,
1
)
try
:
domain
=
domain
.
encode
(
'idna'
)
.
decode
(
'ascii'
)
except
UnicodeError
:
continue
url
=
'mailto:
%
s@
%
s'
%
(
local
,
domain
)
nofollow_attr
=
''
# Make link.
if
url
:
trimmed
=
trim_url
(
middle
)
...
...
@@ -251,4 +259,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
words
[
i
]
=
mark_safe
(
word
)
elif
autoescape
:
words
[
i
]
=
escape
(
word
)
return
mark_safe
(
''
.
join
(
words
)
)
return
''
.
join
(
words
)
rest_framework/tests/authentication.py
View file @
399ac70b
...
...
@@ -466,17 +466,13 @@ class OAuth2Tests(TestCase):
def
_create_authorization_header
(
self
,
token
=
None
):
return
"Bearer {0}"
.
format
(
token
or
self
.
access_token
.
token
)
def
_client_credentials_params
(
self
):
return
{
'client_id'
:
self
.
CLIENT_ID
,
'client_secret'
:
self
.
CLIENT_SECRET
}
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
def
test_get_form_with_wrong_authorization_header_token_type_failing
(
self
):
"""Ensure that a wrong token type lead to the correct HTTP error status code"""
auth
=
"Wrong token-type-obsviously"
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
{},
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
...
...
@@ -485,8 +481,7 @@ class OAuth2Tests(TestCase):
auth
=
"Bearer wrong token format"
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
{},
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
...
...
@@ -495,33 +490,21 @@ class OAuth2Tests(TestCase):
auth
=
"Bearer wrong-token"
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
{},
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
def
test_get_form_with_wrong_client_data_failing_auth
(
self
):
"""Ensure GETing form over OAuth with incorrect client credentials fails"""
auth
=
self
.
_create_authorization_header
()
params
=
self
.
_client_credentials_params
()
params
[
'client_id'
]
+=
'a'
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
def
test_get_form_passing_auth
(
self
):
"""Ensure GETing form over OAuth with correct client credentials succeed"""
auth
=
self
.
_create_authorization_header
()
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
200
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
def
test_post_form_passing_auth
(
self
):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
auth
=
self
.
_create_authorization_header
()
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
200
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
...
...
@@ -529,16 +512,14 @@ class OAuth2Tests(TestCase):
"""Ensure POSTing when there is no OAuth access token in db fails"""
self
.
access_token
.
delete
()
auth
=
self
.
_create_authorization_header
()
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertIn
(
response
.
status_code
,
(
status
.
HTTP_401_UNAUTHORIZED
,
status
.
HTTP_403_FORBIDDEN
))
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
def
test_post_form_with_refresh_token_failing_auth
(
self
):
"""Ensure POSTing with refresh token instead of access token fails"""
auth
=
self
.
_create_authorization_header
(
token
=
self
.
refresh_token
.
token
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertIn
(
response
.
status_code
,
(
status
.
HTTP_401_UNAUTHORIZED
,
status
.
HTTP_403_FORBIDDEN
))
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
...
...
@@ -547,8 +528,7 @@ class OAuth2Tests(TestCase):
self
.
access_token
.
expires
=
datetime
.
datetime
.
now
()
-
datetime
.
timedelta
(
seconds
=
10
)
# 10 seconds late
self
.
access_token
.
save
()
auth
=
self
.
_create_authorization_header
()
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertIn
(
response
.
status_code
,
(
status
.
HTTP_401_UNAUTHORIZED
,
status
.
HTTP_403_FORBIDDEN
))
self
.
assertIn
(
'Invalid token'
,
response
.
content
)
...
...
@@ -559,10 +539,9 @@ class OAuth2Tests(TestCase):
read_only_access_token
.
scope
=
oauth2_provider_scope
.
SCOPE_NAME_DICT
[
'read'
]
read_only_access_token
.
save
()
auth
=
self
.
_create_authorization_header
(
token
=
read_only_access_token
.
token
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
get
(
'/oauth2-with-scope-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-with-scope-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
200
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-with-scope-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-with-scope-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_403_FORBIDDEN
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
...
...
@@ -572,6 +551,5 @@ class OAuth2Tests(TestCase):
read_write_access_token
.
scope
=
oauth2_provider_scope
.
SCOPE_NAME_DICT
[
'write'
]
read_write_access_token
.
save
()
auth
=
self
.
_create_authorization_header
(
token
=
read_write_access_token
.
token
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
post
(
'/oauth2-with-scope-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-with-scope-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
200
)
rest_framework/tests/filterset.py
View file @
399ac70b
from
__future__
import
unicode_literals
import
datetime
from
decimal
import
Decimal
from
django.core.urlresolvers
import
reverse
from
django.test
import
TestCase
from
django.test.client
import
RequestFactory
from
django.utils
import
unittest
from
rest_framework
import
generics
,
status
,
filters
from
rest_framework.compat
import
django_filters
from
rest_framework.compat
import
django_filters
,
patterns
,
url
from
rest_framework.tests.models
import
FilterableItem
,
BasicModel
factory
=
RequestFactory
()
...
...
@@ -46,12 +47,21 @@ if django_filters:
filter_class
=
MisconfiguredFilter
filter_backend
=
filters
.
DjangoFilterBackend
class
FilterClassDetailView
(
generics
.
RetrieveAPIView
):
model
=
FilterableItem
filter_class
=
SeveralFieldsFilter
filter_backend
=
filters
.
DjangoFilterBackend
urlpatterns
=
patterns
(
''
,
url
(
r'^(?P<pk>\d+)/$'
,
FilterClassDetailView
.
as_view
(),
name
=
'detail-view'
),
url
(
r'^$'
,
FilterClassRootView
.
as_view
(),
name
=
'root-view'
),
)
class
IntegrationTestFiltering
(
TestCase
):
"""
Integration tests for filtered list views.
"""
class
CommonFilteringTestCase
(
TestCase
):
def
_serialize_object
(
self
,
obj
):
return
{
'id'
:
obj
.
id
,
'text'
:
obj
.
text
,
'decimal'
:
obj
.
decimal
,
'date'
:
obj
.
date
}
def
setUp
(
self
):
"""
Create 10 FilterableItem instances.
...
...
@@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase):
self
.
objects
=
FilterableItem
.
objects
self
.
data
=
[
{
'id'
:
obj
.
id
,
'text'
:
obj
.
text
,
'decimal'
:
obj
.
decimal
,
'date'
:
obj
.
date
}
self
.
_serialize_object
(
obj
)
for
obj
in
self
.
objects
.
all
()
]
class
IntegrationTestFiltering
(
CommonFilteringTestCase
):
"""
Integration tests for filtered list views.
"""
@unittest.skipUnless
(
django_filters
,
'django-filters not installed'
)
def
test_get_filtered_fields_root_view
(
self
):
"""
...
...
@@ -167,3 +183,50 @@ class IntegrationTestFiltering(TestCase):
request
=
factory
.
get
(
'/?integer=
%
s'
%
search_integer
)
response
=
view
(
request
)
.
render
()
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
class
IntegrationTestDetailFiltering
(
CommonFilteringTestCase
):
"""
Integration tests for filtered detail views.
"""
urls
=
'rest_framework.tests.filterset'
def
_get_url
(
self
,
item
):
return
reverse
(
'detail-view'
,
kwargs
=
dict
(
pk
=
item
.
pk
))
@unittest.skipUnless
(
django_filters
,
'django-filters not installed'
)
def
test_get_filtered_detail_view
(
self
):
"""
GET requests to filtered RetrieveAPIView that have a filter_class set
should return filtered results.
"""
item
=
self
.
objects
.
all
()[
0
]
data
=
self
.
_serialize_object
(
item
)
# Basic test with no filter.
response
=
self
.
client
.
get
(
self
.
_get_url
(
item
))
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
self
.
assertEqual
(
response
.
data
,
data
)
# Tests that the decimal filter set that should fail.
search_decimal
=
Decimal
(
'4.25'
)
high_item
=
self
.
objects
.
filter
(
decimal__gt
=
search_decimal
)[
0
]
response
=
self
.
client
.
get
(
'{url}?decimal={param}'
.
format
(
url
=
self
.
_get_url
(
high_item
),
param
=
search_decimal
))
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_404_NOT_FOUND
)
# Tests that the decimal filter set that should succeed.
search_decimal
=
Decimal
(
'4.25'
)
low_item
=
self
.
objects
.
filter
(
decimal__lt
=
search_decimal
)[
0
]
low_item_data
=
self
.
_serialize_object
(
low_item
)
response
=
self
.
client
.
get
(
'{url}?decimal={param}'
.
format
(
url
=
self
.
_get_url
(
low_item
),
param
=
search_decimal
))
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
self
.
assertEqual
(
response
.
data
,
low_item_data
)
# Tests that multiple filters works.
search_decimal
=
Decimal
(
'5.25'
)
search_date
=
datetime
.
date
(
2012
,
10
,
2
)
valid_item
=
self
.
objects
.
filter
(
decimal__lt
=
search_decimal
,
date__gt
=
search_date
)[
0
]
valid_item_data
=
self
.
_serialize_object
(
valid_item
)
response
=
self
.
client
.
get
(
'{url}?decimal={decimal}&date={date}'
.
format
(
url
=
self
.
_get_url
(
valid_item
),
decimal
=
search_decimal
,
date
=
search_date
))
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
self
.
assertEqual
(
response
.
data
,
valid_item_data
)
rest_framework/tests/pagination.py
View file @
399ac70b
...
...
@@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):
view
=
FilterFieldsRootView
.
as_view
()
EXPECTED_NUM_QUERIES
=
2
if
django
.
VERSION
<
(
1
,
4
):
# On Django 1.3 we need to use django-filter 0.5.4
#
# The filter objects there don't expose a `.count()` method,
# which means we only make a single query *but* it's a single
# query across *all* of the queryset, instead of a COUNT and then
# a SELECT with a LIMIT.
#
# Although this is fewer queries, it's actually a regression.
EXPECTED_NUM_QUERIES
=
1
request
=
factory
.
get
(
'/?decimal=15.20'
)
with
self
.
assertNumQueries
(
EXPECTED_NUM_QUERIES
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment