Commit 99794773 by Tom Christie

Merge pull request #962 from tomchristie/test-client

APIClient and APIRequestFactory
parents a890116a 7398464b
...@@ -217,6 +217,16 @@ Renders data into HTML for the Browsable API. This renderer will determine whic ...@@ -217,6 +217,16 @@ Renders data into HTML for the Browsable API. This renderer will determine whic
**.charset**: `utf-8` **.charset**: `utf-8`
## MultiPartRenderer
This renderer is used for rendering HTML multipart form data. **It is not suitable as a response renderer**, but is instead used for creating test requests, using REST framework's [test client and test request factory][testing].
**.media_type**: `multipart/form-data; boundary=BoUnDaRyStRiNg`
**.format**: `'.multipart'`
**.charset**: `utf-8`
--- ---
# Custom renderers # Custom renderers
...@@ -373,6 +383,7 @@ Comma-separated values are a plain-text tabular data format, that can be easily ...@@ -373,6 +383,7 @@ Comma-separated values are a plain-text tabular data format, that can be easily
[rfc4627]: http://www.ietf.org/rfc/rfc4627.txt [rfc4627]: http://www.ietf.org/rfc/rfc4627.txt
[cors]: http://www.w3.org/TR/cors/ [cors]: http://www.w3.org/TR/cors/
[cors-docs]: ../topics/ajax-csrf-cors.md [cors-docs]: ../topics/ajax-csrf-cors.md
[testing]: testing.md
[HATEOAS]: http://timelessrepo.com/haters-gonna-hateoas [HATEOAS]: http://timelessrepo.com/haters-gonna-hateoas
[quote]: http://roy.gbiv.com/untangled/2008/rest-apis-must-be-hypertext-driven [quote]: http://roy.gbiv.com/untangled/2008/rest-apis-must-be-hypertext-driven
[application/vnd.github+json]: http://developer.github.com/v3/media/ [application/vnd.github+json]: http://developer.github.com/v3/media/
......
...@@ -149,6 +149,33 @@ Default: `None` ...@@ -149,6 +149,33 @@ Default: `None`
--- ---
## Test settings
*The following settings control the behavior of APIRequestFactory and APIClient*
#### TEST_REQUEST_DEFAULT_FORMAT
The default format that should be used when making test requests.
This should match up with the format of one of the renderer classes in the `TEST_REQUEST_RENDERER_CLASSES` setting.
Default: `'multipart'`
#### TEST_REQUEST_RENDERER_CLASSES
The renderer classes that are supported when building test requests.
The format of any of these renderer classes may be used when contructing a test request, for example: `client.post('/users', {'username': 'jamie'}, format='json')`
Default:
(
'rest_framework.renderers.MultiPartRenderer',
'rest_framework.renderers.JSONRenderer'
)
---
## Browser overrides ## Browser overrides
*The following settings provide URL or form-based overrides of the default browser behavior.* *The following settings provide URL or form-based overrides of the default browser behavior.*
......
...@@ -26,6 +26,12 @@ def get_authorization_header(request): ...@@ -26,6 +26,12 @@ def get_authorization_header(request):
return auth return auth
class CSRFCheck(CsrfViewMiddleware):
def _reject(self, request, reason):
# Return the failure reason instead of an HttpResponse
return reason
class BaseAuthentication(object): class BaseAuthentication(object):
""" """
All authentication classes should extend BaseAuthentication. All authentication classes should extend BaseAuthentication.
...@@ -103,27 +109,27 @@ class SessionAuthentication(BaseAuthentication): ...@@ -103,27 +109,27 @@ class SessionAuthentication(BaseAuthentication):
""" """
# Get the underlying HttpRequest object # Get the underlying HttpRequest object
http_request = request._request request = request._request
user = getattr(http_request, 'user', None) user = getattr(request, 'user', None)
# Unauthenticated, CSRF validation not required # Unauthenticated, CSRF validation not required
if not user or not user.is_active: if not user or not user.is_active:
return None return None
# Enforce CSRF validation for session based authentication. self.enforce_csrf(request)
class CSRFCheck(CsrfViewMiddleware):
def _reject(self, request, reason):
# Return the failure reason instead of an HttpResponse
return reason
reason = CSRFCheck().process_view(http_request, None, (), {}) # CSRF passed with authenticated user
return (user, None)
def enforce_csrf(self, request):
"""
Enforce CSRF validation for session based authentication.
"""
reason = CSRFCheck().process_view(request, None, (), {})
if reason: if reason:
# CSRF failed, bail with explicit error message # CSRF failed, bail with explicit error message
raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason)
# CSRF passed with authenticated user
return (user, None)
class TokenAuthentication(BaseAuthentication): class TokenAuthentication(BaseAuthentication):
""" """
......
...@@ -8,6 +8,7 @@ from __future__ import unicode_literals ...@@ -8,6 +8,7 @@ from __future__ import unicode_literals
import django import django
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
# Try to import six from Django, fallback to included `six`. # Try to import six from Django, fallback to included `six`.
try: try:
...@@ -83,7 +84,6 @@ def get_concrete_model(model_cls): ...@@ -83,7 +84,6 @@ def get_concrete_model(model_cls):
# Django 1.5 add support for custom auth user model # Django 1.5 add support for custom auth user model
if django.VERSION >= (1, 5): if django.VERSION >= (1, 5):
from django.conf import settings
AUTH_USER_MODEL = settings.AUTH_USER_MODEL AUTH_USER_MODEL = settings.AUTH_USER_MODEL
else: else:
AUTH_USER_MODEL = 'auth.User' AUTH_USER_MODEL = 'auth.User'
...@@ -436,6 +436,42 @@ except ImportError: ...@@ -436,6 +436,42 @@ except ImportError:
return force_text(url) return force_text(url)
# RequestFactory only provide `generic` from 1.5 onwards
from django.test.client import RequestFactory as DjangoRequestFactory
from django.test.client import FakePayload
try:
# In 1.5 the test client uses force_bytes
from django.utils.encoding import force_bytes_or_smart_bytes
except ImportError:
# In 1.3 and 1.4 the test client just uses smart_str
from django.utils.encoding import smart_str as force_bytes_or_smart_bytes
class RequestFactory(DjangoRequestFactory):
def generic(self, method, path,
data='', content_type='application/octet-stream', **extra):
parsed = urlparse.urlparse(path)
data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
r = {
'PATH_INFO': self._get_path(parsed),
'QUERY_STRING': force_text(parsed[4]),
'REQUEST_METHOD': str(method),
}
if data:
r.update({
'CONTENT_LENGTH': len(data),
'CONTENT_TYPE': str(content_type),
'wsgi.input': FakePayload(data),
})
elif django.VERSION <= (1, 4):
# For 1.3 we need an empty WSGI payload
r.update({
'wsgi.input': FakePayload('')
})
r.update(extra)
return self.request(**r)
# Markdown is optional # Markdown is optional
try: try:
import markdown import markdown
......
...@@ -14,6 +14,7 @@ from django import forms ...@@ -14,6 +14,7 @@ from django import forms
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.http.multipartparser import parse_header from django.http.multipartparser import parse_header
from django.template import RequestContext, loader, Template from django.template import RequestContext, loader, Template
from django.test.client import encode_multipart
from django.utils.xmlutils import SimplerXMLGenerator from django.utils.xmlutils import SimplerXMLGenerator
from rest_framework.compat import StringIO from rest_framework.compat import StringIO
from rest_framework.compat import six from rest_framework.compat import six
...@@ -571,3 +572,13 @@ class BrowsableAPIRenderer(BaseRenderer): ...@@ -571,3 +572,13 @@ class BrowsableAPIRenderer(BaseRenderer):
response.status_code = status.HTTP_200_OK response.status_code = status.HTTP_200_OK
return ret return ret
class MultiPartRenderer(BaseRenderer):
media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg'
format = 'multipart'
charset = 'utf-8'
BOUNDARY = 'BoUnDaRyStRiNg'
def render(self, data, accepted_media_type=None, renderer_context=None):
return encode_multipart(self.BOUNDARY, data)
...@@ -64,6 +64,20 @@ def clone_request(request, method): ...@@ -64,6 +64,20 @@ def clone_request(request, method):
return ret return ret
class ForcedAuthentication(object):
"""
This authentication class is used if the test client or request factory
forcibly authenticated the request.
"""
def __init__(self, force_user, force_token):
self.force_user = force_user
self.force_token = force_token
def authenticate(self, request):
return (self.force_user, self.force_token)
class Request(object): class Request(object):
""" """
Wrapper allowing to enhance a standard `HttpRequest` instance. Wrapper allowing to enhance a standard `HttpRequest` instance.
...@@ -98,6 +112,12 @@ class Request(object): ...@@ -98,6 +112,12 @@ class Request(object):
self.parser_context['request'] = self self.parser_context['request'] = self
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
force_user = getattr(request, '_force_auth_user', None)
force_token = getattr(request, '_force_auth_token', None)
if (force_user is not None or force_token is not None):
forced_auth = ForcedAuthentication(force_user, force_token)
self.authenticators = (forced_auth,)
def _default_negotiator(self): def _default_negotiator(self):
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
......
...@@ -73,6 +73,13 @@ DEFAULTS = { ...@@ -73,6 +73,13 @@ DEFAULTS = {
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None, 'UNAUTHENTICATED_TOKEN': None,
# Testing
'TEST_REQUEST_RENDERER_CLASSES': (
'rest_framework.renderers.MultiPartRenderer',
'rest_framework.renderers.JSONRenderer'
),
'TEST_REQUEST_DEFAULT_FORMAT': 'multipart',
# Browser enhancements # Browser enhancements
'FORM_METHOD_OVERRIDE': '_method', 'FORM_METHOD_OVERRIDE': '_method',
'FORM_CONTENT_OVERRIDE': '_content', 'FORM_CONTENT_OVERRIDE': '_content',
...@@ -115,6 +122,7 @@ IMPORT_STRINGS = ( ...@@ -115,6 +122,7 @@ IMPORT_STRINGS = (
'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_FILTER_BACKENDS', 'DEFAULT_FILTER_BACKENDS',
'FILTER_BACKEND', 'FILTER_BACKEND',
'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN', 'UNAUTHENTICATED_TOKEN',
) )
......
# -- coding: utf-8 --
# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
# to make it harder for the user to import the wrong thing without realizing.
from __future__ import unicode_literals
from django.conf import settings
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
from rest_framework.settings import api_settings
from rest_framework.compat import RequestFactory as DjangoRequestFactory
from rest_framework.compat import force_bytes_or_smart_bytes, six
def force_authenticate(request, user=None, token=None):
request._force_auth_user = user
request._force_auth_token = token
class APIRequestFactory(DjangoRequestFactory):
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
def __init__(self, enforce_csrf_checks=False, **defaults):
self.enforce_csrf_checks = enforce_csrf_checks
self.renderer_classes = {}
for cls in self.renderer_classes_list:
self.renderer_classes[cls.format] = cls
super(APIRequestFactory, self).__init__(**defaults)
def _encode_data(self, data, format=None, content_type=None):
"""
Encode the data returning a two tuple of (bytes, content_type)
"""
if not data:
return ('', None)
assert format is None or content_type is None, (
'You may not set both `format` and `content_type`.'
)
if content_type:
# Content type specified explicitly, treat data as a raw bytestring
ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
else:
format = format or self.default_format
assert format in self.renderer_classes, ("Invalid format '{0}'. "
"Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES "
"to enable extra request formats.".format(
format,
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])
)
)
# Use format and render the data into a bytestring
renderer = self.renderer_classes[format]()
ret = renderer.render(data)
# Determine the content-type header from the renderer
content_type = "{0}; charset={1}".format(
renderer.media_type, renderer.charset
)
# Coerce text to bytes if required.
if isinstance(ret, six.text_type):
ret = bytes(ret.encode(renderer.charset))
return ret, content_type
def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('POST', path, data, content_type, **extra)
def put(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('PUT', path, data, content_type, **extra)
def patch(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('PATCH', path, data, content_type, **extra)
def delete(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('DELETE', path, data, content_type, **extra)
def options(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('OPTIONS', path, data, content_type, **extra)
def request(self, **kwargs):
request = super(APIRequestFactory, self).request(**kwargs)
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
return request
class ForceAuthClientHandler(ClientHandler):
"""
A patched version of ClientHandler that can enforce authentication
on the outgoing requests.
"""
def __init__(self, *args, **kwargs):
self._force_user = None
self._force_token = None
super(ForceAuthClientHandler, self).__init__(*args, **kwargs)
def get_response(self, request):
# This is the simplest place we can hook into to patch the
# request object.
force_authenticate(request, self._force_user, self._force_token)
return super(ForceAuthClientHandler, self).get_response(request)
class APIClient(APIRequestFactory, DjangoClient):
def __init__(self, enforce_csrf_checks=False, **defaults):
super(APIClient, self).__init__(**defaults)
self.handler = ForceAuthClientHandler(enforce_csrf_checks)
self._credentials = {}
def credentials(self, **kwargs):
"""
Sets headers that will be used on every outgoing request.
"""
self._credentials = kwargs
def force_authenticate(self, user=None, token=None):
"""
Forcibly authenticates outgoing requests with the given
user and/or token.
"""
self.handler._force_user = user
self.handler._force_token = token
def request(self, **kwargs):
# Ensure that any credentials set get added to every request.
kwargs.update(self._credentials)
return super(APIClient, self).request(**kwargs)
from __future__ import unicode_literals from __future__ import unicode_literals
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.http import HttpResponse from django.http import HttpResponse
from django.test import Client, TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from rest_framework import HTTP_HEADER_ENCODING from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions from rest_framework import exceptions
...@@ -21,14 +21,13 @@ from rest_framework.authtoken.models import Token ...@@ -21,14 +21,13 @@ from rest_framework.authtoken.models import Token
from rest_framework.compat import patterns, url, include from rest_framework.compat import patterns, url, include
from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
from rest_framework.compat import oauth, oauth_provider from rest_framework.compat import oauth, oauth_provider
from rest_framework.tests.utils import RequestFactory from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView from rest_framework.views import APIView
import json
import base64 import base64
import time import time
import datetime import datetime
factory = RequestFactory() factory = APIRequestFactory()
class MockView(APIView): class MockView(APIView):
...@@ -68,7 +67,7 @@ class BasicAuthTests(TestCase): ...@@ -68,7 +67,7 @@ class BasicAuthTests(TestCase):
urls = 'rest_framework.tests.test_authentication' urls = 'rest_framework.tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
...@@ -87,7 +86,7 @@ class BasicAuthTests(TestCase): ...@@ -87,7 +86,7 @@ class BasicAuthTests(TestCase):
credentials = ('%s:%s' % (self.username, self.password)) credentials = ('%s:%s' % (self.username, self.password))
base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
auth = 'Basic %s' % base64_credentials auth = 'Basic %s' % base64_credentials
response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_failing_basic_auth(self): def test_post_form_failing_basic_auth(self):
...@@ -97,7 +96,7 @@ class BasicAuthTests(TestCase): ...@@ -97,7 +96,7 @@ class BasicAuthTests(TestCase):
def test_post_json_failing_basic_auth(self): def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails""" """Ensure POSTing json over basic auth without correct credentials fails"""
response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
...@@ -107,8 +106,8 @@ class SessionAuthTests(TestCase): ...@@ -107,8 +106,8 @@ class SessionAuthTests(TestCase):
urls = 'rest_framework.tests.test_authentication' urls = 'rest_framework.tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.non_csrf_client = Client(enforce_csrf_checks=False) self.non_csrf_client = APIClient(enforce_csrf_checks=False)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
...@@ -154,7 +153,7 @@ class TokenAuthTests(TestCase): ...@@ -154,7 +153,7 @@ class TokenAuthTests(TestCase):
urls = 'rest_framework.tests.test_authentication' urls = 'rest_framework.tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
...@@ -172,7 +171,7 @@ class TokenAuthTests(TestCase): ...@@ -172,7 +171,7 @@ class TokenAuthTests(TestCase):
def test_post_json_passing_token_auth(self): def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key auth = "Token " + self.key
response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_failing_token_auth(self): def test_post_form_failing_token_auth(self):
...@@ -182,7 +181,7 @@ class TokenAuthTests(TestCase): ...@@ -182,7 +181,7 @@ class TokenAuthTests(TestCase):
def test_post_json_failing_token_auth(self): def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails""" """Ensure POSTing json over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_token_has_auto_assigned_key_if_none_provided(self): def test_token_has_auto_assigned_key_if_none_provided(self):
...@@ -193,33 +192,33 @@ class TokenAuthTests(TestCase): ...@@ -193,33 +192,33 @@ class TokenAuthTests(TestCase):
def test_token_login_json(self): def test_token_login_json(self):
"""Ensure token login view using JSON POST works.""" """Ensure token login view using JSON POST works."""
client = Client(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
json.dumps({'username': self.username, 'password': self.password}), 'application/json') {'username': self.username, 'password': self.password}, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) self.assertEqual(response.data['token'], self.key)
def test_token_login_json_bad_creds(self): def test_token_login_json_bad_creds(self):
"""Ensure token login view using JSON POST fails if bad credentials are used.""" """Ensure token login view using JSON POST fails if bad credentials are used."""
client = Client(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
json.dumps({'username': self.username, 'password': "badpass"}), 'application/json') {'username': self.username, 'password': "badpass"}, format='json')
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
def test_token_login_json_missing_fields(self): def test_token_login_json_missing_fields(self):
"""Ensure token login view using JSON POST fails if missing fields.""" """Ensure token login view using JSON POST fails if missing fields."""
client = Client(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
json.dumps({'username': self.username}), 'application/json') {'username': self.username}, format='json')
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
def test_token_login_form(self): def test_token_login_form(self):
"""Ensure token login view using form POST works.""" """Ensure token login view using form POST works."""
client = Client(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
{'username': self.username, 'password': self.password}) {'username': self.username, 'password': self.password})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) self.assertEqual(response.data['token'], self.key)
class IncorrectCredentialsTests(TestCase): class IncorrectCredentialsTests(TestCase):
...@@ -256,7 +255,7 @@ class OAuthTests(TestCase): ...@@ -256,7 +255,7 @@ class OAuthTests(TestCase):
self.consts = consts self.consts = consts
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
...@@ -470,12 +469,13 @@ class OAuthTests(TestCase): ...@@ -470,12 +469,13 @@ class OAuthTests(TestCase):
response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401) self.assertEqual(response.status_code, 401)
class OAuth2Tests(TestCase): class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication""" """OAuth 2.0 authentication"""
urls = 'rest_framework.tests.test_authentication' urls = 'rest_framework.tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework import status from rest_framework import status
from rest_framework.authentication import BasicAuthentication
from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.renderers import JSONRenderer from rest_framework.renderers import JSONRenderer
from rest_framework.parsers import JSONParser from rest_framework.test import APIRequestFactory
from rest_framework.authentication import BasicAuthentication
from rest_framework.throttling import UserRateThrottle from rest_framework.throttling import UserRateThrottle
from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.decorators import ( from rest_framework.decorators import (
api_view, api_view,
...@@ -17,13 +18,11 @@ from rest_framework.decorators import ( ...@@ -17,13 +18,11 @@ from rest_framework.decorators import (
permission_classes, permission_classes,
) )
from rest_framework.tests.utils import RequestFactory
class DecoratorTestCase(TestCase): class DecoratorTestCase(TestCase):
def setUp(self): def setUp(self):
self.factory = RequestFactory() self.factory = APIRequestFactory()
def _finalize_response(self, request, response, *args, **kwargs): def _finalize_response(self, request, response, *args, **kwargs):
response.request = request response.request = request
......
...@@ -4,13 +4,13 @@ from decimal import Decimal ...@@ -4,13 +4,13 @@ from decimal import Decimal
from django.db import models from django.db import models
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest from django.utils import unittest
from rest_framework import generics, serializers, status, filters from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url from rest_framework.compat import django_filters, patterns, url
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel from rest_framework.tests.models import BasicModel
factory = RequestFactory() factory = APIRequestFactory()
class FilterableItem(models.Model): class FilterableItem(models.Model):
......
...@@ -3,12 +3,11 @@ from django.db import models ...@@ -3,12 +3,11 @@ from django.db import models
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, renderers, serializers, status from rest_framework import generics, renderers, serializers, status
from rest_framework.tests.utils import RequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
from rest_framework.compat import six from rest_framework.compat import six
import json
factory = RequestFactory() factory = APIRequestFactory()
class RootView(generics.ListCreateAPIView): class RootView(generics.ListCreateAPIView):
...@@ -71,9 +70,8 @@ class TestRootView(TestCase): ...@@ -71,9 +70,8 @@ class TestRootView(TestCase):
""" """
POST requests to ListCreateAPIView should create a new object. POST requests to ListCreateAPIView should create a new object.
""" """
content = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.post('/', json.dumps(content), request = factory.post('/', data, format='json')
content_type='application/json')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
...@@ -85,9 +83,8 @@ class TestRootView(TestCase): ...@@ -85,9 +83,8 @@ class TestRootView(TestCase):
""" """
PUT requests to ListCreateAPIView should not be allowed PUT requests to ListCreateAPIView should not be allowed
""" """
content = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/', json.dumps(content), request = factory.put('/', data, format='json')
content_type='application/json')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
...@@ -148,9 +145,8 @@ class TestRootView(TestCase): ...@@ -148,9 +145,8 @@ class TestRootView(TestCase):
""" """
POST requests to create a new object should not be able to set the id. POST requests to create a new object should not be able to set the id.
""" """
content = {'id': 999, 'text': 'foobar'} data = {'id': 999, 'text': 'foobar'}
request = factory.post('/', json.dumps(content), request = factory.post('/', data, format='json')
content_type='application/json')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
...@@ -189,9 +185,8 @@ class TestInstanceView(TestCase): ...@@ -189,9 +185,8 @@ class TestInstanceView(TestCase):
""" """
POST requests to RetrieveUpdateDestroyAPIView should not be allowed POST requests to RetrieveUpdateDestroyAPIView should not be allowed
""" """
content = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.post('/', json.dumps(content), request = factory.post('/', data, format='json')
content_type='application/json')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
...@@ -201,9 +196,8 @@ class TestInstanceView(TestCase): ...@@ -201,9 +196,8 @@ class TestInstanceView(TestCase):
""" """
PUT requests to RetrieveUpdateDestroyAPIView should update an object. PUT requests to RetrieveUpdateDestroyAPIView should update an object.
""" """
content = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content), request = factory.put('/1', data, format='json')
content_type='application/json')
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request, pk='1').render() response = self.view(request, pk='1').render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -215,9 +209,8 @@ class TestInstanceView(TestCase): ...@@ -215,9 +209,8 @@ class TestInstanceView(TestCase):
""" """
PATCH requests to RetrieveUpdateDestroyAPIView should update an object. PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
""" """
content = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.patch('/1', json.dumps(content), request = factory.patch('/1', data, format='json')
content_type='application/json')
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
...@@ -293,9 +286,8 @@ class TestInstanceView(TestCase): ...@@ -293,9 +286,8 @@ class TestInstanceView(TestCase):
""" """
PUT requests to create a new object should not be able to set the id. PUT requests to create a new object should not be able to set the id.
""" """
content = {'id': 999, 'text': 'foobar'} data = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content), request = factory.put('/1', data, format='json')
content_type='application/json')
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -309,9 +301,8 @@ class TestInstanceView(TestCase): ...@@ -309,9 +301,8 @@ class TestInstanceView(TestCase):
if it does not currently exist. if it does not currently exist.
""" """
self.objects.get(id=1).delete() self.objects.get(id=1).delete()
content = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content), request = factory.put('/1', data, format='json')
content_type='application/json')
with self.assertNumQueries(3): with self.assertNumQueries(3):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
...@@ -324,10 +315,9 @@ class TestInstanceView(TestCase): ...@@ -324,10 +315,9 @@ class TestInstanceView(TestCase):
PUT requests to RetrieveUpdateDestroyAPIView should create an object PUT requests to RetrieveUpdateDestroyAPIView should create an object
at the requested url if it doesn't exist. at the requested url if it doesn't exist.
""" """
content = {'text': 'foobar'} data = {'text': 'foobar'}
# pk fields can not be created on demand, only the database can set the pk for a new object # pk fields can not be created on demand, only the database can set the pk for a new object
request = factory.put('/5', json.dumps(content), request = factory.put('/5', data, format='json')
content_type='application/json')
with self.assertNumQueries(3): with self.assertNumQueries(3):
response = self.view(request, pk=5).render() response = self.view(request, pk=5).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
...@@ -339,9 +329,8 @@ class TestInstanceView(TestCase): ...@@ -339,9 +329,8 @@ class TestInstanceView(TestCase):
PUT requests to RetrieveUpdateDestroyAPIView should create an object PUT requests to RetrieveUpdateDestroyAPIView should create an object
at the requested url if possible, else return HTTP_403_FORBIDDEN error-response. at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
""" """
content = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/test_slug', json.dumps(content), request = factory.put('/test_slug', data, format='json')
content_type='application/json')
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.slug_based_view(request, slug='test_slug').render() response = self.slug_based_view(request, slug='test_slug').render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
...@@ -415,9 +404,8 @@ class TestCreateModelWithAutoNowAddField(TestCase): ...@@ -415,9 +404,8 @@ class TestCreateModelWithAutoNowAddField(TestCase):
https://github.com/tomchristie/django-rest-framework/issues/285 https://github.com/tomchristie/django-rest-framework/issues/285
""" """
content = {'email': 'foobar@example.com', 'content': 'foobar'} data = {'email': 'foobar@example.com', 'content': 'foobar'}
request = factory.post('/', json.dumps(content), request = factory.post('/', data, format='json')
content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
created = self.objects.get(id=1) created = self.objects.get(id=1)
......
from __future__ import unicode_literals from __future__ import unicode_literals
import json import json
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status, serializers from rest_framework import generics, status, serializers
from rest_framework.compat import patterns, url from rest_framework.compat import patterns, url
from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import (
Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment,
Album, Photo, OptionalRelationModel
)
factory = RequestFactory() factory = APIRequestFactory()
class BlogPostCommentSerializer(serializers.ModelSerializer): class BlogPostCommentSerializer(serializers.ModelSerializer):
...@@ -21,7 +24,7 @@ class BlogPostCommentSerializer(serializers.ModelSerializer): ...@@ -21,7 +24,7 @@ class BlogPostCommentSerializer(serializers.ModelSerializer):
class PhotoSerializer(serializers.Serializer): class PhotoSerializer(serializers.Serializer):
description = serializers.CharField() description = serializers.CharField()
album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title') album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title')
def restore_object(self, attrs, instance=None): def restore_object(self, attrs, instance=None):
return Photo(**attrs) return Photo(**attrs)
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework.negotiation import DefaultContentNegotiation from rest_framework.negotiation import DefaultContentNegotiation
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.renderers import BaseRenderer from rest_framework.renderers import BaseRenderer
from rest_framework.test import APIRequestFactory
factory = RequestFactory() factory = APIRequestFactory()
class MockJSONRenderer(BaseRenderer): class MockJSONRenderer(BaseRenderer):
......
...@@ -4,13 +4,13 @@ from decimal import Decimal ...@@ -4,13 +4,13 @@ from decimal import Decimal
from django.db import models from django.db import models
from django.core.paginator import Paginator from django.core.paginator import Paginator
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel from rest_framework.tests.models import BasicModel
factory = RequestFactory() factory = APIRequestFactory()
class FilterableItem(models.Model): class FilterableItem(models.Model):
...@@ -369,7 +369,7 @@ class TestCustomPaginationSerializer(TestCase): ...@@ -369,7 +369,7 @@ class TestCustomPaginationSerializer(TestCase):
self.page = paginator.page(1) self.page = paginator.page(1)
def test_custom_pagination_serializer(self): def test_custom_pagination_serializer(self):
request = RequestFactory().get('/foobar') request = APIRequestFactory().get('/foobar')
serializer = CustomPaginationSerializer( serializer = CustomPaginationSerializer(
instance=self.page, instance=self.page,
context={'request': request} context={'request': request}
......
...@@ -3,11 +3,10 @@ from django.contrib.auth.models import User, Permission ...@@ -3,11 +3,10 @@ from django.contrib.auth.models import User, Permission
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING
from rest_framework.tests.utils import RequestFactory from rest_framework.test import APIRequestFactory
import base64 import base64
import json
factory = RequestFactory() factory = APIRequestFactory()
class BasicModel(models.Model): class BasicModel(models.Model):
...@@ -56,15 +55,13 @@ class ModelPermissionsIntegrationTests(TestCase): ...@@ -56,15 +55,13 @@ class ModelPermissionsIntegrationTests(TestCase):
BasicModel(text='foo').save() BasicModel(text='foo').save()
def test_has_create_permissions(self): def test_has_create_permissions(self):
request = factory.post('/', json.dumps({'text': 'foobar'}), request = factory.post('/', {'text': 'foobar'}, format='json',
content_type='application/json',
HTTP_AUTHORIZATION=self.permitted_credentials) HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request, pk=1) response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
def test_has_put_permissions(self): def test_has_put_permissions(self):
request = factory.put('/1', json.dumps({'text': 'foobar'}), request = factory.put('/1', {'text': 'foobar'}, format='json',
content_type='application/json',
HTTP_AUTHORIZATION=self.permitted_credentials) HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -75,15 +72,13 @@ class ModelPermissionsIntegrationTests(TestCase): ...@@ -75,15 +72,13 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
def test_does_not_have_create_permissions(self): def test_does_not_have_create_permissions(self):
request = factory.post('/', json.dumps({'text': 'foobar'}), request = factory.post('/', {'text': 'foobar'}, format='json',
content_type='application/json',
HTTP_AUTHORIZATION=self.disallowed_credentials) HTTP_AUTHORIZATION=self.disallowed_credentials)
response = root_view(request, pk=1) response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_does_not_have_put_permissions(self): def test_does_not_have_put_permissions(self):
request = factory.put('/1', json.dumps({'text': 'foobar'}), request = factory.put('/1', {'text': 'foobar'}, format='json',
content_type='application/json',
HTTP_AUTHORIZATION=self.disallowed_credentials) HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
...@@ -95,28 +90,26 @@ class ModelPermissionsIntegrationTests(TestCase): ...@@ -95,28 +90,26 @@ class ModelPermissionsIntegrationTests(TestCase):
def test_has_put_as_create_permissions(self): def test_has_put_as_create_permissions(self):
# User only has update permissions - should be able to update an entity. # User only has update permissions - should be able to update an entity.
request = factory.put('/1', json.dumps({'text': 'foobar'}), request = factory.put('/1', {'text': 'foobar'}, format='json',
content_type='application/json',
HTTP_AUTHORIZATION=self.updateonly_credentials) HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
# But if PUTing to a new entity, permission should be denied. # But if PUTing to a new entity, permission should be denied.
request = factory.put('/2', json.dumps({'text': 'foobar'}), request = factory.put('/2', {'text': 'foobar'}, format='json',
content_type='application/json',
HTTP_AUTHORIZATION=self.updateonly_credentials) HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='2') response = instance_view(request, pk='2')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_options_permitted(self): def test_options_permitted(self):
request = factory.options('/', content_type='application/json', request = factory.options('/',
HTTP_AUTHORIZATION=self.permitted_credentials) HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['POST']) self.assertEqual(list(response.data['actions'].keys()), ['POST'])
request = factory.options('/1', content_type='application/json', request = factory.options('/1',
HTTP_AUTHORIZATION=self.permitted_credentials) HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -124,26 +117,26 @@ class ModelPermissionsIntegrationTests(TestCase): ...@@ -124,26 +117,26 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(list(response.data['actions'].keys()), ['PUT']) self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
def test_options_disallowed(self): def test_options_disallowed(self):
request = factory.options('/', content_type='application/json', request = factory.options('/',
HTTP_AUTHORIZATION=self.disallowed_credentials) HTTP_AUTHORIZATION=self.disallowed_credentials)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) self.assertNotIn('actions', response.data)
request = factory.options('/1', content_type='application/json', request = factory.options('/1',
HTTP_AUTHORIZATION=self.disallowed_credentials) HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) self.assertNotIn('actions', response.data)
def test_options_updateonly(self): def test_options_updateonly(self):
request = factory.options('/', content_type='application/json', request = factory.options('/',
HTTP_AUTHORIZATION=self.updateonly_credentials) HTTP_AUTHORIZATION=self.updateonly_credentials)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) self.assertNotIn('actions', response.data)
request = factory.options('/1', content_type='application/json', request = factory.options('/1',
HTTP_AUTHORIZATION=self.updateonly_credentials) HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import serializers from rest_framework import serializers
from rest_framework.compat import patterns, url from rest_framework.compat import patterns, url
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import ( from rest_framework.tests.models import (
BlogPost, BlogPost,
ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
) )
factory = RequestFactory() factory = APIRequestFactory()
request = factory.get('/') # Just to ensure we have a request in the serializer context request = factory.get('/') # Just to ensure we have a request in the serializer context
......
...@@ -4,19 +4,17 @@ from __future__ import unicode_literals ...@@ -4,19 +4,17 @@ from __future__ import unicode_literals
from decimal import Decimal from decimal import Decimal
from django.core.cache import cache from django.core.cache import cache
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest from django.utils import unittest
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import status, permissions from rest_framework import status, permissions
from rest_framework.compat import yaml, etree, patterns, url, include from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer
from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.compat import StringIO from rest_framework.test import APIRequestFactory
from rest_framework.compat import six
import datetime import datetime
import pickle import pickle
import re import re
...@@ -121,7 +119,7 @@ class POSTDeniedView(APIView): ...@@ -121,7 +119,7 @@ class POSTDeniedView(APIView):
class DocumentingRendererTests(TestCase): class DocumentingRendererTests(TestCase):
def test_only_permitted_forms_are_displayed(self): def test_only_permitted_forms_are_displayed(self):
view = POSTDeniedView.as_view() view = POSTDeniedView.as_view()
request = RequestFactory().get('/') request = APIRequestFactory().get('/')
response = view(request).render() response = view(request).render()
self.assertNotContains(response, '>POST<') self.assertNotContains(response, '>POST<')
self.assertContains(response, '>PUT<') self.assertContains(response, '>PUT<')
......
...@@ -5,8 +5,7 @@ from __future__ import unicode_literals ...@@ -5,8 +5,7 @@ from __future__ import unicode_literals
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware from django.contrib.sessions.middleware import SessionMiddleware
from django.test import TestCase, Client from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import status from rest_framework import status
from rest_framework.authentication import SessionAuthentication from rest_framework.authentication import SessionAuthentication
from rest_framework.compat import patterns from rest_framework.compat import patterns
...@@ -19,12 +18,13 @@ from rest_framework.parsers import ( ...@@ -19,12 +18,13 @@ from rest_framework.parsers import (
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.compat import six from rest_framework.compat import six
import json import json
factory = RequestFactory() factory = APIRequestFactory()
class PlainTextParser(BaseParser): class PlainTextParser(BaseParser):
...@@ -116,16 +116,7 @@ class TestContentParsing(TestCase): ...@@ -116,16 +116,7 @@ class TestContentParsing(TestCase):
Ensure request.DATA returns content for PUT request with form content. Ensure request.DATA returns content for PUT request with form content.
""" """
data = {'qwerty': 'uiop'} data = {'qwerty': 'uiop'}
from django import VERSION
if VERSION >= (1, 5):
from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart
request = Request(factory.put('/', encode_multipart(BOUNDARY, data),
content_type=MULTIPART_CONTENT))
else:
request = Request(factory.put('/', data)) request = Request(factory.put('/', data))
request.parsers = (FormParser(), MultiPartParser()) request.parsers = (FormParser(), MultiPartParser())
self.assertEqual(list(request.DATA.items()), list(data.items())) self.assertEqual(list(request.DATA.items()), list(data.items()))
...@@ -257,7 +248,7 @@ class TestContentParsingWithAuthentication(TestCase): ...@@ -257,7 +248,7 @@ class TestContentParsingWithAuthentication(TestCase):
urls = 'rest_framework.tests.test_request' urls = 'rest_framework.tests.test_request'
def setUp(self): def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework.compat import patterns, url from rest_framework.compat import patterns, url
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory
factory = RequestFactory() factory = APIRequestFactory()
def null_view(request): def null_view(request):
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers, viewsets, permissions from rest_framework import serializers, viewsets, permissions
from rest_framework.compat import include, patterns, url from rest_framework.compat import include, patterns, url
from rest_framework.decorators import link, action from rest_framework.decorators import link, action
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.routers import SimpleRouter, DefaultRouter from rest_framework.routers import SimpleRouter, DefaultRouter
from rest_framework.test import APIRequestFactory
factory = RequestFactory() factory = APIRequestFactory()
urlpatterns = patterns('',) urlpatterns = patterns('',)
...@@ -193,6 +193,7 @@ class TestActionKeywordArgs(TestCase): ...@@ -193,6 +193,7 @@ class TestActionKeywordArgs(TestCase):
{'permission_classes': [permissions.AllowAny]} {'permission_classes': [permissions.AllowAny]}
) )
class TestActionAppliedToExistingRoute(TestCase): class TestActionAppliedToExistingRoute(TestCase):
""" """
Ensure `@action` decorator raises an except when applied Ensure `@action` decorator raises an except when applied
......
# -- coding: utf-8 --
from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.test import TestCase
from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.test import APIClient, APIRequestFactory, force_authenticate
@api_view(['GET', 'POST'])
def view(request):
return Response({
'auth': request.META.get('HTTP_AUTHORIZATION', b''),
'user': request.user.username
})
urlpatterns = patterns('',
url(r'^view/$', view),
)
class TestAPITestClient(TestCase):
urls = 'rest_framework.tests.test_testing'
def setUp(self):
self.client = APIClient()
def test_credentials(self):
"""
Setting `.credentials()` adds the required headers to each request.
"""
self.client.credentials(HTTP_AUTHORIZATION='example')
for _ in range(0, 3):
response = self.client.get('/view/')
self.assertEqual(response.data['auth'], 'example')
def test_force_authenticate(self):
"""
Setting `.force_authenticate()` forcibly authenticates each request.
"""
user = User.objects.create_user('example', 'example@example.com')
self.client.force_authenticate(user)
response = self.client.get('/view/')
self.assertEqual(response.data['user'], 'example')
def test_csrf_exempt_by_default(self):
"""
By default, the test client is CSRF exempt.
"""
User.objects.create_user('example', 'example@example.com', 'password')
self.client.login(username='example', password='password')
response = self.client.post('/view/')
self.assertEqual(response.status_code, 200)
def test_explicitly_enforce_csrf_checks(self):
"""
The test client can enforce CSRF checks.
"""
client = APIClient(enforce_csrf_checks=True)
User.objects.create_user('example', 'example@example.com', 'password')
client.login(username='example', password='password')
response = client.post('/view/')
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
self.assertEqual(response.status_code, 403)
self.assertEqual(response.data, expected)
class TestAPIRequestFactory(TestCase):
def test_csrf_exempt_by_default(self):
"""
By default, the test client is CSRF exempt.
"""
user = User.objects.create_user('example', 'example@example.com', 'password')
factory = APIRequestFactory()
request = factory.post('/view/')
request.user = user
response = view(request)
self.assertEqual(response.status_code, 200)
def test_explicitly_enforce_csrf_checks(self):
"""
The test client can enforce CSRF checks.
"""
user = User.objects.create_user('example', 'example@example.com', 'password')
factory = APIRequestFactory(enforce_csrf_checks=True)
request = factory.post('/view/')
request.user = user
response = view(request)
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
self.assertEqual(response.status_code, 403)
self.assertEqual(response.data, expected)
def test_invalid_format(self):
"""
Attempting to use a format that is not configured will raise an
assertion error.
"""
factory = APIRequestFactory()
self.assertRaises(AssertionError, factory.post,
path='/view/', data={'example': 1}, format='xml'
)
def test_force_authenticate(self):
"""
Setting `force_authenticate()` forcibly authenticates the request.
"""
user = User.objects.create_user('example', 'example@example.com')
factory = APIRequestFactory()
request = factory.get('/view')
force_authenticate(request, user=user)
response = view(request)
self.assertEqual(response.data['user'], 'example')
...@@ -5,7 +5,7 @@ from __future__ import unicode_literals ...@@ -5,7 +5,7 @@ from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.cache import cache from django.core.cache import cache
from django.test.client import RequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle
from rest_framework.response import Response from rest_framework.response import Response
...@@ -41,7 +41,7 @@ class ThrottlingTests(TestCase): ...@@ -41,7 +41,7 @@ class ThrottlingTests(TestCase):
Reset the cache so that no throttles will be active Reset the cache so that no throttles will be active
""" """
cache.clear() cache.clear()
self.factory = RequestFactory() self.factory = APIRequestFactory()
def test_requests_are_throttled(self): def test_requests_are_throttled(self):
""" """
...@@ -173,7 +173,7 @@ class ScopedRateThrottleTests(TestCase): ...@@ -173,7 +173,7 @@ class ScopedRateThrottleTests(TestCase):
return Response('y') return Response('y')
self.throttle_class = XYScopedRateThrottle self.throttle_class = XYScopedRateThrottle
self.factory = RequestFactory() self.factory = APIRequestFactory()
self.x_view = XView.as_view() self.x_view = XView.as_view()
self.y_view = YView.as_view() self.y_view = YView.as_view()
self.unscoped_view = UnscopedView.as_view() self.unscoped_view = UnscopedView.as_view()
......
...@@ -2,7 +2,7 @@ from __future__ import unicode_literals ...@@ -2,7 +2,7 @@ from __future__ import unicode_literals
from collections import namedtuple from collections import namedtuple
from django.core import urlresolvers from django.core import urlresolvers
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.compat import patterns, url, include from rest_framework.compat import patterns, url, include
from rest_framework.urlpatterns import format_suffix_patterns from rest_framework.urlpatterns import format_suffix_patterns
...@@ -20,7 +20,7 @@ class FormatSuffixTests(TestCase): ...@@ -20,7 +20,7 @@ class FormatSuffixTests(TestCase):
Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
""" """
def _resolve_urlpatterns(self, urlpatterns, test_paths): def _resolve_urlpatterns(self, urlpatterns, test_paths):
factory = RequestFactory() factory = APIRequestFactory()
try: try:
urlpatterns = format_suffix_patterns(urlpatterns) urlpatterns = format_suffix_patterns(urlpatterns)
except Exception: except Exception:
......
...@@ -2,10 +2,9 @@ from __future__ import unicode_literals ...@@ -2,10 +2,9 @@ from __future__ import unicode_literals
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, serializers, status from rest_framework import generics, serializers, status
from rest_framework.tests.utils import RequestFactory from rest_framework.test import APIRequestFactory
import json
factory = RequestFactory() factory = APIRequestFactory()
# Regression for #666 # Regression for #666
...@@ -33,8 +32,7 @@ class TestPreSaveValidationExclusions(TestCase): ...@@ -33,8 +32,7 @@ class TestPreSaveValidationExclusions(TestCase):
validation on read only fields. validation on read only fields.
""" """
obj = ValidationModel.objects.create(blank_validated_field='') obj = ValidationModel.objects.create(blank_validated_field='')
request = factory.put('/', json.dumps({}), request = factory.put('/', {}, format='json')
content_type='application/json')
view = UpdateValidationModel().as_view() view = UpdateValidationModel().as_view()
response = view(request, pk=obj.pk).render() response = view(request, pk=obj.pk).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
......
from __future__ import unicode_literals from __future__ import unicode_literals
import copy import copy
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import status from rest_framework import status
from rest_framework.decorators import api_view from rest_framework.decorators import api_view
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
factory = RequestFactory() factory = APIRequestFactory()
class BasicView(APIView): class BasicView(APIView):
......
from __future__ import unicode_literals
from django.test.client import FakePayload, Client as _Client, RequestFactory as _RequestFactory
from django.test.client import MULTIPART_CONTENT
from rest_framework.compat import urlparse
class RequestFactory(_RequestFactory):
def __init__(self, **defaults):
super(RequestFactory, self).__init__(**defaults)
def patch(self, path, data={}, content_type=MULTIPART_CONTENT,
**extra):
"Construct a PATCH request."
patch_data = self._encode_data(data, content_type)
parsed = urlparse.urlparse(path)
r = {
'CONTENT_LENGTH': len(patch_data),
'CONTENT_TYPE': content_type,
'PATH_INFO': self._get_path(parsed),
'QUERY_STRING': parsed[4],
'REQUEST_METHOD': 'PATCH',
'wsgi.input': FakePayload(patch_data),
}
r.update(extra)
return self.request(**r)
class Client(_Client, RequestFactory):
def patch(self, path, data={}, content_type=MULTIPART_CONTENT,
follow=False, **extra):
"""
Send a resource to the server using PATCH.
"""
response = super(Client, self).patch(path, data=data, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment