Commit bf09c32d by Tom Christie

Code linting and added runtests.py

parent e385a7b8
...@@ -28,7 +28,7 @@ install: ...@@ -28,7 +28,7 @@ install:
- export PYTHONPATH=. - export PYTHONPATH=.
script: script:
- py.test - ./runtests.py
matrix: matrix:
exclude: exclude:
......
[pytest]
addopts = --tb=short
# Test requirements
pytest-django==2.6
pytest==2.5.2
pytest-cov==1.6
# Optional packages
markdown>=2.1.0 markdown>=2.1.0
PyYAML>=3.10 PyYAML>=3.10
defusedxml>=0.3 defusedxml>=0.3
......
""" """
______ _____ _____ _____ __ _ ______ _____ _____ _____ __
| ___ \ ___/ ___|_ _| / _| | | | ___ \ ___/ ___|_ _| / _| | |
| |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| | __ | |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| |__
| /| __| `--. \ | | | _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ / | /| __| `--. \ | | | _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ /
| |\ \| |___/\__/ / | | | | | | | (_| | | | | | | __/\ V V / (_) | | | < | |\ \| |___/\__/ / | | | | | | | (_| | | | | | | __/\ V V / (_) | | | <
\_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_| \_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_|
""" """
......
...@@ -21,7 +21,7 @@ def get_authorization_header(request): ...@@ -21,7 +21,7 @@ def get_authorization_header(request):
Hide some test client ickyness where the header can be unicode. Hide some test client ickyness where the header can be unicode.
""" """
auth = request.META.get('HTTP_AUTHORIZATION', b'') auth = request.META.get('HTTP_AUTHORIZATION', b'')
if type(auth) == type(''): if isinstance(auth, type('')):
# Work around django test client oddness # Work around django test client oddness
auth = auth.encode(HTTP_HEADER_ENCODING) auth = auth.encode(HTTP_HEADER_ENCODING)
return auth return auth
......
import binascii import binascii
import os import os
from hashlib import sha1
from django.conf import settings from django.conf import settings
from django.db import models from django.db import models
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import datetime
from south.db import db from south.db import db
from south.v2 import SchemaMigration from south.v2 import SchemaMigration
from django.db import models
from rest_framework.settings import api_settings
try: try:
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
except ImportError: # django < 1.5 except ImportError: # django < 1.5
from django.contrib.auth.models import User from django.contrib.auth.models import User
else: else:
User = get_user_model() User = get_user_model()
...@@ -26,12 +21,10 @@ class Migration(SchemaMigration): ...@@ -26,12 +21,10 @@ class Migration(SchemaMigration):
)) ))
db.send_create_signal('authtoken', ['Token']) db.send_create_signal('authtoken', ['Token'])
def backwards(self, orm): def backwards(self, orm):
# Deleting model 'Token' # Deleting model 'Token'
db.delete_table('authtoken_token') db.delete_table('authtoken_token')
models = { models = {
'auth.group': { 'auth.group': {
'Meta': {'object_name': 'Group'}, 'Meta': {'object_name': 'Group'},
......
...@@ -131,6 +131,7 @@ def list_route(methods=['get'], **kwargs): ...@@ -131,6 +131,7 @@ def list_route(methods=['get'], **kwargs):
return func return func
return decorator return decorator
# These are now pending deprecation, in favor of `detail_route` and `list_route`. # These are now pending deprecation, in favor of `detail_route` and `list_route`.
def link(**kwargs): def link(**kwargs):
...@@ -139,11 +140,13 @@ def link(**kwargs): ...@@ -139,11 +140,13 @@ def link(**kwargs):
""" """
msg = 'link is pending deprecation. Use detail_route instead.' msg = 'link is pending deprecation. Use detail_route instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func): def decorator(func):
func.bind_to_methods = ['get'] func.bind_to_methods = ['get']
func.detail = True func.detail = True
func.kwargs = kwargs func.kwargs = kwargs
return func return func
return decorator return decorator
...@@ -153,9 +156,11 @@ def action(methods=['post'], **kwargs): ...@@ -153,9 +156,11 @@ def action(methods=['post'], **kwargs):
""" """
msg = 'action is pending deprecation. Use detail_route instead.' msg = 'action is pending deprecation. Use detail_route instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func): def decorator(func):
func.bind_to_methods = methods func.bind_to_methods = methods
func.detail = True func.detail = True
func.kwargs = kwargs func.kwargs = kwargs
return func return func
return decorator
\ No newline at end of file return decorator
...@@ -23,6 +23,7 @@ class APIException(Exception): ...@@ -23,6 +23,7 @@ class APIException(Exception):
def __str__(self): def __str__(self):
return self.detail return self.detail
class ParseError(APIException): class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'Malformed request.' default_detail = 'Malformed request.'
......
...@@ -116,7 +116,7 @@ class OrderingFilter(BaseFilterBackend): ...@@ -116,7 +116,7 @@ class OrderingFilter(BaseFilterBackend):
def get_ordering(self, request): def get_ordering(self, request):
""" """
Ordering is set by a comma delimited ?ordering=... query parameter. Ordering is set by a comma delimited ?ordering=... query parameter.
The `ordering` query parameter can be overridden by setting The `ordering` query parameter can be overridden by setting
the `ordering_param` value on the OrderingFilter or by the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings. specifying an `ORDERING_PARAM` value in the API settings.
......
...@@ -25,6 +25,7 @@ def strict_positive_int(integer_string, cutoff=None): ...@@ -25,6 +25,7 @@ def strict_positive_int(integer_string, cutoff=None):
ret = min(ret, cutoff) ret = min(ret, cutoff)
return ret return ret
def get_object_or_404(queryset, *filter_args, **filter_kwargs): def get_object_or_404(queryset, *filter_args, **filter_kwargs):
""" """
Same as Django's standard shortcut, but make sure to raise 404 Same as Django's standard shortcut, but make sure to raise 404
...@@ -162,10 +163,11 @@ class GenericAPIView(views.APIView): ...@@ -162,10 +163,11 @@ class GenericAPIView(views.APIView):
raise Http404(_("Page is not 'last', nor can it be converted to an int.")) raise Http404(_("Page is not 'last', nor can it be converted to an int."))
try: try:
page = paginator.page(page_number) page = paginator.page(page_number)
except InvalidPage as e: except InvalidPage as exc:
raise Http404(_('Invalid page (%(page_number)s): %(message)s') % { error_format = _('Invalid page (%(page_number)s): %(message)s')
'page_number': page_number, raise Http404(error_format % {
'message': str(e) 'page_number': page_number,
'message': str(exc)
}) })
if deprecated_style: if deprecated_style:
...@@ -208,7 +210,6 @@ class GenericAPIView(views.APIView): ...@@ -208,7 +210,6 @@ class GenericAPIView(views.APIView):
return filter_backends return filter_backends
######################## ########################
### The following methods provide default implementations ### The following methods provide default implementations
### that you may want to override for more complex cases. ### that you may want to override for more complex cases.
...@@ -284,8 +285,8 @@ class GenericAPIView(views.APIView): ...@@ -284,8 +285,8 @@ class GenericAPIView(views.APIView):
if self.model is not None: if self.model is not None:
return self.model._default_manager.all() return self.model._default_manager.all()
raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" error_format = "'%s' must define 'queryset' or 'model'"
% self.__class__.__name__) raise ImproperlyConfigured(error_format % self.__class__.__name__)
def get_object(self, queryset=None): def get_object(self, queryset=None):
""" """
......
...@@ -54,8 +54,10 @@ class DefaultContentNegotiation(BaseContentNegotiation): ...@@ -54,8 +54,10 @@ class DefaultContentNegotiation(BaseContentNegotiation):
for media_type in media_type_set: for media_type in media_type_set:
if media_type_matches(renderer.media_type, media_type): if media_type_matches(renderer.media_type, media_type):
# Return the most specific media type as accepted. # Return the most specific media type as accepted.
if (_MediaType(renderer.media_type).precedence > if (
_MediaType(media_type).precedence): _MediaType(renderer.media_type).precedence >
_MediaType(media_type).precedence
):
# Eg client requests '*/*' # Eg client requests '*/*'
# Accepted media type is 'application/json' # Accepted media type is 'application/json'
return renderer, renderer.media_type return renderer, renderer.media_type
......
...@@ -62,9 +62,11 @@ class IsAuthenticatedOrReadOnly(BasePermission): ...@@ -62,9 +62,11 @@ class IsAuthenticatedOrReadOnly(BasePermission):
""" """
def has_permission(self, request, view): def has_permission(self, request, view):
return (request.method in SAFE_METHODS or return (
request.user and request.method in SAFE_METHODS or
request.user.is_authenticated()) request.user and
request.user.is_authenticated()
)
class DjangoModelPermissions(BasePermission): class DjangoModelPermissions(BasePermission):
...@@ -122,9 +124,11 @@ class DjangoModelPermissions(BasePermission): ...@@ -122,9 +124,11 @@ class DjangoModelPermissions(BasePermission):
perms = self.get_required_permissions(request.method, model_cls) perms = self.get_required_permissions(request.method, model_cls)
return (request.user and return (
request.user and
(request.user.is_authenticated() or not self.authenticated_users_only) and (request.user.is_authenticated() or not self.authenticated_users_only) and
request.user.has_perms(perms)) request.user.has_perms(perms)
)
class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions): class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
...@@ -212,6 +216,8 @@ class TokenHasReadWriteScope(BasePermission): ...@@ -212,6 +216,8 @@ class TokenHasReadWriteScope(BasePermission):
required = oauth2_constants.READ if read_only else oauth2_constants.WRITE required = oauth2_constants.READ if read_only else oauth2_constants.WRITE
return oauth2_provider_scope.check(required, request.auth.scope) return oauth2_provider_scope.check(required, request.auth.scope)
assert False, ('TokenHasReadWriteScope requires either the' assert False, (
'`OAuthAuthentication` or `OAuth2Authentication` authentication ' 'TokenHasReadWriteScope requires either the'
'class to be used.') '`OAuthAuthentication` or `OAuth2Authentication` authentication '
'class to be used.'
)
...@@ -8,7 +8,6 @@ REST framework also provides an HTML renderer the renders the browsable API. ...@@ -8,7 +8,6 @@ REST framework also provides an HTML renderer the renders the browsable API.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
import copy
import json import json
import django import django
from django import forms from django import forms
...@@ -75,7 +74,6 @@ class JSONRenderer(BaseRenderer): ...@@ -75,7 +74,6 @@ class JSONRenderer(BaseRenderer):
# E.g. If we're being called by the BrowsableAPIRenderer. # E.g. If we're being called by the BrowsableAPIRenderer.
return renderer_context.get('indent', None) return renderer_context.get('indent', None)
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
""" """
Render `data` into JSON, returning a bytestring. Render `data` into JSON, returning a bytestring.
...@@ -86,8 +84,10 @@ class JSONRenderer(BaseRenderer): ...@@ -86,8 +84,10 @@ class JSONRenderer(BaseRenderer):
renderer_context = renderer_context or {} renderer_context = renderer_context or {}
indent = self.get_indent(accepted_media_type, renderer_context) indent = self.get_indent(accepted_media_type, renderer_context)
ret = json.dumps(data, cls=self.encoder_class, ret = json.dumps(
indent=indent, ensure_ascii=self.ensure_ascii) data, cls=self.encoder_class,
indent=indent, ensure_ascii=self.ensure_ascii
)
# On python 2.x json.dumps() returns bytestrings if ensure_ascii=True, # On python 2.x json.dumps() returns bytestrings if ensure_ascii=True,
# but if ensure_ascii=False, the return type is underspecified, # but if ensure_ascii=False, the return type is underspecified,
...@@ -454,8 +454,10 @@ class BrowsableAPIRenderer(BaseRenderer): ...@@ -454,8 +454,10 @@ class BrowsableAPIRenderer(BaseRenderer):
if method in ('DELETE', 'OPTIONS'): if method in ('DELETE', 'OPTIONS'):
return True # Don't actually need to return a form return True # Don't actually need to return a form
if (not getattr(view, 'get_serializer', None) if (
or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)): not getattr(view, 'get_serializer', None)
or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)
):
return return
serializer = view.get_serializer(instance=obj, data=data, files=files) serializer = view.get_serializer(instance=obj, data=data, files=files)
...@@ -576,7 +578,7 @@ class BrowsableAPIRenderer(BaseRenderer): ...@@ -576,7 +578,7 @@ class BrowsableAPIRenderer(BaseRenderer):
'version': VERSION, 'version': VERSION,
'breadcrumblist': self.get_breadcrumbs(request), 'breadcrumblist': self.get_breadcrumbs(request),
'allowed_methods': view.allowed_methods, 'allowed_methods': view.allowed_methods,
'available_formats': [renderer.format for renderer in view.renderer_classes], 'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
'response_headers': response_headers, 'response_headers': response_headers,
'put_form': self.get_rendered_html_form(view, 'PUT', request), 'put_form': self.get_rendered_html_form(view, 'PUT', request),
...@@ -625,4 +627,3 @@ class MultiPartRenderer(BaseRenderer): ...@@ -625,4 +627,3 @@ class MultiPartRenderer(BaseRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
return encode_multipart(self.BOUNDARY, data) return encode_multipart(self.BOUNDARY, data)
...@@ -295,8 +295,11 @@ class Request(object): ...@@ -295,8 +295,11 @@ class Request(object):
Return the content body of the request, as a stream. Return the content body of the request, as a stream.
""" """
try: try:
content_length = int(self.META.get('CONTENT_LENGTH', content_length = int(
self.META.get('HTTP_CONTENT_LENGTH'))) self.META.get(
'CONTENT_LENGTH', self.META.get('HTTP_CONTENT_LENGTH')
)
)
except (ValueError, TypeError): except (ValueError, TypeError):
content_length = 0 content_length = 0
...@@ -320,9 +323,11 @@ class Request(object): ...@@ -320,9 +323,11 @@ class Request(object):
) )
# We only need to use form overloading on form POST requests. # We only need to use form overloading on form POST requests.
if (not USE_FORM_OVERLOADING if (
not USE_FORM_OVERLOADING
or self._request.method != 'POST' or self._request.method != 'POST'
or not is_form_media_type(self._content_type)): or not is_form_media_type(self._content_type)
):
return return
# At this point we're committed to parsing the request as form data. # At this point we're committed to parsing the request as form data.
...@@ -330,15 +335,19 @@ class Request(object): ...@@ -330,15 +335,19 @@ class Request(object):
self._files = self._request.FILES self._files = self._request.FILES
# Method overloading - change the method and remove the param from the content. # Method overloading - change the method and remove the param from the content.
if (self._METHOD_PARAM and if (
self._METHOD_PARAM in self._data): self._METHOD_PARAM and
self._METHOD_PARAM in self._data
):
self._method = self._data[self._METHOD_PARAM].upper() self._method = self._data[self._METHOD_PARAM].upper()
# Content overloading - modify the content type, and force re-parse. # Content overloading - modify the content type, and force re-parse.
if (self._CONTENT_PARAM and if (
self._CONTENT_PARAM and
self._CONTENTTYPE_PARAM and self._CONTENTTYPE_PARAM and
self._CONTENT_PARAM in self._data and self._CONTENT_PARAM in self._data and
self._CONTENTTYPE_PARAM in self._data): self._CONTENTTYPE_PARAM in self._data
):
self._content_type = self._data[self._CONTENTTYPE_PARAM] self._content_type = self._data[self._CONTENTTYPE_PARAM]
self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding'])) self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))
self._data, self._files = (Empty, Empty) self._data, self._files = (Empty, Empty)
......
...@@ -62,8 +62,10 @@ class Response(SimpleTemplateResponse): ...@@ -62,8 +62,10 @@ class Response(SimpleTemplateResponse):
ret = renderer.render(self.data, media_type, context) ret = renderer.render(self.data, media_type, context)
if isinstance(ret, six.text_type): if isinstance(ret, six.text_type):
assert charset, 'renderer returned unicode, and did not specify ' \ assert charset, (
'a charset value.' 'renderer returned unicode, and did not specify '
'a charset value.'
)
return bytes(ret.encode(charset)) return bytes(ret.encode(charset))
if not ret: if not ret:
......
...@@ -449,9 +449,11 @@ class BaseSerializer(WritableField): ...@@ -449,9 +449,11 @@ class BaseSerializer(WritableField):
# If we have a model manager or similar object then we need # If we have a model manager or similar object then we need
# to iterate through each instance. # to iterate through each instance.
if (self.many and if (
self.many and
not hasattr(obj, '__iter__') and not hasattr(obj, '__iter__') and
is_simple_callable(getattr(obj, 'all', None))): is_simple_callable(getattr(obj, 'all', None))
):
obj = obj.all() obj = obj.all()
kwargs = { kwargs = {
...@@ -601,8 +603,10 @@ class BaseSerializer(WritableField): ...@@ -601,8 +603,10 @@ class BaseSerializer(WritableField):
API schemas for auto-documentation. API schemas for auto-documentation.
""" """
return SortedDict( return SortedDict(
[(field_name, field.metadata()) [
for field_name, field in six.iteritems(self.fields)] (field_name, field.metadata())
for field_name, field in six.iteritems(self.fields)
]
) )
...@@ -656,8 +660,10 @@ class ModelSerializer(Serializer): ...@@ -656,8 +660,10 @@ class ModelSerializer(Serializer):
""" """
cls = self.opts.model cls = self.opts.model
assert cls is not None, \ assert cls is not None, (
"Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__ "Serializer class '%s' is missing 'model' Meta option" %
self.__class__.__name__
)
opts = cls._meta.concrete_model._meta opts = cls._meta.concrete_model._meta
ret = SortedDict() ret = SortedDict()
nested = bool(self.opts.depth) nested = bool(self.opts.depth)
...@@ -668,9 +674,9 @@ class ModelSerializer(Serializer): ...@@ -668,9 +674,9 @@ class ModelSerializer(Serializer):
# If model is a child via multitable inheritance, use parent's pk # If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk pk_field = pk_field.rel.to._meta.pk
field = self.get_pk_field(pk_field) serializer_pk_field = self.get_pk_field(pk_field)
if field: if serializer_pk_field:
ret[pk_field.name] = field ret[pk_field.name] = serializer_pk_field
# Deal with forward relationships # Deal with forward relationships
forward_rels = [field for field in opts.fields if field.serialize] forward_rels = [field for field in opts.fields if field.serialize]
...@@ -739,9 +745,11 @@ class ModelSerializer(Serializer): ...@@ -739,9 +745,11 @@ class ModelSerializer(Serializer):
is_m2m = isinstance(relation.field, is_m2m = isinstance(relation.field,
models.fields.related.ManyToManyField) models.fields.related.ManyToManyField)
if (is_m2m and if (
is_m2m and
hasattr(relation.field.rel, 'through') and hasattr(relation.field.rel, 'through') and
not relation.field.rel.through._meta.auto_created): not relation.field.rel.through._meta.auto_created
):
has_through_model = True has_through_model = True
if nested: if nested:
...@@ -911,10 +919,12 @@ class ModelSerializer(Serializer): ...@@ -911,10 +919,12 @@ class ModelSerializer(Serializer):
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
field_name = field.source or field_name field_name = field.source or field_name
if field_name in exclusions \ if (
and not field.read_only \ field_name in exclusions
and (field.required or hasattr(instance, field_name)) \ and not field.read_only
and not isinstance(field, Serializer): and (field.required or hasattr(instance, field_name))
and not isinstance(field, Serializer)
):
exclusions.remove(field_name) exclusions.remove(field_name)
return exclusions return exclusions
......
...@@ -46,16 +46,12 @@ DEFAULTS = { ...@@ -46,16 +46,12 @@ DEFAULTS = {
'DEFAULT_PERMISSION_CLASSES': ( 'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.AllowAny', 'rest_framework.permissions.AllowAny',
), ),
'DEFAULT_THROTTLE_CLASSES': ( 'DEFAULT_THROTTLE_CLASSES': (),
), 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
# Genric view behavior # Genric view behavior
'DEFAULT_MODEL_SERIALIZER_CLASS': 'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer',
'rest_framework.serializers.ModelSerializer', 'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer',
'DEFAULT_PAGINATION_SERIALIZER_CLASS':
'rest_framework.pagination.PaginationSerializer',
'DEFAULT_FILTER_BACKENDS': (), 'DEFAULT_FILTER_BACKENDS': (),
# Throttling # Throttling
......
...@@ -10,15 +10,19 @@ from __future__ import unicode_literals ...@@ -10,15 +10,19 @@ from __future__ import unicode_literals
def is_informational(code): def is_informational(code):
return code >= 100 and code <= 199 return code >= 100 and code <= 199
def is_success(code): def is_success(code):
return code >= 200 and code <= 299 return code >= 200 and code <= 299
def is_redirect(code): def is_redirect(code):
return code >= 300 and code <= 399 return code >= 300 and code <= 399
def is_client_error(code): def is_client_error(code):
return code >= 400 and code <= 499 return code >= 400 and code <= 499
def is_server_error(code): def is_server_error(code):
return code >= 500 and code <= 599 return code >= 500 and code <= 599
......
...@@ -152,8 +152,10 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru ...@@ -152,8 +152,10 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
middle = middle[len(opening):] middle = middle[len(opening):]
lead = lead + opening lead = lead + opening
# Keep parentheses at the end only if they're balanced. # Keep parentheses at the end only if they're balanced.
if (middle.endswith(closing) if (
and middle.count(closing) == middle.count(opening) + 1): middle.endswith(closing)
and middle.count(closing) == middle.count(opening) + 1
):
middle = middle[:-len(closing)] middle = middle[:-len(closing)]
trail = closing + trail trail = closing + trail
......
...@@ -49,9 +49,10 @@ class APIRequestFactory(DjangoRequestFactory): ...@@ -49,9 +49,10 @@ class APIRequestFactory(DjangoRequestFactory):
else: else:
format = format or self.default_format format = format or self.default_format
assert format in self.renderer_classes, ("Invalid format '{0}'. " assert format in self.renderer_classes, (
"Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES " "Invalid format '{0}'. Available formats are {1}. "
"to enable extra request formats.".format( "Set TEST_REQUEST_RENDERER_CLASSES to enable "
"extra request formats.".format(
format, format,
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()]) ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])
) )
......
...@@ -8,17 +8,19 @@ your API requires authentication: ...@@ -8,17 +8,19 @@ your API requires authentication:
... ...
url(r'^auth', include('rest_framework.urls', namespace='rest_framework')) url(r'^auth', include('rest_framework.urls', namespace='rest_framework'))
) )
The urls must be namespaced as 'rest_framework', and you should make sure The urls must be namespaced as 'rest_framework', and you should make sure
your authentication settings include `SessionAuthentication`. your authentication settings include `SessionAuthentication`.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, url from django.conf.urls import patterns, url
from django.contrib.auth import views
template_name = {'template_name': 'rest_framework/login.html'} template_name = {'template_name': 'rest_framework/login.html'}
urlpatterns = patterns('django.contrib.auth.views', urlpatterns = patterns(
url(r'^login/$', 'login', template_name, name='login'), '',
url(r'^logout/$', 'logout', template_name, name='logout'), url(r'^login/$', views.login, template_name, name='login'),
url(r'^logout/$', views.logout, template_name, name='logout')
) )
...@@ -98,14 +98,23 @@ else: ...@@ -98,14 +98,23 @@ else:
node.flow_style = best_style node.flow_style = best_style
return node return node
SafeDumper.add_representer(decimal.Decimal, SafeDumper.add_representer(
SafeDumper.represent_decimal) decimal.Decimal,
SafeDumper.represent_decimal
SafeDumper.add_representer(SortedDict, )
yaml.representer.SafeRepresenter.represent_dict) SafeDumper.add_representer(
SafeDumper.add_representer(DictWithMetadata, SortedDict,
yaml.representer.SafeRepresenter.represent_dict) yaml.representer.SafeRepresenter.represent_dict
SafeDumper.add_representer(SortedDictWithMetadata, )
yaml.representer.SafeRepresenter.represent_dict) SafeDumper.add_representer(
SafeDumper.add_representer(types.GeneratorType, DictWithMetadata,
yaml.representer.SafeRepresenter.represent_list) yaml.representer.SafeRepresenter.represent_dict
)
SafeDumper.add_representer(
SortedDictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict
)
SafeDumper.add_representer(
types.GeneratorType,
yaml.representer.SafeRepresenter.represent_list
)
...@@ -6,8 +6,6 @@ from __future__ import unicode_literals ...@@ -6,8 +6,6 @@ from __future__ import unicode_literals
from django.utils.html import escape from django.utils.html import escape
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
from rest_framework.compat import apply_markdown from rest_framework.compat import apply_markdown
from rest_framework.settings import api_settings
from textwrap import dedent
import re import re
...@@ -40,6 +38,7 @@ def dedent(content): ...@@ -40,6 +38,7 @@ def dedent(content):
return content.strip() return content.strip()
def camelcase_to_spaces(content): def camelcase_to_spaces(content):
""" """
Translate 'CamelCaseNames' to 'Camel Case Names'. Translate 'CamelCaseNames' to 'Camel Case Names'.
...@@ -49,6 +48,7 @@ def camelcase_to_spaces(content): ...@@ -49,6 +48,7 @@ def camelcase_to_spaces(content):
content = re.sub(camelcase_boundry, ' \\1', content).strip() content = re.sub(camelcase_boundry, ' \\1', content).strip()
return ' '.join(content.split('_')).title() return ' '.join(content.split('_')).title()
def markup_description(description): def markup_description(description):
""" """
Apply HTML markup to the given description. Apply HTML markup to the given description.
......
...@@ -57,7 +57,7 @@ class _MediaType(object): ...@@ -57,7 +57,7 @@ class _MediaType(object):
if key != 'q' and other.params.get(key, None) != self.params.get(key, None): if key != 'q' and other.params.get(key, None) != self.params.get(key, None):
return False return False
if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type: if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
return False return False
if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type: if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type:
......
...@@ -31,6 +31,7 @@ def get_view_name(view_cls, suffix=None): ...@@ -31,6 +31,7 @@ def get_view_name(view_cls, suffix=None):
return name return name
def get_view_description(view_cls, html=False): def get_view_description(view_cls, html=False):
""" """
Given a view class, return a textual description to represent the view. Given a view class, return a textual description to represent the view.
...@@ -119,7 +120,6 @@ class APIView(View): ...@@ -119,7 +120,6 @@ class APIView(View):
headers['Vary'] = 'Accept' headers['Vary'] = 'Accept'
return headers return headers
def http_method_not_allowed(self, request, *args, **kwargs): def http_method_not_allowed(self, request, *args, **kwargs):
""" """
If `request.method` does not correspond to a handler method, If `request.method` does not correspond to a handler method,
......
...@@ -127,11 +127,11 @@ class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, ...@@ -127,11 +127,11 @@ class ReadOnlyModelViewSet(mixins.RetrieveModelMixin,
class ModelViewSet(mixins.CreateModelMixin, class ModelViewSet(mixins.CreateModelMixin,
mixins.RetrieveModelMixin, mixins.RetrieveModelMixin,
mixins.UpdateModelMixin, mixins.UpdateModelMixin,
mixins.DestroyModelMixin, mixins.DestroyModelMixin,
mixins.ListModelMixin, mixins.ListModelMixin,
GenericViewSet): GenericViewSet):
""" """
A viewset that provides default `create()`, `retrieve()`, `update()`, A viewset that provides default `create()`, `retrieve()`, `update()`,
`partial_update()`, `destroy()` and `list()` actions. `partial_update()`, `destroy()` and `list()` actions.
......
#! /usr/bin/env python
from __future__ import print_function
import pytest
import sys
import os
import subprocess
PYTEST_ARGS = {
'default': ['tests'],
'fast': ['tests', '-q'],
}
FLAKE8_ARGS = ['rest_framework', 'tests', '--ignore=E501']
sys.path.append(os.path.dirname(__file__))
def exit_on_failure(ret, message=None):
if ret:
sys.exit(ret)
def flake8_main(args):
print('Running flake8 code linting')
ret = subprocess.call(['flake8'] + args)
print('flake8 failed' if ret else 'flake8 passed')
return ret
def split_class_and_function(string):
class_string, function_string = string.split('.', 1)
return "%s and %s" % (class_string, function_string)
def is_function(string):
# `True` if it looks like a test function is included in the string.
return string.startswith('test_') or '.test_' in string
def is_class(string):
# `True` if first character is uppercase - assume it's a class name.
return string[0] == string[0].upper()
if __name__ == "__main__":
try:
sys.argv.remove('--nolint')
except ValueError:
run_flake8 = True
else:
run_flake8 = False
try:
sys.argv.remove('--lintonly')
except ValueError:
run_tests = True
else:
run_tests = False
try:
sys.argv.remove('--fast')
except ValueError:
style = 'default'
else:
style = 'fast'
run_flake8 = False
if len(sys.argv) > 1:
pytest_args = sys.argv[1:]
first_arg = pytest_args[0]
if first_arg.startswith('-'):
# `runtests.py [flags]`
pytest_args = ['tests'] + pytest_args
elif is_class(first_arg) and is_function(first_arg):
# `runtests.py TestCase.test_function [flags]`
expression = split_class_and_function(first_arg)
pytest_args = ['tests', '-k', expression] + pytest_args[1:]
elif is_class(first_arg) or is_function(first_arg):
# `runtests.py TestCase [flags]`
# `runtests.py test_function [flags]`
pytest_args = ['tests', '-k', pytest_args[0]] + pytest_args[1:]
else:
pytest_args = PYTEST_ARGS[style]
if run_tests:
exit_on_failure(pytest.main(pytest_args))
if run_flake8:
exit_on_failure(flake8_main(FLAKE8_ARGS))
from rest_framework import serializers from rest_framework import serializers
from tests.models import NullableForeignKeySource from tests.models import NullableForeignKeySource
......
...@@ -68,7 +68,6 @@ SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy' ...@@ -68,7 +68,6 @@ SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy'
TEMPLATE_LOADERS = ( TEMPLATE_LOADERS = (
'django.template.loaders.filesystem.Loader', 'django.template.loaders.filesystem.Loader',
'django.template.loaders.app_directories.Loader', 'django.template.loaders.app_directories.Loader',
# 'django.template.loaders.eggs.Loader',
) )
MIDDLEWARE_CLASSES = ( MIDDLEWARE_CLASSES = (
...@@ -104,8 +103,8 @@ INSTALLED_APPS = ( ...@@ -104,8 +103,8 @@ INSTALLED_APPS = (
# OAuth is optional and won't work if there is no oauth_provider & oauth2 # OAuth is optional and won't work if there is no oauth_provider & oauth2
try: try:
import oauth_provider import oauth_provider # NOQA
import oauth2 import oauth2 # NOQA
except ImportError: except ImportError:
pass pass
else: else:
...@@ -114,7 +113,7 @@ else: ...@@ -114,7 +113,7 @@ else:
) )
try: try:
import provider import provider # NOQA
except ImportError: except ImportError:
pass pass
else: else:
...@@ -125,13 +124,13 @@ else: ...@@ -125,13 +124,13 @@ else:
# guardian is optional # guardian is optional
try: try:
import guardian import guardian # NOQA
except ImportError: except ImportError:
pass pass
else: else:
ANONYMOUS_USER_ID = -1 ANONYMOUS_USER_ID = -1
AUTHENTICATION_BACKENDS = ( AUTHENTICATION_BACKENDS = (
'django.contrib.auth.backends.ModelBackend', # default 'django.contrib.auth.backends.ModelBackend', # default
'guardian.backends.ObjectPermissionBackend', 'guardian.backends.ObjectPermissionBackend',
) )
INSTALLED_APPS += ( INSTALLED_APPS += (
......
...@@ -45,26 +45,39 @@ class MockView(APIView): ...@@ -45,26 +45,39 @@ class MockView(APIView):
return HttpResponse({'a': 1, 'b': 2, 'c': 3}) return HttpResponse({'a': 1, 'b': 2, 'c': 3})
urlpatterns = patterns('', urlpatterns = patterns(
'',
(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
(r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
(r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], (
permission_classes=[permissions.TokenHasReadWriteScope])) r'^oauth-with-scope/$',
MockView.as_view(
authentication_classes=[OAuthAuthentication],
permission_classes=[permissions.TokenHasReadWriteScope]
)
)
) )
class OAuth2AuthenticationDebug(OAuth2Authentication): class OAuth2AuthenticationDebug(OAuth2Authentication):
allow_query_params_token = True allow_query_params_token = True
if oauth2_provider is not None: if oauth2_provider is not None:
urlpatterns += patterns('', urlpatterns += patterns(
'',
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])), url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], url(
permission_classes=[permissions.TokenHasReadWriteScope])), r'^oauth2-with-scope-test/$',
MockView.as_view(
authentication_classes=[OAuth2Authentication],
permission_classes=[permissions.TokenHasReadWriteScope]
)
)
) )
...@@ -278,12 +291,16 @@ class OAuthTests(TestCase): ...@@ -278,12 +291,16 @@ class OAuthTests(TestCase):
self.TOKEN_KEY = "token_key" self.TOKEN_KEY = "token_key"
self.TOKEN_SECRET = "token_secret" self.TOKEN_SECRET = "token_secret"
self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, self.consumer = Consumer.objects.create(
name='example', user=self.user, status=self.consts.ACCEPTED) key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
name='example', user=self.user, status=self.consts.ACCEPTED
)
self.scope = Scope.objects.create(name="resource name", url="api/") self.scope = Scope.objects.create(name="resource name", url="api/")
self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, scope=self.scope, self.token = OAuthToken.objects.create(
token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True user=self.user, consumer=self.consumer, scope=self.scope,
token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET,
is_approved=True
) )
def _create_authorization_header(self): def _create_authorization_header(self):
...@@ -501,24 +518,24 @@ class OAuth2Tests(TestCase): ...@@ -501,24 +518,24 @@ class OAuth2Tests(TestCase):
self.REFRESH_TOKEN = "refresh_token" self.REFRESH_TOKEN = "refresh_token"
self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create( self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create(
client_id=self.CLIENT_ID, client_id=self.CLIENT_ID,
client_secret=self.CLIENT_SECRET, client_secret=self.CLIENT_SECRET,
redirect_uri='', redirect_uri='',
client_type=0, client_type=0,
name='example', name='example',
user=None, user=None,
) )
self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create( self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create(
token=self.ACCESS_TOKEN, token=self.ACCESS_TOKEN,
client=self.oauth2_client, client=self.oauth2_client,
user=self.user, user=self.user,
) )
self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create( self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create(
user=self.user, user=self.user,
access_token=self.access_token, access_token=self.access_token,
client=self.oauth2_client client=self.oauth2_client
) )
def _create_authorization_header(self, token=None): def _create_authorization_header(self, token=None):
return "Bearer {0}".format(token or self.access_token.token) return "Bearer {0}".format(token or self.access_token.token)
...@@ -569,8 +586,10 @@ class OAuth2Tests(TestCase): ...@@ -569,8 +586,10 @@ class OAuth2Tests(TestCase):
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth_url_transport(self): def test_post_form_passing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in form data succeed""" """Ensure GETing form over OAuth with correct client credentials in form data succeed"""
response = self.csrf_client.post('/oauth2-test/', response = self.csrf_client.post(
data={'access_token': self.access_token.token}) '/oauth2-test/',
data={'access_token': self.access_token.token}
)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
......
...@@ -24,7 +24,8 @@ class NestedResourceRoot(APIView): ...@@ -24,7 +24,8 @@ class NestedResourceRoot(APIView):
class NestedResourceInstance(APIView): class NestedResourceInstance(APIView):
pass pass
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^$', Root.as_view()), url(r'^$', Root.as_view()),
url(r'^resource/$', ResourceRoot.as_view()), url(r'^resource/$', ResourceRoot.as_view()),
url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()), url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()),
...@@ -40,34 +41,60 @@ class BreadcrumbTests(TestCase): ...@@ -40,34 +41,60 @@ class BreadcrumbTests(TestCase):
def test_root_breadcrumbs(self): def test_root_breadcrumbs(self):
url = '/' url = '/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/')]) self.assertEqual(
get_breadcrumbs(url),
[('Root', '/')]
)
def test_resource_root_breadcrumbs(self): def test_resource_root_breadcrumbs(self):
url = '/resource/' url = '/resource/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(
('Resource Root', '/resource/')]) get_breadcrumbs(url),
[
('Root', '/'),
('Resource Root', '/resource/')
]
)
def test_resource_instance_breadcrumbs(self): def test_resource_instance_breadcrumbs(self):
url = '/resource/123' url = '/resource/123'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(
('Resource Root', '/resource/'), get_breadcrumbs(url),
('Resource Instance', '/resource/123')]) [
('Root', '/'),
('Resource Root', '/resource/'),
('Resource Instance', '/resource/123')
]
)
def test_nested_resource_breadcrumbs(self): def test_nested_resource_breadcrumbs(self):
url = '/resource/123/' url = '/resource/123/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(
('Resource Root', '/resource/'), get_breadcrumbs(url),
('Resource Instance', '/resource/123'), [
('Nested Resource Root', '/resource/123/')]) ('Root', '/'),
('Resource Root', '/resource/'),
('Resource Instance', '/resource/123'),
('Nested Resource Root', '/resource/123/')
]
)
def test_nested_resource_instance_breadcrumbs(self): def test_nested_resource_instance_breadcrumbs(self):
url = '/resource/123/abc' url = '/resource/123/abc'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(
('Resource Root', '/resource/'), get_breadcrumbs(url),
('Resource Instance', '/resource/123'), [
('Nested Resource Root', '/resource/123/'), ('Root', '/'),
('Nested Resource Instance', '/resource/123/abc')]) ('Resource Root', '/resource/'),
('Resource Instance', '/resource/123'),
('Nested Resource Root', '/resource/123/'),
('Nested Resource Instance', '/resource/123/abc')
]
)
def test_broken_url_breadcrumbs_handled_gracefully(self): def test_broken_url_breadcrumbs_handled_gracefully(self):
url = '/foobar' url = '/foobar'
self.assertEqual(get_breadcrumbs(url), [('Root', '/')]) self.assertEqual(
get_breadcrumbs(url),
[('Root', '/')]
)
...@@ -648,7 +648,7 @@ class DecimalFieldTest(TestCase): ...@@ -648,7 +648,7 @@ class DecimalFieldTest(TestCase):
s = DecimalSerializer(data={'decimal_field': '123'}) s = DecimalSerializer(data={'decimal_field': '123'})
self.assertFalse(s.is_valid()) self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']}) self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
def test_raise_min_value(self): def test_raise_min_value(self):
""" """
...@@ -660,7 +660,7 @@ class DecimalFieldTest(TestCase): ...@@ -660,7 +660,7 @@ class DecimalFieldTest(TestCase):
s = DecimalSerializer(data={'decimal_field': '99'}) s = DecimalSerializer(data={'decimal_field': '99'})
self.assertFalse(s.is_valid()) self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']}) self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
def test_raise_max_digits(self): def test_raise_max_digits(self):
""" """
...@@ -672,7 +672,7 @@ class DecimalFieldTest(TestCase): ...@@ -672,7 +672,7 @@ class DecimalFieldTest(TestCase):
s = DecimalSerializer(data={'decimal_field': '123.456'}) s = DecimalSerializer(data={'decimal_field': '123.456'})
self.assertFalse(s.is_valid()) self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']}) self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
def test_raise_max_decimal_places(self): def test_raise_max_decimal_places(self):
""" """
...@@ -684,7 +684,7 @@ class DecimalFieldTest(TestCase): ...@@ -684,7 +684,7 @@ class DecimalFieldTest(TestCase):
s = DecimalSerializer(data={'decimal_field': '123.4567'}) s = DecimalSerializer(data={'decimal_field': '123.4567'})
self.assertFalse(s.is_valid()) self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']}) self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
def test_raise_max_whole_digits(self): def test_raise_max_whole_digits(self):
""" """
...@@ -696,7 +696,7 @@ class DecimalFieldTest(TestCase): ...@@ -696,7 +696,7 @@ class DecimalFieldTest(TestCase):
s = DecimalSerializer(data={'decimal_field': '12345.6'}) s = DecimalSerializer(data={'decimal_field': '12345.6'})
self.assertFalse(s.is_valid()) self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
class ChoiceFieldTests(TestCase): class ChoiceFieldTests(TestCase):
...@@ -729,7 +729,7 @@ class ChoiceFieldTests(TestCase): ...@@ -729,7 +729,7 @@ class ChoiceFieldTests(TestCase):
def test_invalid_choice_model(self): def test_invalid_choice_model(self):
s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'}) s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
self.assertFalse(s.is_valid()) self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']}) self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']})
self.assertEqual(s.data['choice'], '') self.assertEqual(s.data['choice'], '')
def test_empty_choice_model(self): def test_empty_choice_model(self):
...@@ -875,7 +875,7 @@ class SlugFieldTests(TestCase): ...@@ -875,7 +875,7 @@ class SlugFieldTests(TestCase):
s = SlugFieldSerializer(data={'slug_field': 'a b'}) s = SlugFieldSerializer(data={'slug_field': 'a b'})
self.assertEqual(s.is_valid(), False) self.assertEqual(s.is_valid(), False)
self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]}) self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]})
class URLFieldTests(TestCase): class URLFieldTests(TestCase):
......
...@@ -85,11 +85,8 @@ class FileSerializerTests(TestCase): ...@@ -85,11 +85,8 @@ class FileSerializerTests(TestCase):
""" """
Validation should still function when no data dictionary is provided. Validation should still function when no data dictionary is provided.
""" """
now = datetime.datetime.now() uploaded_file = BytesIO(six.b('stuff'))
file = BytesIO(six.b('stuff')) uploaded_file.name = 'stuff.txt'
file.name = 'stuff.txt' uploaded_file.size = len(uploaded_file.getvalue())
file.size = len(file.getvalue()) serializer = UploadedFileSerializer(files={'file': uploaded_file})
uploaded_file = UploadedFile(file=file, created=now)
serializer = UploadedFileSerializer(files={'file': file})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
...@@ -74,7 +74,8 @@ if django_filters: ...@@ -74,7 +74,8 @@ if django_filters:
def get_queryset(self): def get_queryset(self):
return FilterableItem.objects.all() return FilterableItem.objects.all()
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
url(r'^$', FilterClassRootView.as_view(), name='root-view'), url(r'^$', FilterClassRootView.as_view(), name='root-view'),
url(r'^get-queryset/$', GetQuerysetView.as_view(), url(r'^get-queryset/$', GetQuerysetView.as_view(),
...@@ -653,8 +654,8 @@ class SensitiveOrderingFilterTests(TestCase): ...@@ -653,8 +654,8 @@ class SensitiveOrderingFilterTests(TestCase):
self.assertEqual( self.assertEqual(
response.data, response.data,
[ [
{'id': 1, username_field: 'userA'}, # PassB {'id': 1, username_field: 'userA'}, # PassB
{'id': 2, username_field: 'userB'}, # PassC {'id': 2, username_field: 'userB'}, # PassC
{'id': 3, username_field: 'userC'}, # PassA {'id': 3, username_field: 'userC'}, # PassA
] ]
) )
...@@ -117,18 +117,18 @@ class TestGenericRelations(TestCase): ...@@ -117,18 +117,18 @@ class TestGenericRelations(TestCase):
serializer = TagSerializer(Tag.objects.all(), many=True) serializer = TagSerializer(Tag.objects.all(), many=True)
expected = [ expected = [
{ {
'tag': 'django', 'tag': 'django',
'tagged_item': 'Bookmark: https://www.djangoproject.com/' 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
}, },
{ {
'tag': 'python', 'tag': 'python',
'tagged_item': 'Bookmark: https://www.djangoproject.com/' 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
}, },
{ {
'tag': 'reminder', 'tag': 'reminder',
'tagged_item': 'Note: Remember the milk' 'tagged_item': 'Note: Remember the milk'
} }
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
......
...@@ -34,7 +34,8 @@ def not_found(request): ...@@ -34,7 +34,8 @@ def not_found(request):
raise Http404() raise Http404()
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^$', example), url(r'^$', example),
url(r'^permission_denied$', permission_denied), url(r'^permission_denied$', permission_denied),
url(r'^not_found$', not_found), url(r'^not_found$', not_found),
......
...@@ -94,7 +94,8 @@ class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): ...@@ -94,7 +94,8 @@ class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
model_serializer_class = serializers.HyperlinkedModelSerializer model_serializer_class = serializers.HyperlinkedModelSerializer
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
......
from __future__ import unicode_literals from __future__ import unicode_literals
import datetime import datetime
from decimal import Decimal from decimal import Decimal
from django.db import models
from django.core.paginator import Paginator from django.core.paginator import Paginator
from django.test import TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
...@@ -12,6 +11,7 @@ from .models import BasicModel, FilterableItem ...@@ -12,6 +11,7 @@ from .models import BasicModel, FilterableItem
factory = APIRequestFactory() factory = APIRequestFactory()
# Helper function to split arguments out of an url # Helper function to split arguments out of an url
def split_arguments_from_url(url): def split_arguments_from_url(url):
if '?' not in url: if '?' not in url:
...@@ -274,8 +274,8 @@ class TestUnpaginated(TestCase): ...@@ -274,8 +274,8 @@ class TestUnpaginated(TestCase):
BasicModel(text=i).save() BasicModel(text=i).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text} {'id': obj.id, 'text': obj.text}
for obj in self.objects.all() for obj in self.objects.all()
] ]
self.view = DefaultPageSizeKwargView.as_view() self.view = DefaultPageSizeKwargView.as_view()
...@@ -302,8 +302,8 @@ class TestCustomPaginateByParam(TestCase): ...@@ -302,8 +302,8 @@ class TestCustomPaginateByParam(TestCase):
BasicModel(text=i).save() BasicModel(text=i).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text} {'id': obj.id, 'text': obj.text}
for obj in self.objects.all() for obj in self.objects.all()
] ]
self.view = PaginateByParamView.as_view() self.view = PaginateByParamView.as_view()
...@@ -483,8 +483,6 @@ class NonIntegerPaginator(object): ...@@ -483,8 +483,6 @@ class NonIntegerPaginator(object):
class TestNonIntegerPagination(TestCase): class TestNonIntegerPagination(TestCase):
def test_custom_pagination_serializer(self): def test_custom_pagination_serializer(self):
objects = ['john', 'paul', 'george', 'ringo'] objects = ['john', 'paul', 'george', 'ringo']
paginator = NonIntegerPaginator(objects, 2) paginator = NonIntegerPaginator(objects, 2)
......
...@@ -12,6 +12,7 @@ import base64 ...@@ -12,6 +12,7 @@ import base64
factory = APIRequestFactory() factory = APIRequestFactory()
class RootView(generics.ListCreateAPIView): class RootView(generics.ListCreateAPIView):
model = BasicModel model = BasicModel
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
...@@ -101,42 +102,54 @@ class ModelPermissionsIntegrationTests(TestCase): ...@@ -101,42 +102,54 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_options_permitted(self): def test_options_permitted(self):
request = factory.options('/', request = factory.options(
HTTP_AUTHORIZATION=self.permitted_credentials) '/',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['POST']) self.assertEqual(list(response.data['actions'].keys()), ['POST'])
request = factory.options('/1', request = factory.options(
HTTP_AUTHORIZATION=self.permitted_credentials) '/1',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['PUT']) self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
def test_options_disallowed(self): def test_options_disallowed(self):
request = factory.options('/', request = factory.options(
HTTP_AUTHORIZATION=self.disallowed_credentials) '/',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) self.assertNotIn('actions', response.data)
request = factory.options('/1', request = factory.options(
HTTP_AUTHORIZATION=self.disallowed_credentials) '/1',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) self.assertNotIn('actions', response.data)
def test_options_updateonly(self): def test_options_updateonly(self):
request = factory.options('/', request = factory.options(
HTTP_AUTHORIZATION=self.updateonly_credentials) '/',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) self.assertNotIn('actions', response.data)
request = factory.options('/1', request = factory.options(
HTTP_AUTHORIZATION=self.updateonly_credentials) '/1',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) self.assertIn('actions', response.data)
...@@ -153,6 +166,7 @@ class BasicPermModel(models.Model): ...@@ -153,6 +166,7 @@ class BasicPermModel(models.Model):
# add, change, delete built in to django # add, change, delete built in to django
) )
# Custom object-level permission, that includes 'view' permissions # Custom object-level permission, that includes 'view' permissions
class ViewObjectPermissions(permissions.DjangoObjectPermissions): class ViewObjectPermissions(permissions.DjangoObjectPermissions):
perms_map = { perms_map = {
...@@ -205,7 +219,7 @@ class ObjectPermissionsIntegrationTests(TestCase): ...@@ -205,7 +219,7 @@ class ObjectPermissionsIntegrationTests(TestCase):
app_label = BasicPermModel._meta.app_label app_label = BasicPermModel._meta.app_label
f = '{0}_{1}'.format f = '{0}_{1}'.format
perms = { perms = {
'view': f('view', model_name), 'view': f('view', model_name),
'change': f('change', model_name), 'change': f('change', model_name),
'delete': f('delete', model_name) 'delete': f('delete', model_name)
} }
...@@ -246,21 +260,27 @@ class ObjectPermissionsIntegrationTests(TestCase): ...@@ -246,21 +260,27 @@ class ObjectPermissionsIntegrationTests(TestCase):
# Update # Update
def test_can_update_permissions(self): def test_can_update_permissions(self):
request = factory.patch('/1', {'text': 'foobar'}, format='json', request = factory.patch(
HTTP_AUTHORIZATION=self.credentials['writeonly']) '/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['writeonly']
)
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data.get('text'), 'foobar') self.assertEqual(response.data.get('text'), 'foobar')
def test_cannot_update_permissions(self): def test_cannot_update_permissions(self):
request = factory.patch('/1', {'text': 'foobar'}, format='json', request = factory.patch(
HTTP_AUTHORIZATION=self.credentials['deleteonly']) '/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_cannot_update_permissions_non_existing(self): def test_cannot_update_permissions_non_existing(self):
request = factory.patch('/999', {'text': 'foobar'}, format='json', request = factory.patch(
HTTP_AUTHORIZATION=self.credentials['deleteonly']) '/999', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='999') response = object_permissions_view(request, pk='999')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
......
...@@ -108,19 +108,25 @@ class RelatedFieldSourceTests(TestCase): ...@@ -108,19 +108,25 @@ class RelatedFieldSourceTests(TestCase):
doesn't exist. doesn't exist.
""" """
from tests.models import ManyToManySource from tests.models import ManyToManySource
class Meta: class Meta:
model = ManyToManySource model = ManyToManySource
attrs = { attrs = {
'name': serializers.SlugRelatedField( 'name': serializers.SlugRelatedField(
slug_field='name', source='banzai'), slug_field='name', source='banzai'),
'Meta': Meta, 'Meta': Meta,
} }
TestSerializer = type(str('TestSerializer'), TestSerializer = type(
(serializers.ModelSerializer,), attrs) str('TestSerializer'),
(serializers.ModelSerializer,),
attrs
)
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
TestSerializer(data={'name': 'foo'}) TestSerializer(data={'name': 'foo'})
@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6') @unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6')
class RelatedFieldChoicesTests(TestCase): class RelatedFieldChoicesTests(TestCase):
""" """
...@@ -141,4 +147,3 @@ class RelatedFieldChoicesTests(TestCase): ...@@ -141,4 +147,3 @@ class RelatedFieldChoicesTests(TestCase):
widget_count = len(field.widget.choices) widget_count = len(field.widget.choices)
self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added') self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added')
...@@ -16,7 +16,8 @@ request = factory.get('/') # Just to ensure we have a request in the serializer ...@@ -16,7 +16,8 @@ request = factory.get('/') # Just to ensure we have a request in the serializer
def dummy_view(request, pk): def dummy_view(request, pk):
pass pass
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
...@@ -86,9 +87,9 @@ class HyperlinkedManyToManyTests(TestCase): ...@@ -86,9 +87,9 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
...@@ -114,9 +115,9 @@ class HyperlinkedManyToManyTests(TestCase): ...@@ -114,9 +115,9 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
......
...@@ -65,9 +65,9 @@ class PKManyToManyTests(TestCase): ...@@ -65,9 +65,9 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ expected = [
{'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
...@@ -93,9 +93,9 @@ class PKManyToManyTests(TestCase): ...@@ -93,9 +93,9 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ expected = [
{'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
......
...@@ -76,7 +76,6 @@ class MockGETView(APIView): ...@@ -76,7 +76,6 @@ class MockGETView(APIView):
return Response({'foo': ['bar', 'baz']}) return Response({'foo': ['bar', 'baz']})
class MockPOSTView(APIView): class MockPOSTView(APIView):
def post(self, request, **kwargs): def post(self, request, **kwargs):
return Response({'foo': request.DATA}) return Response({'foo': request.DATA})
...@@ -102,7 +101,8 @@ class HTMLView1(APIView): ...@@ -102,7 +101,8 @@ class HTMLView1(APIView):
def get(self, request, **kwargs): def get(self, request, **kwargs):
return Response('text') return Response('text')
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^cache$', MockGETView.as_view()), url(r'^cache$', MockGETView.as_view()),
...@@ -312,16 +312,22 @@ class JSONRendererTests(TestCase): ...@@ -312,16 +312,22 @@ class JSONRendererTests(TestCase):
class Dict(MutableMapping): class Dict(MutableMapping):
def __init__(self): def __init__(self):
self._dict = dict() self._dict = dict()
def __getitem__(self, key): def __getitem__(self, key):
return self._dict.__getitem__(key) return self._dict.__getitem__(key)
def __setitem__(self, key, value): def __setitem__(self, key, value):
return self._dict.__setitem__(key, value) return self._dict.__setitem__(key, value)
def __delitem__(self, key): def __delitem__(self, key):
return self._dict.__delitem__(key) return self._dict.__delitem__(key)
def __iter__(self): def __iter__(self):
return self._dict.__iter__() return self._dict.__iter__()
def __len__(self): def __len__(self):
return self._dict.__len__() return self._dict.__len__()
def keys(self): def keys(self):
return self._dict.keys() return self._dict.keys()
...@@ -330,22 +336,24 @@ class JSONRendererTests(TestCase): ...@@ -330,22 +336,24 @@ class JSONRendererTests(TestCase):
x[2] = 3 x[2] = 3
ret = JSONRenderer().render(x) ret = JSONRenderer().render(x)
data = json.loads(ret.decode('utf-8')) data = json.loads(ret.decode('utf-8'))
self.assertEquals(data, {'key': 'string value', '2': 3}) self.assertEquals(data, {'key': 'string value', '2': 3})
def test_render_obj_with_getitem(self): def test_render_obj_with_getitem(self):
class DictLike(object): class DictLike(object):
def __init__(self): def __init__(self):
self._dict = {} self._dict = {}
def set(self, value): def set(self, value):
self._dict = dict(value) self._dict = dict(value)
def __getitem__(self, key): def __getitem__(self, key):
return self._dict[key] return self._dict[key]
x = DictLike() x = DictLike()
x.set({'a': 1, 'b': 'string'}) x.set({'a': 1, 'b': 'string'})
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
JSONRenderer().render(x) JSONRenderer().render(x)
def test_without_content_type_args(self): def test_without_content_type_args(self):
""" """
Test basic JSON rendering. Test basic JSON rendering.
...@@ -394,35 +402,47 @@ class JSONPRendererTests(TestCase): ...@@ -394,35 +402,47 @@ class JSONPRendererTests(TestCase):
""" """
Test JSONP rendering with View JSON Renderer. Test JSONP rendering with View JSON Renderer.
""" """
resp = self.client.get('/jsonp/jsonrenderer', resp = self.client.get(
HTTP_ACCEPT='application/javascript') '/jsonp/jsonrenderer',
HTTP_ACCEPT='application/javascript'
)
self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content, self.assertEqual(
('callback(%s);' % _flat_repr).encode('ascii')) resp.content,
('callback(%s);' % _flat_repr).encode('ascii')
)
def test_without_callback_without_json_renderer(self): def test_without_callback_without_json_renderer(self):
""" """
Test JSONP rendering without View JSON Renderer. Test JSONP rendering without View JSON Renderer.
""" """
resp = self.client.get('/jsonp/nojsonrenderer', resp = self.client.get(
HTTP_ACCEPT='application/javascript') '/jsonp/nojsonrenderer',
HTTP_ACCEPT='application/javascript'
)
self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content, self.assertEqual(
('callback(%s);' % _flat_repr).encode('ascii')) resp.content,
('callback(%s);' % _flat_repr).encode('ascii')
)
def test_with_callback(self): def test_with_callback(self):
""" """
Test JSONP rendering with callback function name. Test JSONP rendering with callback function name.
""" """
callback_func = 'myjsonpcallback' callback_func = 'myjsonpcallback'
resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, resp = self.client.get(
HTTP_ACCEPT='application/javascript') '/jsonp/nojsonrenderer?callback=' + callback_func,
HTTP_ACCEPT='application/javascript'
)
self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content, self.assertEqual(
('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')) resp.content,
('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')
)
if yaml: if yaml:
...@@ -467,7 +487,6 @@ if yaml: ...@@ -467,7 +487,6 @@ if yaml:
def assertYAMLContains(self, content, string): def assertYAMLContains(self, content, string):
self.assertTrue(string in content, '%r not in %r' % (string, content)) self.assertTrue(string in content, '%r not in %r' % (string, content))
class UnicodeYAMLRendererTests(TestCase): class UnicodeYAMLRendererTests(TestCase):
""" """
Tests specific for the Unicode YAML Renderer Tests specific for the Unicode YAML Renderer
...@@ -592,13 +611,13 @@ class CacheRenderTest(TestCase): ...@@ -592,13 +611,13 @@ class CacheRenderTest(TestCase):
""" Return any errors that would be raised if `obj' is pickled """ Return any errors that would be raised if `obj' is pickled
Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897 Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897
""" """
if seen == None: if seen is None:
seen = [] seen = []
try: try:
state = obj.__getstate__() state = obj.__getstate__()
except AttributeError: except AttributeError:
return return
if state == None: if state is None:
return return
if isinstance(state, tuple): if isinstance(state, tuple):
if not isinstance(state[0], dict): if not isinstance(state[0], dict):
......
...@@ -272,7 +272,8 @@ class MockView(APIView): ...@@ -272,7 +272,8 @@ class MockView(APIView):
return Response(status=status.INTERNAL_SERVER_ERROR) return Response(status=status.INTERNAL_SERVER_ERROR)
urlpatterns = patterns('', urlpatterns = patterns(
'',
(r'^$', MockView.as_view()), (r'^$', MockView.as_view()),
) )
......
...@@ -100,7 +100,8 @@ new_model_viewset_router = routers.DefaultRouter() ...@@ -100,7 +100,8 @@ new_model_viewset_router = routers.DefaultRouter()
new_model_viewset_router.register(r'', HTMLNewModelViewSet) new_model_viewset_router.register(r'', HTMLNewModelViewSet)
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])), url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
......
...@@ -10,7 +10,8 @@ factory = APIRequestFactory() ...@@ -10,7 +10,8 @@ factory = APIRequestFactory()
def null_view(request): def null_view(request):
pass pass
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^view$', null_view, name='view'), url(r'^view$', null_view, name='view'),
) )
......
...@@ -93,7 +93,8 @@ class TestCustomLookupFields(TestCase): ...@@ -93,7 +93,8 @@ class TestCustomLookupFields(TestCase):
from tests import test_routers from tests import test_routers
urls = getattr(test_routers, 'urlpatterns') urls = getattr(test_routers, 'urlpatterns')
urls += patterns('', urls += patterns(
'',
url(r'^', include(self.router.urls)), url(r'^', include(self.router.urls)),
) )
...@@ -104,7 +105,8 @@ class TestCustomLookupFields(TestCase): ...@@ -104,7 +105,8 @@ class TestCustomLookupFields(TestCase):
def test_retrieve_lookup_field_list_view(self): def test_retrieve_lookup_field_list_view(self):
response = self.client.get('/notes/') response = self.client.get('/notes/')
self.assertEqual(response.data, self.assertEqual(
response.data,
[{ [{
"url": "http://testserver/notes/123/", "url": "http://testserver/notes/123/",
"uuid": "123", "text": "foo bar" "uuid": "123", "text": "foo bar"
...@@ -113,7 +115,8 @@ class TestCustomLookupFields(TestCase): ...@@ -113,7 +115,8 @@ class TestCustomLookupFields(TestCase):
def test_retrieve_lookup_field_detail_view(self): def test_retrieve_lookup_field_detail_view(self):
response = self.client.get('/notes/123/') response = self.client.get('/notes/123/')
self.assertEqual(response.data, self.assertEqual(
response.data,
{ {
"url": "http://testserver/notes/123/", "url": "http://testserver/notes/123/",
"uuid": "123", "text": "foo bar" "uuid": "123", "text": "foo bar"
......
...@@ -7,10 +7,12 @@ from django.utils import unittest ...@@ -7,10 +7,12 @@ from django.utils import unittest
from django.utils.datastructures import MultiValueDict from django.utils.datastructures import MultiValueDict
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers, fields, relations from rest_framework import serializers, fields, relations
from tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, from tests.models import (
BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel, BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel,
ForeignKeySource, ManyToManySource) DefaultValueModel, ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo,
RESTFrameworkModel, ForeignKeySource
)
from tests.models import BasicModelSerializer from tests.models import BasicModelSerializer
import datetime import datetime
import pickle import pickle
...@@ -99,6 +101,7 @@ class ActionItemSerializer(serializers.ModelSerializer): ...@@ -99,6 +101,7 @@ class ActionItemSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ActionItem model = ActionItem
class ActionItemSerializerOptionalFields(serializers.ModelSerializer): class ActionItemSerializerOptionalFields(serializers.ModelSerializer):
""" """
Intended to test that fields with `required=False` are excluded from validation. Intended to test that fields with `required=False` are excluded from validation.
...@@ -109,6 +112,7 @@ class ActionItemSerializerOptionalFields(serializers.ModelSerializer): ...@@ -109,6 +112,7 @@ class ActionItemSerializerOptionalFields(serializers.ModelSerializer):
model = ActionItem model = ActionItem
fields = ('title',) fields = ('title',)
class ActionItemSerializerCustomRestore(serializers.ModelSerializer): class ActionItemSerializerCustomRestore(serializers.ModelSerializer):
class Meta: class Meta:
...@@ -295,8 +299,10 @@ class BasicTests(TestCase): ...@@ -295,8 +299,10 @@ class BasicTests(TestCase):
in the Meta data in the Meta data
""" """
serializer = PersonSerializer(self.person) serializer = PersonSerializer(self.person)
self.assertEqual(set(serializer.data.keys()), self.assertEqual(
set(['name', 'age', 'info'])) set(serializer.data.keys()),
set(['name', 'age', 'info'])
)
def test_field_with_dictionary(self): def test_field_with_dictionary(self):
""" """
...@@ -331,9 +337,9 @@ class BasicTests(TestCase): ...@@ -331,9 +337,9 @@ class BasicTests(TestCase):
— id field is not populated if `data` is accessed prior to `save()` — id field is not populated if `data` is accessed prior to `save()`
""" """
serializer = ActionItemSerializer(self.actionitem) serializer = ActionItemSerializer(self.actionitem)
self.assertIsNone(serializer.data.get('id',None), 'New instance. `id` should not be set.') self.assertIsNone(serializer.data.get('id', None), 'New instance. `id` should not be set.')
serializer.save() serializer.save()
self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.') self.assertIsNotNone(serializer.data.get('id', None), 'Model is saved. `id` should be set.')
def test_fields_marked_as_not_required_are_excluded_from_validation(self): def test_fields_marked_as_not_required_are_excluded_from_validation(self):
""" """
...@@ -660,10 +666,10 @@ class ModelValidationTests(TestCase): ...@@ -660,10 +666,10 @@ class ModelValidationTests(TestCase):
serializer.save() serializer.save()
second_serializer = AlbumsSerializer(data={'title': 'a'}) second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid()) self.assertFalse(second_serializer.is_valid())
self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.'],}) self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}], many=True) third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}], many=True)
self.assertFalse(third_serializer.is_valid()) self.assertFalse(third_serializer.is_valid())
self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}]) self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}])
def test_foreign_key_is_null_with_partial(self): def test_foreign_key_is_null_with_partial(self):
""" """
...@@ -959,7 +965,7 @@ class WritableFieldDefaultValueTests(TestCase): ...@@ -959,7 +965,7 @@ class WritableFieldDefaultValueTests(TestCase):
self.assertEqual(got, self.expected) self.assertEqual(got, self.expected)
def test_get_default_value_with_callable(self): def test_get_default_value_with_callable(self):
field = self.create_field(default=lambda : self.expected) field = self.create_field(default=lambda: self.expected)
got = field.get_default_value() got = field.get_default_value()
self.assertEqual(got, self.expected) self.assertEqual(got, self.expected)
...@@ -974,7 +980,7 @@ class WritableFieldDefaultValueTests(TestCase): ...@@ -974,7 +980,7 @@ class WritableFieldDefaultValueTests(TestCase):
self.assertIsNone(got) self.assertIsNone(got)
def test_get_default_value_returns_non_True_values(self): def test_get_default_value_returns_non_True_values(self):
values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause
for expected in values: for expected in values:
field = self.create_field(default=expected) field = self.create_field(default=expected)
got = field.get_default_value() got = field.get_default_value()
......
...@@ -83,9 +83,9 @@ class BulkCreateSerializerTests(TestCase): ...@@ -83,9 +83,9 @@ class BulkCreateSerializerTests(TestCase):
self.assertEqual(serializer.is_valid(), False) self.assertEqual(serializer.is_valid(), False)
expected_errors = [ expected_errors = [
{'non_field_errors': ['Invalid data']}, {'non_field_errors': ['Invalid data']},
{'non_field_errors': ['Invalid data']}, {'non_field_errors': ['Invalid data']},
{'non_field_errors': ['Invalid data']} {'non_field_errors': ['Invalid data']}
] ]
self.assertEqual(serializer.errors, expected_errors) self.assertEqual(serializer.errors, expected_errors)
......
...@@ -328,12 +328,14 @@ class NestedModelSerializerUpdateTests(TestCase): ...@@ -328,12 +328,14 @@ class NestedModelSerializerUpdateTests(TestCase):
class BlogPostSerializer(serializers.ModelSerializer): class BlogPostSerializer(serializers.ModelSerializer):
comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set') comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set')
class Meta: class Meta:
model = models.BlogPost model = models.BlogPost
fields = ('id', 'title', 'comments') fields = ('id', 'title', 'comments')
class PersonSerializer(serializers.ModelSerializer): class PersonSerializer(serializers.ModelSerializer):
posts = BlogPostSerializer(many=True, source='blogpost_set') posts = BlogPostSerializer(many=True, source='blogpost_set')
class Meta: class Meta:
model = models.Person model = models.Person
fields = ('id', 'name', 'age', 'posts') fields = ('id', 'name', 'age', 'posts')
......
from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework.compat import six
from rest_framework.serializers import _resolve_model from rest_framework.serializers import _resolve_model
from tests.models import BasicModel from tests.models import BasicModel
from rest_framework.compat import six
class ResolveModelTests(TestCase): class ResolveModelTests(TestCase):
......
...@@ -30,4 +30,4 @@ class TestStatus(TestCase): ...@@ -30,4 +30,4 @@ class TestStatus(TestCase):
self.assertFalse(is_server_error(499)) self.assertFalse(is_server_error(499))
self.assertTrue(is_server_error(500)) self.assertTrue(is_server_error(500))
self.assertTrue(is_server_error(599)) self.assertTrue(is_server_error(599))
self.assertFalse(is_server_error(600)) self.assertFalse(is_server_error(600))
\ No newline at end of file
...@@ -48,4 +48,4 @@ class Issue1386Tests(TestCase): ...@@ -48,4 +48,4 @@ class Issue1386Tests(TestCase):
self.assertEqual(i, res) self.assertEqual(i, res)
# example from issue #1386, this shouldn't raise an exception # example from issue #1386, this shouldn't raise an exception
_ = urlize_quoted_links("asdf:[/p]zxcv.com") urlize_quoted_links("asdf:[/p]zxcv.com")
...@@ -28,7 +28,8 @@ def session_view(request): ...@@ -28,7 +28,8 @@ def session_view(request):
}) })
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^view/$', view), url(r'^view/$', view),
url(r'^session-view/$', session_view), url(r'^session-view/$', session_view),
) )
...@@ -142,7 +143,8 @@ class TestAPIRequestFactory(TestCase): ...@@ -142,7 +143,8 @@ class TestAPIRequestFactory(TestCase):
assertion error. assertion error.
""" """
factory = APIRequestFactory() factory = APIRequestFactory()
self.assertRaises(AssertionError, factory.post, self.assertRaises(
AssertionError, factory.post,
path='/view/', data={'example': 1}, format='xml' path='/view/', data={'example': 1}, format='xml'
) )
......
...@@ -27,7 +27,7 @@ class NonTimeThrottle(BaseThrottle): ...@@ -27,7 +27,7 @@ class NonTimeThrottle(BaseThrottle):
if not hasattr(self.__class__, 'called'): if not hasattr(self.__class__, 'called'):
self.__class__.called = True self.__class__.called = True
return True return True
return False return False
class MockView(APIView): class MockView(APIView):
...@@ -125,36 +125,42 @@ class ThrottlingTests(TestCase): ...@@ -125,36 +125,42 @@ class ThrottlingTests(TestCase):
""" """
Ensure for second based throttles. Ensure for second based throttles.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView, self.ensure_response_header_contains_proper_throttle_field(
((0, None), MockView, (
(0, None), (0, None),
(0, None), (0, None),
(0, '1') (0, None),
)) (0, '1')
)
)
def test_minutes_fields(self): def test_minutes_fields(self):
""" """
Ensure for minute based throttles. Ensure for minute based throttles.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, self.ensure_response_header_contains_proper_throttle_field(
((0, None), MockView_MinuteThrottling, (
(0, None), (0, None),
(0, None), (0, None),
(0, '60') (0, None),
)) (0, '60')
)
)
def test_next_rate_remains_constant_if_followed(self): def test_next_rate_remains_constant_if_followed(self):
""" """
If a client follows the recommended next request rate, If a client follows the recommended next request rate,
the throttling rate should stay constant. the throttling rate should stay constant.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, self.ensure_response_header_contains_proper_throttle_field(
((0, None), MockView_MinuteThrottling, (
(20, None), (0, None),
(40, None), (20, None),
(60, None), (40, None),
(80, None) (60, None),
)) (80, None)
)
)
def test_non_time_throttle(self): def test_non_time_throttle(self):
""" """
...@@ -170,7 +176,7 @@ class ThrottlingTests(TestCase): ...@@ -170,7 +176,7 @@ class ThrottlingTests(TestCase):
self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called) self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
response = MockView_NonTimeThrottling.as_view()(request) response = MockView_NonTimeThrottling.as_view()(request)
self.assertFalse('X-Throttle-Wait-Seconds' in response) self.assertFalse('X-Throttle-Wait-Seconds' in response)
class ScopedRateThrottleTests(TestCase): class ScopedRateThrottleTests(TestCase):
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework.templatetags.rest_framework import urlize_quoted_links from rest_framework.templatetags.rest_framework import urlize_quoted_links
import sys
class URLizerTests(TestCase): class URLizerTests(TestCase):
......
[tox] [tox]
downloadcache = {toxworkdir}/cache/ downloadcache = {toxworkdir}/cache/
envlist = envlist =
flake8,
py3.4-django1.7,py3.3-django1.7,py3.2-django1.7,py2.7-django1.7, py3.4-django1.7,py3.3-django1.7,py3.2-django1.7,py2.7-django1.7,
py3.4-django1.6,py3.3-django1.6,py3.2-django1.6,py2.7-django1.6,py2.6-django1.6, py3.4-django1.6,py3.3-django1.6,py3.2-django1.6,py2.7-django1.6,py2.6-django1.6,
py3.4-django1.5,py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.6-django1.5, py3.4-django1.5,py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.6-django1.5,
py2.7-django1.4,py2.6-django1.4, py2.7-django1.4,py2.6-django1.4
[testenv] [testenv]
commands = py.test -q commands = ./runtests.py --fast
[testenv:flake8]
basepython = python2.7
deps = pytest==2.5.2
commands = ./runtests.py --lintonly
[testenv:py3.4-django1.7] [testenv:py3.4-django1.7]
basepython = python3.4 basepython = python3.4
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment