Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
django-rest-framework
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
edx
django-rest-framework
Commits
664f8c63
Commit
664f8c63
authored
Jun 29, 2013
by
Tom Christie
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added APIClient.authenticate()
parent
35022ca9
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
95 additions
and
8 deletions
+95
-8
rest_framework/renderers.py
+1
-1
rest_framework/request.py
+20
-0
rest_framework/test.py
+35
-4
rest_framework/tests/test_testing.py
+39
-3
No files found.
rest_framework/renderers.py
View file @
664f8c63
...
@@ -576,7 +576,7 @@ class BrowsableAPIRenderer(BaseRenderer):
...
@@ -576,7 +576,7 @@ class BrowsableAPIRenderer(BaseRenderer):
class
MultiPartRenderer
(
BaseRenderer
):
class
MultiPartRenderer
(
BaseRenderer
):
media_type
=
'multipart/form-data; boundary=BoUnDaRyStRiNg'
media_type
=
'multipart/form-data; boundary=BoUnDaRyStRiNg'
format
=
'
form
'
format
=
'
multipart
'
charset
=
'utf-8'
charset
=
'utf-8'
BOUNDARY
=
'BoUnDaRyStRiNg'
BOUNDARY
=
'BoUnDaRyStRiNg'
...
...
rest_framework/request.py
View file @
664f8c63
...
@@ -64,6 +64,20 @@ def clone_request(request, method):
...
@@ -64,6 +64,20 @@ def clone_request(request, method):
return
ret
return
ret
class
ForcedAuthentication
(
object
):
"""
This authentication class is used if the test client or request factory
forcibly authenticated the request.
"""
def
__init__
(
self
,
force_user
,
force_token
):
self
.
force_user
=
force_user
self
.
force_token
=
force_token
def
authenticate
(
self
,
request
):
return
(
self
.
force_user
,
self
.
force_token
)
class
Request
(
object
):
class
Request
(
object
):
"""
"""
Wrapper allowing to enhance a standard `HttpRequest` instance.
Wrapper allowing to enhance a standard `HttpRequest` instance.
...
@@ -98,6 +112,12 @@ class Request(object):
...
@@ -98,6 +112,12 @@ class Request(object):
self
.
parser_context
[
'request'
]
=
self
self
.
parser_context
[
'request'
]
=
self
self
.
parser_context
[
'encoding'
]
=
request
.
encoding
or
settings
.
DEFAULT_CHARSET
self
.
parser_context
[
'encoding'
]
=
request
.
encoding
or
settings
.
DEFAULT_CHARSET
force_user
=
getattr
(
request
,
'_force_auth_user'
,
None
)
force_token
=
getattr
(
request
,
'_force_auth_token'
,
None
)
if
(
force_user
is
not
None
or
force_token
is
not
None
):
forced_auth
=
ForcedAuthentication
(
force_user
,
force_token
)
self
.
authenticators
=
(
forced_auth
,)
def
_default_negotiator
(
self
):
def
_default_negotiator
(
self
):
return
api_settings
.
DEFAULT_CONTENT_NEGOTIATION_CLASS
()
return
api_settings
.
DEFAULT_CONTENT_NEGOTIATION_CLASS
()
...
...
rest_framework/test.py
View file @
664f8c63
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
from
__future__
import
unicode_literals
from
__future__
import
unicode_literals
from
django.conf
import
settings
from
django.conf
import
settings
from
django.test.client
import
Client
as
DjangoClient
from
django.test.client
import
Client
as
DjangoClient
from
django.test.client
import
ClientHandler
from
rest_framework.compat
import
RequestFactory
as
DjangoRequestFactory
from
rest_framework.compat
import
RequestFactory
as
DjangoRequestFactory
from
rest_framework.compat
import
force_bytes_or_smart_bytes
,
six
from
rest_framework.compat
import
force_bytes_or_smart_bytes
,
six
from
rest_framework.renderers
import
JSONRenderer
,
MultiPartRenderer
from
rest_framework.renderers
import
JSONRenderer
,
MultiPartRenderer
...
@@ -13,9 +14,9 @@ from rest_framework.renderers import JSONRenderer, MultiPartRenderer
...
@@ -13,9 +14,9 @@ from rest_framework.renderers import JSONRenderer, MultiPartRenderer
class
APIRequestFactory
(
DjangoRequestFactory
):
class
APIRequestFactory
(
DjangoRequestFactory
):
renderer_classes
=
{
renderer_classes
=
{
'json'
:
JSONRenderer
,
'json'
:
JSONRenderer
,
'
form
'
:
MultiPartRenderer
'
multipart
'
:
MultiPartRenderer
}
}
default_format
=
'
form
'
default_format
=
'
multipart
'
def
_encode_data
(
self
,
data
,
format
=
None
,
content_type
=
None
):
def
_encode_data
(
self
,
data
,
format
=
None
,
content_type
=
None
):
"""
"""
...
@@ -74,14 +75,44 @@ class APIRequestFactory(DjangoRequestFactory):
...
@@ -74,14 +75,44 @@ class APIRequestFactory(DjangoRequestFactory):
return
self
.
generic
(
'OPTIONS'
,
path
,
data
,
content_type
,
**
extra
)
return
self
.
generic
(
'OPTIONS'
,
path
,
data
,
content_type
,
**
extra
)
class
APIClient
(
APIRequestFactory
,
DjangoClient
):
class
ForceAuthClientHandler
(
ClientHandler
):
"""
A patched version of ClientHandler that can enforce authentication
on the outgoing requests.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_force_auth_user
=
None
self
.
_force_auth_token
=
None
super
(
ForceAuthClientHandler
,
self
)
.
__init__
(
*
args
,
**
kwargs
)
def
force_authenticate
(
self
,
user
=
None
,
token
=
None
):
self
.
_force_auth_user
=
user
self
.
_force_auth_token
=
token
def
get_response
(
self
,
request
):
# This is the simplest place we can hook into to patch the
# request object.
request
.
_force_auth_user
=
self
.
_force_auth_user
request
.
_force_auth_token
=
self
.
_force_auth_token
return
super
(
ForceAuthClientHandler
,
self
)
.
get_response
(
request
)
class
APIClient
(
APIRequestFactory
,
DjangoClient
):
def
__init__
(
self
,
enforce_csrf_checks
=
False
,
**
defaults
):
# Note that our super call skips Client.__init__
# since we don't need to instantiate a regular ClientHandler
super
(
DjangoClient
,
self
)
.
__init__
(
**
defaults
)
self
.
handler
=
ForceAuthClientHandler
(
enforce_csrf_checks
)
self
.
exc_info
=
None
self
.
_credentials
=
{}
self
.
_credentials
=
{}
super
(
APIClient
,
self
)
.
__init__
(
*
args
,
**
kwargs
)
def
credentials
(
self
,
**
kwargs
):
def
credentials
(
self
,
**
kwargs
):
self
.
_credentials
=
kwargs
self
.
_credentials
=
kwargs
def
authenticate
(
self
,
user
=
None
,
token
=
None
):
self
.
handler
.
force_authenticate
(
user
,
token
)
def
get
(
self
,
path
,
data
=
{},
follow
=
False
,
**
extra
):
def
get
(
self
,
path
,
data
=
{},
follow
=
False
,
**
extra
):
extra
.
update
(
self
.
_credentials
)
extra
.
update
(
self
.
_credentials
)
response
=
super
(
APIClient
,
self
)
.
get
(
path
,
data
=
data
,
**
extra
)
response
=
super
(
APIClient
,
self
)
.
get
(
path
,
data
=
data
,
**
extra
)
...
...
rest_framework/tests/test_testing.py
View file @
664f8c63
# -- coding: utf-8 --
# -- coding: utf-8 --
from
__future__
import
unicode_literals
from
__future__
import
unicode_literals
from
django.contrib.auth.models
import
User
from
django.test
import
TestCase
from
django.test
import
TestCase
from
rest_framework.compat
import
patterns
,
url
from
rest_framework.compat
import
patterns
,
url
from
rest_framework.decorators
import
api_view
from
rest_framework.decorators
import
api_view
...
@@ -8,10 +9,11 @@ from rest_framework.response import Response
...
@@ -8,10 +9,11 @@ from rest_framework.response import Response
from
rest_framework.test
import
APIClient
from
rest_framework.test
import
APIClient
@api_view
([
'GET'
])
@api_view
([
'GET'
,
'POST'
])
def
mirror
(
request
):
def
mirror
(
request
):
return
Response
({
return
Response
({
'auth'
:
request
.
META
.
get
(
'HTTP_AUTHORIZATION'
,
b
''
)
'auth'
:
request
.
META
.
get
(
'HTTP_AUTHORIZATION'
,
b
''
),
'user'
:
request
.
user
.
username
})
})
...
@@ -27,6 +29,40 @@ class CheckTestClient(TestCase):
...
@@ -27,6 +29,40 @@ class CheckTestClient(TestCase):
self
.
client
=
APIClient
()
self
.
client
=
APIClient
()
def
test_credentials
(
self
):
def
test_credentials
(
self
):
"""
Setting `.credentials()` adds the required headers to each request.
"""
self
.
client
.
credentials
(
HTTP_AUTHORIZATION
=
'example'
)
self
.
client
.
credentials
(
HTTP_AUTHORIZATION
=
'example'
)
for
_
in
range
(
0
,
3
):
response
=
self
.
client
.
get
(
'/view/'
)
self
.
assertEqual
(
response
.
data
[
'auth'
],
'example'
)
def
test_authenticate
(
self
):
"""
Setting `.authenticate()` forcibly authenticates each request.
"""
user
=
User
.
objects
.
create_user
(
'example'
,
'example@example.com'
)
self
.
client
.
authenticate
(
user
)
response
=
self
.
client
.
get
(
'/view/'
)
response
=
self
.
client
.
get
(
'/view/'
)
self
.
assertEqual
(
response
.
data
[
'auth'
],
'example'
)
self
.
assertEqual
(
response
.
data
[
'user'
],
'example'
)
def
test_csrf_exempt_by_default
(
self
):
"""
By default, the test client is CSRF exempt.
"""
User
.
objects
.
create_user
(
'example'
,
'example@example.com'
,
'password'
)
self
.
client
.
login
(
username
=
'example'
,
password
=
'password'
)
response
=
self
.
client
.
post
(
'/view/'
)
self
.
assertEqual
(
response
.
status_code
,
200
)
def
test_explicitly_enforce_csrf_checks
(
self
):
"""
The test client can enforce CSRF checks.
"""
client
=
APIClient
(
enforce_csrf_checks
=
True
)
User
.
objects
.
create_user
(
'example'
,
'example@example.com'
,
'password'
)
client
.
login
(
username
=
'example'
,
password
=
'password'
)
response
=
client
.
post
(
'/view/'
)
expected
=
{
'detail'
:
'CSRF Failed: CSRF cookie not set.'
}
self
.
assertEqual
(
response
.
status_code
,
403
)
self
.
assertEqual
(
response
.
data
,
expected
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment