Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
django-rest-framework
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
edx
django-rest-framework
Commits
f585480e
Commit
f585480e
authored
Jun 28, 2013
by
Tom Christie
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added APIClient
parent
7224b20d
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
93 additions
and
38 deletions
+93
-38
rest_framework/test.py
+68
-13
rest_framework/tests/test_authentication.py
+22
-22
rest_framework/tests/test_request.py
+3
-3
No files found.
rest_framework/test.py
View file @
f585480e
from
rest_framework.compat
import
six
,
RequestFactory
# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order
# to make it harder for the user to import the wrong thing without realizing.
from
django.conf
import
settings
from
django.test.client
import
Client
as
DjangoClient
from
rest_framework.compat
import
RequestFactory
as
DjangoRequestFactory
from
rest_framework.compat
import
force_bytes_or_smart_bytes
,
six
from
rest_framework.renderers
import
JSONRenderer
,
MultiPartRenderer
class
APIRequestFactory
(
RequestFactory
):
class
APIRequestFactory
(
Django
RequestFactory
):
renderer_classes
=
{
'json'
:
JSONRenderer
,
'form'
:
MultiPartRenderer
}
default_format
=
'form'
def
__init__
(
self
,
format
=
None
,
**
defaults
):
self
.
format
=
format
or
self
.
default_format
super
(
APIRequestFactory
,
self
)
.
__init__
(
**
defaults
)
def
_encode_data
(
self
,
data
,
format
=
None
,
content_type
=
None
):
"""
Encode the data returning a two tuple of (bytes, content_type)
"""
def
_encode_data
(
self
,
data
,
format
,
content_type
):
if
not
data
:
return
(
''
,
None
)
format
=
format
or
self
.
format
assert
format
is
None
or
content_type
is
None
,
(
'You may not set both `format` and `content_type`.'
)
if
content_type
is
None
and
data
is
not
None
:
if
content_type
:
# Content type specified explicitly, treat data as a raw bytestring
ret
=
force_bytes_or_smart_bytes
(
data
,
settings
.
DEFAULT_CHARSET
)
else
:
# Use format and render the data into a bytestring
format
=
format
or
self
.
default_format
renderer
=
self
.
renderer_classes
[
format
]()
data
=
renderer
.
render
(
data
)
# Determine the content-type header
ret
=
renderer
.
render
(
data
)
# Determine the content-type header from the renderer
if
';'
in
renderer
.
media_type
:
content_type
=
renderer
.
media_type
else
:
content_type
=
"{0}; charset={1}"
.
format
(
renderer
.
media_type
,
renderer
.
charset
)
# Coerce text to bytes if required.
if
isinstance
(
data
,
six
.
text_type
):
data
=
bytes
(
data
.
encode
(
renderer
.
charset
))
if
isinstance
(
ret
,
six
.
text_type
):
ret
=
bytes
(
ret
.
encode
(
renderer
.
charset
))
return
data
,
content_type
return
ret
,
content_type
def
post
(
self
,
path
,
data
=
None
,
format
=
None
,
content_type
=
None
,
**
extra
):
data
,
content_type
=
self
.
_encode_data
(
data
,
format
,
content_type
)
...
...
@@ -46,3 +61,43 @@ class APIRequestFactory(RequestFactory):
def
patch
(
self
,
path
,
data
=
None
,
format
=
None
,
content_type
=
None
,
**
extra
):
data
,
content_type
=
self
.
_encode_data
(
data
,
format
,
content_type
)
return
self
.
generic
(
'PATCH'
,
path
,
data
,
content_type
,
**
extra
)
def
delete
(
self
,
path
,
data
=
None
,
format
=
None
,
content_type
=
None
,
**
extra
):
data
,
content_type
=
self
.
_encode_data
(
data
,
format
,
content_type
)
return
self
.
generic
(
'DELETE'
,
path
,
data
,
content_type
,
**
extra
)
def
options
(
self
,
path
,
data
=
None
,
format
=
None
,
content_type
=
None
,
**
extra
):
data
,
content_type
=
self
.
_encode_data
(
data
,
format
,
content_type
)
return
self
.
generic
(
'OPTIONS'
,
path
,
data
,
content_type
,
**
extra
)
class
APIClient
(
APIRequestFactory
,
DjangoClient
):
def
post
(
self
,
path
,
data
=
None
,
format
=
None
,
content_type
=
None
,
follow
=
False
,
**
extra
):
response
=
super
(
APIClient
,
self
)
.
post
(
path
,
data
=
data
,
format
=
format
,
content_type
=
content_type
,
**
extra
)
if
follow
:
response
=
self
.
_handle_redirects
(
response
,
**
extra
)
return
response
def
put
(
self
,
path
,
data
=
None
,
format
=
None
,
content_type
=
None
,
follow
=
False
,
**
extra
):
response
=
super
(
APIClient
,
self
)
.
post
(
path
,
data
=
data
,
format
=
format
,
content_type
=
content_type
,
**
extra
)
if
follow
:
response
=
self
.
_handle_redirects
(
response
,
**
extra
)
return
response
def
patch
(
self
,
path
,
data
=
None
,
format
=
None
,
content_type
=
None
,
follow
=
False
,
**
extra
):
response
=
super
(
APIClient
,
self
)
.
post
(
path
,
data
=
data
,
format
=
format
,
content_type
=
content_type
,
**
extra
)
if
follow
:
response
=
self
.
_handle_redirects
(
response
,
**
extra
)
return
response
def
delete
(
self
,
path
,
data
=
None
,
format
=
None
,
content_type
=
None
,
follow
=
False
,
**
extra
):
response
=
super
(
APIClient
,
self
)
.
post
(
path
,
data
=
data
,
format
=
format
,
content_type
=
content_type
,
**
extra
)
if
follow
:
response
=
self
.
_handle_redirects
(
response
,
**
extra
)
return
response
def
options
(
self
,
path
,
data
=
None
,
format
=
None
,
content_type
=
None
,
follow
=
False
,
**
extra
):
response
=
super
(
APIClient
,
self
)
.
post
(
path
,
data
=
data
,
format
=
format
,
content_type
=
content_type
,
**
extra
)
if
follow
:
response
=
self
.
_handle_redirects
(
response
,
**
extra
)
return
response
rest_framework/tests/test_authentication.py
View file @
f585480e
from
__future__
import
unicode_literals
from
django.contrib.auth.models
import
User
from
django.http
import
HttpResponse
from
django.test
import
Client
,
TestCase
from
django.test
import
TestCase
from
django.utils
import
unittest
from
rest_framework
import
HTTP_HEADER_ENCODING
from
rest_framework
import
exceptions
...
...
@@ -21,12 +21,11 @@ from rest_framework.authtoken.models import Token
from
rest_framework.compat
import
patterns
,
url
,
include
from
rest_framework.compat
import
oauth2_provider
,
oauth2_provider_models
,
oauth2_provider_scope
from
rest_framework.compat
import
oauth
,
oauth_provider
from
rest_framework.test
import
APIRequestFactory
from
rest_framework.test
import
APIRequestFactory
,
APIClient
from
rest_framework.views
import
APIView
import
base64
import
time
import
datetime
import
json
factory
=
APIRequestFactory
()
...
...
@@ -68,7 +67,7 @@ class BasicAuthTests(TestCase):
urls
=
'rest_framework.tests.test_authentication'
def
setUp
(
self
):
self
.
csrf_client
=
Client
(
enforce_csrf_checks
=
True
)
self
.
csrf_client
=
API
Client
(
enforce_csrf_checks
=
True
)
self
.
username
=
'john'
self
.
email
=
'lennon@thebeatles.com'
self
.
password
=
'password'
...
...
@@ -87,7 +86,7 @@ class BasicAuthTests(TestCase):
credentials
=
(
'
%
s:
%
s'
%
(
self
.
username
,
self
.
password
))
base64_credentials
=
base64
.
b64encode
(
credentials
.
encode
(
HTTP_HEADER_ENCODING
))
.
decode
(
HTTP_HEADER_ENCODING
)
auth
=
'Basic
%
s'
%
base64_credentials
response
=
self
.
csrf_client
.
post
(
'/basic/'
,
json
.
dumps
({
'example'
:
'example'
}),
'application/
json'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/basic/'
,
{
'example'
:
'example'
},
format
=
'
json'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
def
test_post_form_failing_basic_auth
(
self
):
...
...
@@ -97,7 +96,7 @@ class BasicAuthTests(TestCase):
def
test_post_json_failing_basic_auth
(
self
):
"""Ensure POSTing json over basic auth without correct credentials fails"""
response
=
self
.
csrf_client
.
post
(
'/basic/'
,
json
.
dumps
({
'example'
:
'example'
}),
'application/
json'
)
response
=
self
.
csrf_client
.
post
(
'/basic/'
,
{
'example'
:
'example'
},
format
=
'
json'
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_401_UNAUTHORIZED
)
self
.
assertEqual
(
response
[
'WWW-Authenticate'
],
'Basic realm="api"'
)
...
...
@@ -107,8 +106,8 @@ class SessionAuthTests(TestCase):
urls
=
'rest_framework.tests.test_authentication'
def
setUp
(
self
):
self
.
csrf_client
=
Client
(
enforce_csrf_checks
=
True
)
self
.
non_csrf_client
=
Client
(
enforce_csrf_checks
=
False
)
self
.
csrf_client
=
API
Client
(
enforce_csrf_checks
=
True
)
self
.
non_csrf_client
=
API
Client
(
enforce_csrf_checks
=
False
)
self
.
username
=
'john'
self
.
email
=
'lennon@thebeatles.com'
self
.
password
=
'password'
...
...
@@ -154,7 +153,7 @@ class TokenAuthTests(TestCase):
urls
=
'rest_framework.tests.test_authentication'
def
setUp
(
self
):
self
.
csrf_client
=
Client
(
enforce_csrf_checks
=
True
)
self
.
csrf_client
=
API
Client
(
enforce_csrf_checks
=
True
)
self
.
username
=
'john'
self
.
email
=
'lennon@thebeatles.com'
self
.
password
=
'password'
...
...
@@ -172,7 +171,7 @@ class TokenAuthTests(TestCase):
def
test_post_json_passing_token_auth
(
self
):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth
=
"Token "
+
self
.
key
response
=
self
.
csrf_client
.
post
(
'/token/'
,
json
.
dumps
({
'example'
:
'example'
}),
'application/
json'
,
HTTP_AUTHORIZATION
=
auth
)
response
=
self
.
csrf_client
.
post
(
'/token/'
,
{
'example'
:
'example'
},
format
=
'
json'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
def
test_post_form_failing_token_auth
(
self
):
...
...
@@ -182,7 +181,7 @@ class TokenAuthTests(TestCase):
def
test_post_json_failing_token_auth
(
self
):
"""Ensure POSTing json over token auth without correct credentials fails"""
response
=
self
.
csrf_client
.
post
(
'/token/'
,
json
.
dumps
({
'example'
:
'example'
}),
'application/
json'
)
response
=
self
.
csrf_client
.
post
(
'/token/'
,
{
'example'
:
'example'
},
format
=
'
json'
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_401_UNAUTHORIZED
)
def
test_token_has_auto_assigned_key_if_none_provided
(
self
):
...
...
@@ -193,33 +192,33 @@ class TokenAuthTests(TestCase):
def
test_token_login_json
(
self
):
"""Ensure token login view using JSON POST works."""
client
=
Client
(
enforce_csrf_checks
=
True
)
client
=
API
Client
(
enforce_csrf_checks
=
True
)
response
=
client
.
post
(
'/auth-token/'
,
json
.
dumps
({
'username'
:
self
.
username
,
'password'
:
self
.
password
}),
'application/
json'
)
{
'username'
:
self
.
username
,
'password'
:
self
.
password
},
format
=
'
json'
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
self
.
assertEqual
(
json
.
loads
(
response
.
content
.
decode
(
'ascii'
))
[
'token'
],
self
.
key
)
self
.
assertEqual
(
response
.
data
[
'token'
],
self
.
key
)
def
test_token_login_json_bad_creds
(
self
):
"""Ensure token login view using JSON POST fails if bad credentials are used."""
client
=
Client
(
enforce_csrf_checks
=
True
)
client
=
API
Client
(
enforce_csrf_checks
=
True
)
response
=
client
.
post
(
'/auth-token/'
,
json
.
dumps
({
'username'
:
self
.
username
,
'password'
:
"badpass"
}),
'application/
json'
)
{
'username'
:
self
.
username
,
'password'
:
"badpass"
},
format
=
'
json'
)
self
.
assertEqual
(
response
.
status_code
,
400
)
def
test_token_login_json_missing_fields
(
self
):
"""Ensure token login view using JSON POST fails if missing fields."""
client
=
Client
(
enforce_csrf_checks
=
True
)
client
=
API
Client
(
enforce_csrf_checks
=
True
)
response
=
client
.
post
(
'/auth-token/'
,
json
.
dumps
({
'username'
:
self
.
username
}),
'application/
json'
)
{
'username'
:
self
.
username
},
format
=
'
json'
)
self
.
assertEqual
(
response
.
status_code
,
400
)
def
test_token_login_form
(
self
):
"""Ensure token login view using form POST works."""
client
=
Client
(
enforce_csrf_checks
=
True
)
client
=
API
Client
(
enforce_csrf_checks
=
True
)
response
=
client
.
post
(
'/auth-token/'
,
{
'username'
:
self
.
username
,
'password'
:
self
.
password
})
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
self
.
assertEqual
(
json
.
loads
(
response
.
content
.
decode
(
'ascii'
))
[
'token'
],
self
.
key
)
self
.
assertEqual
(
response
.
data
[
'token'
],
self
.
key
)
class
IncorrectCredentialsTests
(
TestCase
):
...
...
@@ -256,7 +255,7 @@ class OAuthTests(TestCase):
self
.
consts
=
consts
self
.
csrf_client
=
Client
(
enforce_csrf_checks
=
True
)
self
.
csrf_client
=
API
Client
(
enforce_csrf_checks
=
True
)
self
.
username
=
'john'
self
.
email
=
'lennon@thebeatles.com'
self
.
password
=
'password'
...
...
@@ -470,12 +469,13 @@ class OAuthTests(TestCase):
response
=
self
.
csrf_client
.
post
(
'/oauth/'
,
HTTP_AUTHORIZATION
=
auth
)
self
.
assertEqual
(
response
.
status_code
,
401
)
class
OAuth2Tests
(
TestCase
):
"""OAuth 2.0 authentication"""
urls
=
'rest_framework.tests.test_authentication'
def
setUp
(
self
):
self
.
csrf_client
=
Client
(
enforce_csrf_checks
=
True
)
self
.
csrf_client
=
API
Client
(
enforce_csrf_checks
=
True
)
self
.
username
=
'john'
self
.
email
=
'lennon@thebeatles.com'
self
.
password
=
'password'
...
...
rest_framework/tests/test_request.py
View file @
f585480e
...
...
@@ -5,7 +5,7 @@ from __future__ import unicode_literals
from
django.contrib.auth.models
import
User
from
django.contrib.auth
import
authenticate
,
login
,
logout
from
django.contrib.sessions.middleware
import
SessionMiddleware
from
django.test
import
TestCase
,
Client
from
django.test
import
TestCase
from
rest_framework
import
status
from
rest_framework.authentication
import
SessionAuthentication
from
rest_framework.compat
import
patterns
...
...
@@ -18,7 +18,7 @@ from rest_framework.parsers import (
from
rest_framework.request
import
Request
from
rest_framework.response
import
Response
from
rest_framework.settings
import
api_settings
from
rest_framework.test
import
APIRequestFactory
from
rest_framework.test
import
APIRequestFactory
,
APIClient
from
rest_framework.views
import
APIView
from
rest_framework.compat
import
six
import
json
...
...
@@ -248,7 +248,7 @@ class TestContentParsingWithAuthentication(TestCase):
urls
=
'rest_framework.tests.test_request'
def
setUp
(
self
):
self
.
csrf_client
=
Client
(
enforce_csrf_checks
=
True
)
self
.
csrf_client
=
API
Client
(
enforce_csrf_checks
=
True
)
self
.
username
=
'john'
self
.
email
=
'lennon@thebeatles.com'
self
.
password
=
'password'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment