Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
django-rest-framework
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
edx
django-rest-framework
Commits
399ac70b
Commit
399ac70b
authored
Mar 30, 2013
by
Tom Christie
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' of
https://github.com/tomchristie/django-rest-framework
parents
c4eda3a6
2e06f5c8
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
176 additions
and
106 deletions
+176
-106
docs/api-guide/authentication.md
+1
-1
docs/topics/release-notes.md
+6
-0
rest_framework/authentication.py
+11
-15
rest_framework/compat.py
+31
-2
rest_framework/filters.py
+1
-1
rest_framework/templatetags/rest_framework.py
+46
-38
rest_framework/tests/authentication.py
+11
-33
rest_framework/tests/filterset.py
+69
-6
rest_framework/tests/pagination.py
+0
-10
No files found.
docs/api-guide/authentication.md
View file @
399ac70b
...
@@ -300,7 +300,7 @@ The only thing needed to make the `OAuth2Authentication` class work is to insert
...
@@ -300,7 +300,7 @@ The only thing needed to make the `OAuth2Authentication` class work is to insert
The command line to test the authentication looks like:
The command line to test the authentication looks like:
curl -H "Authorization: Bearer <your-access-token>" http://localhost:8000/api/
?client_id=YOUR_CLIENT_ID\&client_secret=YOUR_CLIENT_SECRET
curl -H "Authorization: Bearer <your-access-token>" http://localhost:8000/api/
---
---
...
...
docs/topics/release-notes.md
View file @
399ac70b
...
@@ -40,6 +40,12 @@ You can determine your currently installed version using `pip freeze`:
...
@@ -40,6 +40,12 @@ You can determine your currently installed version using `pip freeze`:
## 2.2.x series
## 2.2.x series
### Master
*
OAuth2 authentication no longer requires unneccessary URL parameters in addition to the token.
*
URL hyperlinking in browseable API now handles more cases correctly.
*
Bugfix
:
Fix regression with DjangoFilterBackend not worthing correctly with single object views.
### 2.2.5
### 2.2.5
*
*Date
**
:
26th March 2013
*
*Date
**
:
26th March 2013
...
...
rest_framework/authentication.py
View file @
399ac70b
...
@@ -2,14 +2,16 @@
...
@@ -2,14 +2,16 @@
Provides a set of pluggable authentication policies.
Provides a set of pluggable authentication policies.
"""
"""
from
__future__
import
unicode_literals
from
__future__
import
unicode_literals
import
base64
from
datetime
import
datetime
from
django.contrib.auth
import
authenticate
from
django.contrib.auth
import
authenticate
from
django.core.exceptions
import
ImproperlyConfigured
from
django.core.exceptions
import
ImproperlyConfigured
from
rest_framework
import
exceptions
,
HTTP_HEADER_ENCODING
from
rest_framework
import
exceptions
,
HTTP_HEADER_ENCODING
from
rest_framework.compat
import
CsrfViewMiddleware
from
rest_framework.compat
import
CsrfViewMiddleware
from
rest_framework.compat
import
oauth
,
oauth_provider
,
oauth_provider_store
from
rest_framework.compat
import
oauth
,
oauth_provider
,
oauth_provider_store
from
rest_framework.compat
import
oauth2_provider
,
oauth2_provider_forms
,
oauth2_provider_backends
from
rest_framework.compat
import
oauth2_provider
,
oauth2_provider_forms
from
rest_framework.authtoken.models
import
Token
from
rest_framework.authtoken.models
import
Token
import
base64
def
get_authorization_header
(
request
):
def
get_authorization_header
(
request
):
...
@@ -315,21 +317,15 @@ class OAuth2Authentication(BaseAuthentication):
...
@@ -315,21 +317,15 @@ class OAuth2Authentication(BaseAuthentication):
Authenticate the request, given the access token.
Authenticate the request, given the access token.
"""
"""
# Authenticate the client
try
:
oauth2_client_form
=
oauth2_provider_forms
.
ClientAuthForm
(
request
.
REQUEST
)
token
=
oauth2_provider
.
models
.
AccessToken
.
objects
.
select_related
(
'user'
)
if
not
oauth2_client_form
.
is_valid
():
# TODO: Change to timezone aware datetime when oauth2_provider add
raise
exceptions
.
AuthenticationFailed
(
'Client could not be validated'
)
# support to it.
client
=
oauth2_client_form
.
cleaned_data
.
get
(
'client'
)
token
=
token
.
get
(
token
=
access_token
,
expires__gt
=
datetime
.
now
())
except
oauth2_provider
.
models
.
AccessToken
.
DoesNotExist
:
# Retrieve the `OAuth2AccessToken` instance from the access_token
auth_backend
=
oauth2_provider_backends
.
AccessTokenBackend
()
token
=
auth_backend
.
authenticate
(
access_token
,
client
)
if
token
is
None
:
raise
exceptions
.
AuthenticationFailed
(
'Invalid token'
)
raise
exceptions
.
AuthenticationFailed
(
'Invalid token'
)
user
=
token
.
user
if
not
token
.
user
.
is_active
:
if
not
user
.
is_active
:
msg
=
'User inactive or deleted:
%
s'
%
user
.
username
msg
=
'User inactive or deleted:
%
s'
%
user
.
username
raise
exceptions
.
AuthenticationFailed
(
msg
)
raise
exceptions
.
AuthenticationFailed
(
msg
)
...
...
rest_framework/compat.py
View file @
399ac70b
...
@@ -395,6 +395,37 @@ except ImportError:
...
@@ -395,6 +395,37 @@ except ImportError:
kw
=
dict
((
k
,
int
(
v
))
for
k
,
v
in
kw
.
iteritems
()
if
v
is
not
None
)
kw
=
dict
((
k
,
int
(
v
))
for
k
,
v
in
kw
.
iteritems
()
if
v
is
not
None
)
return
datetime
.
datetime
(
**
kw
)
return
datetime
.
datetime
(
**
kw
)
# smart_urlquote is new on Django 1.4
try
:
from
django.utils.html
import
smart_urlquote
except
ImportError
:
try
:
from
urllib.parse
import
quote
,
urlsplit
,
urlunsplit
except
ImportError
:
# Python 2
from
urllib
import
quote
from
urlparse
import
urlsplit
,
urlunsplit
def
smart_urlquote
(
url
):
"Quotes a URL if it isn't already quoted."
# Handle IDN before quoting.
scheme
,
netloc
,
path
,
query
,
fragment
=
urlsplit
(
url
)
try
:
netloc
=
netloc
.
encode
(
'idna'
)
.
decode
(
'ascii'
)
# IDN -> ACE
except
UnicodeError
:
# invalid domain part
pass
else
:
url
=
urlunsplit
((
scheme
,
netloc
,
path
,
query
,
fragment
))
# An URL is considered unquoted if it contains no % characters or
# contains a % not followed by two hexadecimal digits. See #9655.
if
'
%
'
not
in
url
or
unquoted_percents_re
.
search
(
url
):
# See http://bugs.python.org/issue2637
url
=
quote
(
force_bytes
(
url
),
safe
=
b
'!*
\'
();:@&=+$,/?#[]~'
)
return
force_text
(
url
)
# Markdown is optional
# Markdown is optional
try
:
try
:
import
markdown
import
markdown
...
@@ -445,14 +476,12 @@ except ImportError:
...
@@ -445,14 +476,12 @@ except ImportError:
# OAuth 2 support is optional
# OAuth 2 support is optional
try
:
try
:
import
provider.oauth2
as
oauth2_provider
import
provider.oauth2
as
oauth2_provider
from
provider.oauth2
import
backends
as
oauth2_provider_backends
from
provider.oauth2
import
models
as
oauth2_provider_models
from
provider.oauth2
import
models
as
oauth2_provider_models
from
provider.oauth2
import
forms
as
oauth2_provider_forms
from
provider.oauth2
import
forms
as
oauth2_provider_forms
from
provider
import
scope
as
oauth2_provider_scope
from
provider
import
scope
as
oauth2_provider_scope
from
provider
import
constants
as
oauth2_constants
from
provider
import
constants
as
oauth2_constants
except
ImportError
:
except
ImportError
:
oauth2_provider
=
None
oauth2_provider
=
None
oauth2_provider_backends
=
None
oauth2_provider_models
=
None
oauth2_provider_models
=
None
oauth2_provider_forms
=
None
oauth2_provider_forms
=
None
oauth2_provider_scope
=
None
oauth2_provider_scope
=
None
...
...
rest_framework/filters.py
View file @
399ac70b
...
@@ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend):
...
@@ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend):
filter_class
=
self
.
get_filter_class
(
view
)
filter_class
=
self
.
get_filter_class
(
view
)
if
filter_class
:
if
filter_class
:
return
filter_class
(
request
.
QUERY_PARAMS
,
queryset
=
queryset
)
return
filter_class
(
request
.
QUERY_PARAMS
,
queryset
=
queryset
)
.
qs
return
queryset
return
queryset
rest_framework/templatetags/rest_framework.py
View file @
399ac70b
...
@@ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch
...
@@ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch
from
django.http
import
QueryDict
from
django.http
import
QueryDict
from
django.utils.html
import
escape
from
django.utils.html
import
escape
from
django.utils.safestring
import
SafeData
,
mark_safe
from
django.utils.safestring
import
SafeData
,
mark_safe
from
rest_framework.compat
import
urlparse
from
rest_framework.compat
import
urlparse
,
force_text
,
six
,
smart_urlquote
from
rest_framework.compat
import
force_text
import
re
,
string
from
rest_framework.compat
import
six
import
re
import
string
register
=
template
.
Library
()
register
=
template
.
Library
()
...
@@ -112,22 +109,6 @@ def replace_query_param(url, key, val):
...
@@ -112,22 +109,6 @@ def replace_query_param(url, key, val):
class_re
=
re
.
compile
(
r'(?<=class=["\'])(.*)(?=["\'])'
)
class_re
=
re
.
compile
(
r'(?<=class=["\'])(.*)(?=["\'])'
)
# Bunch of stuff cloned from urlize
LEADING_PUNCTUATION
=
[
'('
,
'<'
,
'<'
,
'"'
,
"'"
]
TRAILING_PUNCTUATION
=
[
'.'
,
','
,
')'
,
'>'
,
'
\n
'
,
'>'
,
'"'
,
"'"
]
DOTS
=
[
'·'
,
'*'
,
'
\xe2\x80\xa2
'
,
'•'
,
'•'
,
'•'
]
unencoded_ampersands_re
=
re
.
compile
(
r'&(?!(\w+|#\d+);)'
)
word_split_re
=
re
.
compile
(
r'(\s+)'
)
punctuation_re
=
re
.
compile
(
'^(?P<lead>(?:
%
s)*)(?P<middle>.*?)(?P<trail>(?:
%
s)*)$'
%
\
(
'|'
.
join
([
re
.
escape
(
x
)
for
x
in
LEADING_PUNCTUATION
]),
'|'
.
join
([
re
.
escape
(
x
)
for
x
in
TRAILING_PUNCTUATION
])))
simple_email_re
=
re
.
compile
(
r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$'
)
link_target_attribute_re
=
re
.
compile
(
r'(<a [^>]*?)target=[^\s>]+'
)
html_gunk_re
=
re
.
compile
(
r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)'
,
re
.
IGNORECASE
)
hard_coded_bullets_re
=
re
.
compile
(
r'((?:<p>(?:
%
s).*?[a-zA-Z].*?</p>\s*)+)'
%
'|'
.
join
([
re
.
escape
(
x
)
for
x
in
DOTS
]),
re
.
DOTALL
)
trailing_empty_content_re
=
re
.
compile
(
r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z'
)
# And the template tags themselves...
# And the template tags themselves...
@register.simple_tag
@register.simple_tag
...
@@ -195,15 +176,25 @@ def add_class(value, css_class):
...
@@ -195,15 +176,25 @@ def add_class(value, css_class):
return
value
return
value
# Bunch of stuff cloned from urlize
TRAILING_PUNCTUATION
=
[
'.'
,
','
,
':'
,
';'
,
'.)'
,
'"'
,
"'"
]
WRAPPING_PUNCTUATION
=
[(
'('
,
')'
),
(
'<'
,
'>'
),
(
'['
,
']'
),
(
'<'
,
'>'
),
(
'"'
,
'"'
),
(
"'"
,
"'"
)]
word_split_re
=
re
.
compile
(
r'(\s+)'
)
simple_url_re
=
re
.
compile
(
r'^https?://\w'
,
re
.
IGNORECASE
)
simple_url_2_re
=
re
.
compile
(
r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$'
,
re
.
IGNORECASE
)
simple_email_re
=
re
.
compile
(
r'^\S+@\S+\.\S+$'
)
@register.filter
@register.filter
def
urlize_quoted_links
(
text
,
trim_url_limit
=
None
,
nofollow
=
True
,
autoescape
=
True
):
def
urlize_quoted_links
(
text
,
trim_url_limit
=
None
,
nofollow
=
True
,
autoescape
=
True
):
"""
"""
Converts any URLs in text into clickable links.
Converts any URLs in text into clickable links.
Works on http://, https://, www. links
and links ending in .org, .net or
Works on http://, https://, www. links
, and also on links ending in one of
.com. Links can have trailing punctuation (periods, commas, close-parens)
the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org).
and leading punctuation (opening parens) and it'll still do the right
Links can have trailing punctuation (periods, commas, close-parens) and
thing.
leading punctuation (opening parens) and it'll still do the right
thing.
If trim_url_limit is not None, the URLs in link text longer than this limit
If trim_url_limit is not None, the URLs in link text longer than this limit
will truncated to trim_url_limit-3 characters and appended with an elipsis.
will truncated to trim_url_limit-3 characters and appended with an elipsis.
...
@@ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
...
@@ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
trim_url
=
lambda
x
,
limit
=
trim_url_limit
:
limit
is
not
None
and
(
len
(
x
)
>
limit
and
(
'
%
s...'
%
x
[:
max
(
0
,
limit
-
3
)]))
or
x
trim_url
=
lambda
x
,
limit
=
trim_url_limit
:
limit
is
not
None
and
(
len
(
x
)
>
limit
and
(
'
%
s...'
%
x
[:
max
(
0
,
limit
-
3
)]))
or
x
safe_input
=
isinstance
(
text
,
SafeData
)
safe_input
=
isinstance
(
text
,
SafeData
)
words
=
word_split_re
.
split
(
force_text
(
text
))
words
=
word_split_re
.
split
(
force_text
(
text
))
nofollow_attr
=
nofollow
and
' rel="nofollow"'
or
''
for
i
,
word
in
enumerate
(
words
):
for
i
,
word
in
enumerate
(
words
):
match
=
None
match
=
None
if
'.'
in
word
or
'@'
in
word
or
':'
in
word
:
if
'.'
in
word
or
'@'
in
word
or
':'
in
word
:
match
=
punctuation_re
.
match
(
word
)
# Deal with punctuation.
if
match
:
lead
,
middle
,
trail
=
''
,
word
,
''
lead
,
middle
,
trail
=
match
.
groups
()
for
punctuation
in
TRAILING_PUNCTUATION
:
if
middle
.
endswith
(
punctuation
):
middle
=
middle
[:
-
len
(
punctuation
)]
trail
=
punctuation
+
trail
for
opening
,
closing
in
WRAPPING_PUNCTUATION
:
if
middle
.
startswith
(
opening
):
middle
=
middle
[
len
(
opening
):]
lead
=
lead
+
opening
# Keep parentheses at the end only if they're balanced.
if
(
middle
.
endswith
(
closing
)
and
middle
.
count
(
closing
)
==
middle
.
count
(
opening
)
+
1
):
middle
=
middle
[:
-
len
(
closing
)]
trail
=
closing
+
trail
# Make URL we want to point to.
# Make URL we want to point to.
url
=
None
url
=
None
if
middle
.
startswith
(
'http://'
)
or
middle
.
startswith
(
'https://'
):
nofollow_attr
=
' rel="nofollow"'
if
nofollow
else
''
url
=
middle
if
simple_url_re
.
match
(
middle
):
elif
middle
.
startswith
(
'www.'
)
or
(
'@'
not
in
middle
and
\
url
=
smart_urlquote
(
middle
)
middle
and
middle
[
0
]
in
string
.
ascii_letters
+
string
.
digits
and
\
elif
simple_url_2_re
.
match
(
middle
):
(
middle
.
endswith
(
'.org'
)
or
middle
.
endswith
(
'.net'
)
or
middle
.
endswith
(
'.com'
))):
url
=
smart_urlquote
(
'http://
%
s'
%
middle
)
url
=
'http://
%
s'
%
middle
elif
not
':'
in
middle
and
simple_email_re
.
match
(
middle
):
elif
'@'
in
middle
and
not
':'
in
middle
and
simple_email_re
.
match
(
middle
):
local
,
domain
=
middle
.
rsplit
(
'@'
,
1
)
url
=
'mailto:
%
s'
%
middle
try
:
domain
=
domain
.
encode
(
'idna'
)
.
decode
(
'ascii'
)
except
UnicodeError
:
continue
url
=
'mailto:
%
s@
%
s'
%
(
local
,
domain
)
nofollow_attr
=
''
nofollow_attr
=
''
# Make link.
# Make link.
if
url
:
if
url
:
trimmed
=
trim_url
(
middle
)
trimmed
=
trim_url
(
middle
)
...
@@ -251,4 +259,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
...
@@ -251,4 +259,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
words
[
i
]
=
mark_safe
(
word
)
words
[
i
]
=
mark_safe
(
word
)
elif
autoescape
:
elif
autoescape
:
words
[
i
]
=
escape
(
word
)
words
[
i
]
=
escape
(
word
)
return
mark_safe
(
''
.
join
(
words
)
)
return
''
.
join
(
words
)
rest_framework/tests/authentication.py
View file @
399ac70b
...
@@ -466,17 +466,13 @@ class OAuth2Tests(TestCase):
...
@@ -466,17 +466,13 @@ class OAuth2Tests(TestCase):
def
_create_authorization_header
(
self
,
token
=
None
):
def
_create_authorization_header
(
self
,
token
=
None
):
return
"Bearer {0}"
.
format
(
token
or
self
.
access_token
.
token
)
return
"Bearer {0}"
.
format
(
token
or
self
.
access_token
.
token
)
def
_client_credentials_params
(
self
):
return
{
'client_id'
:
self
.
CLIENT_ID
,
'client_secret'
:
self
.
CLIENT_SECRET
}
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
def
test_get_form_with_wrong_authorization_header_token_type_failing
(
self
):
def
test_get_form_with_wrong_authorization_header_token_type_failing
(
self
):
"""Ensure that a wrong token type lead to the correct HTTP error status code"""
"""Ensure that a wrong token type lead to the correct HTTP error status code"""
auth
=
"Wrong token-type-obsviously"
auth
=
"Wrong token-type-obsviously"
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
{},
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
{},
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
self
.
assertEqual
(
response
.
status_code
,
401
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
self
.
assertEqual
(
response
.
status_code
,
401
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
...
@@ -485,8 +481,7 @@ class OAuth2Tests(TestCase):
...
@@ -485,8 +481,7 @@ class OAuth2Tests(TestCase):
auth
=
"Bearer wrong token format"
auth
=
"Bearer wrong token format"
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
{},
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
{},
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
self
.
assertEqual
(
response
.
status_code
,
401
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
self
.
assertEqual
(
response
.
status_code
,
401
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
...
@@ -495,33 +490,21 @@ class OAuth2Tests(TestCase):
...
@@ -495,33 +490,21 @@ class OAuth2Tests(TestCase):
auth
=
"Bearer wrong-token"
auth
=
"Bearer wrong-token"
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
{},
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
{},
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
self
.
assertEqual
(
response
.
status_code
,
401
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
def
test_get_form_with_wrong_client_data_failing_auth
(
self
):
"""Ensure GETing form over OAuth with incorrect client credentials fails"""
auth
=
self
.
_create_authorization_header
()
params
=
self
.
_client_credentials_params
()
params
[
'client_id'
]
+=
'a'
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
self
.
assertEqual
(
response
.
status_code
,
401
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
def
test_get_form_passing_auth
(
self
):
def
test_get_form_passing_auth
(
self
):
"""Ensure GETing form over OAuth with correct client credentials succeed"""
"""Ensure GETing form over OAuth with correct client credentials succeed"""
auth
=
self
.
_create_authorization_header
()
auth
=
self
.
_create_authorization_header
()
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
def
test_post_form_passing_auth
(
self
):
def
test_post_form_passing_auth
(
self
):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
auth
=
self
.
_create_authorization_header
()
auth
=
self
.
_create_authorization_header
()
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
...
@@ -529,16 +512,14 @@ class OAuth2Tests(TestCase):
...
@@ -529,16 +512,14 @@ class OAuth2Tests(TestCase):
"""Ensure POSTing when there is no OAuth access token in db fails"""
"""Ensure POSTing when there is no OAuth access token in db fails"""
self
.
access_token
.
delete
()
self
.
access_token
.
delete
()
auth
=
self
.
_create_authorization_header
()
auth
=
self
.
_create_authorization_header
()
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertIn
(
response
.
status_code
,
(
status
.
HTTP_401_UNAUTHORIZED
,
status
.
HTTP_403_FORBIDDEN
))
self
.
assertIn
(
response
.
status_code
,
(
status
.
HTTP_401_UNAUTHORIZED
,
status
.
HTTP_403_FORBIDDEN
))
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
def
test_post_form_with_refresh_token_failing_auth
(
self
):
def
test_post_form_with_refresh_token_failing_auth
(
self
):
"""Ensure POSTing with refresh token instead of access token fails"""
"""Ensure POSTing with refresh token instead of access token fails"""
auth
=
self
.
_create_authorization_header
(
token
=
self
.
refresh_token
.
token
)
auth
=
self
.
_create_authorization_header
(
token
=
self
.
refresh_token
.
token
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertIn
(
response
.
status_code
,
(
status
.
HTTP_401_UNAUTHORIZED
,
status
.
HTTP_403_FORBIDDEN
))
self
.
assertIn
(
response
.
status_code
,
(
status
.
HTTP_401_UNAUTHORIZED
,
status
.
HTTP_403_FORBIDDEN
))
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
...
@@ -547,8 +528,7 @@ class OAuth2Tests(TestCase):
...
@@ -547,8 +528,7 @@ class OAuth2Tests(TestCase):
self
.
access_token
.
expires
=
datetime
.
datetime
.
now
()
-
datetime
.
timedelta
(
seconds
=
10
)
# 10 seconds late
self
.
access_token
.
expires
=
datetime
.
datetime
.
now
()
-
datetime
.
timedelta
(
seconds
=
10
)
# 10 seconds late
self
.
access_token
.
save
()
self
.
access_token
.
save
()
auth
=
self
.
_create_authorization_header
()
auth
=
self
.
_create_authorization_header
()
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertIn
(
response
.
status_code
,
(
status
.
HTTP_401_UNAUTHORIZED
,
status
.
HTTP_403_FORBIDDEN
))
self
.
assertIn
(
response
.
status_code
,
(
status
.
HTTP_401_UNAUTHORIZED
,
status
.
HTTP_403_FORBIDDEN
))
self
.
assertIn
(
'Invalid token'
,
response
.
content
)
self
.
assertIn
(
'Invalid token'
,
response
.
content
)
...
@@ -559,10 +539,9 @@ class OAuth2Tests(TestCase):
...
@@ -559,10 +539,9 @@ class OAuth2Tests(TestCase):
read_only_access_token
.
scope
=
oauth2_provider_scope
.
SCOPE_NAME_DICT
[
'read'
]
read_only_access_token
.
scope
=
oauth2_provider_scope
.
SCOPE_NAME_DICT
[
'read'
]
read_only_access_token
.
save
()
read_only_access_token
.
save
()
auth
=
self
.
_create_authorization_header
(
token
=
read_only_access_token
.
token
)
auth
=
self
.
_create_authorization_header
(
token
=
read_only_access_token
.
token
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
get
(
'/oauth2-with-scope-test/'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
get
(
'/oauth2-with-scope-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-with-scope-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-with-scope-test/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_403_FORBIDDEN
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_403_FORBIDDEN
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
@unittest.skipUnless
(
oauth2_provider
,
'django-oauth2-provider not installed'
)
...
@@ -572,6 +551,5 @@ class OAuth2Tests(TestCase):
...
@@ -572,6 +551,5 @@ class OAuth2Tests(TestCase):
read_write_access_token
.
scope
=
oauth2_provider_scope
.
SCOPE_NAME_DICT
[
'write'
]
read_write_access_token
.
scope
=
oauth2_provider_scope
.
SCOPE_NAME_DICT
[
'write'
]
read_write_access_token
.
save
()
read_write_access_token
.
save
()
auth
=
self
.
_create_authorization_header
(
token
=
read_write_access_token
.
token
)
auth
=
self
.
_create_authorization_header
(
token
=
read_write_access_token
.
token
)
params
=
self
.
_client_credentials_params
()
response
=
self
.
csrf_client
.
post
(
'/oauth2-with-scope-test/'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/oauth2-with-scope-test/'
,
params
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertEqual
(
response
.
status_code
,
200
)
rest_framework/tests/filterset.py
View file @
399ac70b
from
__future__
import
unicode_literals
from
__future__
import
unicode_literals
import
datetime
import
datetime
from
decimal
import
Decimal
from
decimal
import
Decimal
from
django.core.urlresolvers
import
reverse
from
django.test
import
TestCase
from
django.test
import
TestCase
from
django.test.client
import
RequestFactory
from
django.test.client
import
RequestFactory
from
django.utils
import
unittest
from
django.utils
import
unittest
from
rest_framework
import
generics
,
status
,
filters
from
rest_framework
import
generics
,
status
,
filters
from
rest_framework.compat
import
django_filters
from
rest_framework.compat
import
django_filters
,
patterns
,
url
from
rest_framework.tests.models
import
FilterableItem
,
BasicModel
from
rest_framework.tests.models
import
FilterableItem
,
BasicModel
factory
=
RequestFactory
()
factory
=
RequestFactory
()
...
@@ -46,11 +47,20 @@ if django_filters:
...
@@ -46,11 +47,20 @@ if django_filters:
filter_class
=
MisconfiguredFilter
filter_class
=
MisconfiguredFilter
filter_backend
=
filters
.
DjangoFilterBackend
filter_backend
=
filters
.
DjangoFilterBackend
class
FilterClassDetailView
(
generics
.
RetrieveAPIView
):
model
=
FilterableItem
filter_class
=
SeveralFieldsFilter
filter_backend
=
filters
.
DjangoFilterBackend
urlpatterns
=
patterns
(
''
,
url
(
r'^(?P<pk>\d+)/$'
,
FilterClassDetailView
.
as_view
(),
name
=
'detail-view'
),
url
(
r'^$'
,
FilterClassRootView
.
as_view
(),
name
=
'root-view'
),
)
class
IntegrationTestFiltering
(
TestCase
):
"""
class
CommonFilteringTestCase
(
TestCase
):
Integration tests for filtered list views.
def
_serialize_object
(
self
,
obj
):
"""
return
{
'id'
:
obj
.
id
,
'text'
:
obj
.
text
,
'decimal'
:
obj
.
decimal
,
'date'
:
obj
.
date
}
def
setUp
(
self
):
def
setUp
(
self
):
"""
"""
...
@@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase):
...
@@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase):
self
.
objects
=
FilterableItem
.
objects
self
.
objects
=
FilterableItem
.
objects
self
.
data
=
[
self
.
data
=
[
{
'id'
:
obj
.
id
,
'text'
:
obj
.
text
,
'decimal'
:
obj
.
decimal
,
'date'
:
obj
.
date
}
self
.
_serialize_object
(
obj
)
for
obj
in
self
.
objects
.
all
()
for
obj
in
self
.
objects
.
all
()
]
]
class
IntegrationTestFiltering
(
CommonFilteringTestCase
):
"""
Integration tests for filtered list views.
"""
@unittest.skipUnless
(
django_filters
,
'django-filters not installed'
)
@unittest.skipUnless
(
django_filters
,
'django-filters not installed'
)
def
test_get_filtered_fields_root_view
(
self
):
def
test_get_filtered_fields_root_view
(
self
):
"""
"""
...
@@ -167,3 +183,50 @@ class IntegrationTestFiltering(TestCase):
...
@@ -167,3 +183,50 @@ class IntegrationTestFiltering(TestCase):
request
=
factory
.
get
(
'/?integer=
%
s'
%
search_integer
)
request
=
factory
.
get
(
'/?integer=
%
s'
%
search_integer
)
response
=
view
(
request
)
.
render
()
response
=
view
(
request
)
.
render
()
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
class
IntegrationTestDetailFiltering
(
CommonFilteringTestCase
):
"""
Integration tests for filtered detail views.
"""
urls
=
'rest_framework.tests.filterset'
def
_get_url
(
self
,
item
):
return
reverse
(
'detail-view'
,
kwargs
=
dict
(
pk
=
item
.
pk
))
@unittest.skipUnless
(
django_filters
,
'django-filters not installed'
)
def
test_get_filtered_detail_view
(
self
):
"""
GET requests to filtered RetrieveAPIView that have a filter_class set
should return filtered results.
"""
item
=
self
.
objects
.
all
()[
0
]
data
=
self
.
_serialize_object
(
item
)
# Basic test with no filter.
response
=
self
.
client
.
get
(
self
.
_get_url
(
item
))
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
self
.
assertEqual
(
response
.
data
,
data
)
# Tests that the decimal filter set that should fail.
search_decimal
=
Decimal
(
'4.25'
)
high_item
=
self
.
objects
.
filter
(
decimal__gt
=
search_decimal
)[
0
]
response
=
self
.
client
.
get
(
'{url}?decimal={param}'
.
format
(
url
=
self
.
_get_url
(
high_item
),
param
=
search_decimal
))
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_404_NOT_FOUND
)
# Tests that the decimal filter set that should succeed.
search_decimal
=
Decimal
(
'4.25'
)
low_item
=
self
.
objects
.
filter
(
decimal__lt
=
search_decimal
)[
0
]
low_item_data
=
self
.
_serialize_object
(
low_item
)
response
=
self
.
client
.
get
(
'{url}?decimal={param}'
.
format
(
url
=
self
.
_get_url
(
low_item
),
param
=
search_decimal
))
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
self
.
assertEqual
(
response
.
data
,
low_item_data
)
# Tests that multiple filters works.
search_decimal
=
Decimal
(
'5.25'
)
search_date
=
datetime
.
date
(
2012
,
10
,
2
)
valid_item
=
self
.
objects
.
filter
(
decimal__lt
=
search_decimal
,
date__gt
=
search_date
)[
0
]
valid_item_data
=
self
.
_serialize_object
(
valid_item
)
response
=
self
.
client
.
get
(
'{url}?decimal={decimal}&date={date}'
.
format
(
url
=
self
.
_get_url
(
valid_item
),
decimal
=
search_decimal
,
date
=
search_date
))
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
self
.
assertEqual
(
response
.
data
,
valid_item_data
)
rest_framework/tests/pagination.py
View file @
399ac70b
...
@@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):
...
@@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):
view
=
FilterFieldsRootView
.
as_view
()
view
=
FilterFieldsRootView
.
as_view
()
EXPECTED_NUM_QUERIES
=
2
EXPECTED_NUM_QUERIES
=
2
if
django
.
VERSION
<
(
1
,
4
):
# On Django 1.3 we need to use django-filter 0.5.4
#
# The filter objects there don't expose a `.count()` method,
# which means we only make a single query *but* it's a single
# query across *all* of the queryset, instead of a COUNT and then
# a SELECT with a LIMIT.
#
# Although this is fewer queries, it's actually a regression.
EXPECTED_NUM_QUERIES
=
1
request
=
factory
.
get
(
'/?decimal=15.20'
)
request
=
factory
.
get
(
'/?decimal=15.20'
)
with
self
.
assertNumQueries
(
EXPECTED_NUM_QUERIES
):
with
self
.
assertNumQueries
(
EXPECTED_NUM_QUERIES
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment