Commit 2859eaf5 by Tom Christie

request.data attribute

parent 417fe1b6
...@@ -16,7 +16,7 @@ class ObtainAuthToken(APIView): ...@@ -16,7 +16,7 @@ class ObtainAuthToken(APIView):
model = Token model = Token
def post(self, request): def post(self, request):
serializer = self.serializer_class(data=request.DATA) serializer = self.serializer_class(data=request.data)
if serializer.is_valid(): if serializer.is_valid():
user = serializer.validated_data['user'] user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user) token, created = Token.objects.get_or_create(user=user)
......
...@@ -56,7 +56,7 @@ def get_attribute(instance, attrs): ...@@ -56,7 +56,7 @@ def get_attribute(instance, attrs):
except AttributeError as exc: except AttributeError as exc:
try: try:
return instance[attr] return instance[attr]
except (KeyError, TypeError): except (KeyError, TypeError, AttributeError):
raise exc raise exc
return instance return instance
...@@ -90,6 +90,7 @@ NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' ...@@ -90,6 +90,7 @@ NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField'
MISSING_ERROR_MESSAGE = ( MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does ' 'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `error_messages` dictionary.' 'not exist in the `error_messages` dictionary.'
...@@ -105,9 +106,10 @@ class Field(object): ...@@ -105,9 +106,10 @@ class Field(object):
} }
default_validators = [] default_validators = []
default_empty_html = None default_empty_html = None
initial = None
def __init__(self, read_only=False, write_only=False, def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None, required=None, default=empty, initial=empty, source=None,
label=None, help_text=None, style=None, label=None, help_text=None, style=None,
error_messages=None, validators=[], allow_null=False): error_messages=None, validators=[], allow_null=False):
self._creation_counter = Field._creation_counter self._creation_counter = Field._creation_counter
...@@ -122,13 +124,14 @@ class Field(object): ...@@ -122,13 +124,14 @@ class Field(object):
assert not (read_only and required), NOT_READ_ONLY_REQUIRED assert not (read_only and required), NOT_READ_ONLY_REQUIRED
assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
assert not (required and default is not empty), NOT_REQUIRED_DEFAULT assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
assert not (read_only and self.__class__ == Field), USE_READONLYFIELD
self.read_only = read_only self.read_only = read_only
self.write_only = write_only self.write_only = write_only
self.required = required self.required = required
self.default = default self.default = default
self.source = source self.source = source
self.initial = initial self.initial = self.initial if (initial is empty) else initial
self.label = label self.label = label
self.help_text = help_text self.help_text = help_text
self.style = {} if style is None else style self.style = {} if style is None else style
...@@ -146,24 +149,10 @@ class Field(object): ...@@ -146,24 +149,10 @@ class Field(object):
messages.update(error_messages or {}) messages.update(error_messages or {})
self.error_messages = messages self.error_messages = messages
def __new__(cls, *args, **kwargs):
"""
When a field is instantiated, we store the arguments that were used,
so that we can present a helpful representation of the object.
"""
instance = super(Field, cls).__new__(cls)
instance._args = args
instance._kwargs = kwargs
return instance
def __deepcopy__(self, memo):
args = copy.deepcopy(self._args)
kwargs = copy.deepcopy(self._kwargs)
return self.__class__(*args, **kwargs)
def bind(self, field_name, parent): def bind(self, field_name, parent):
""" """
Setup the context for the field instance. Initializes the field name and parent for the field instance.
Called when a field is added to the parent serializer instance.
""" """
# In order to enforce a consistent style, we error if a redundant # In order to enforce a consistent style, we error if a redundant
...@@ -244,9 +233,9 @@ class Field(object): ...@@ -244,9 +233,9 @@ class Field(object):
validated data. validated data.
""" """
if data is empty: if data is empty:
if getattr(self.root, 'partial', False):
raise SkipField()
if self.required: if self.required:
if getattr(self.root, 'partial', False):
raise SkipField()
self.fail('required') self.fail('required')
return self.get_default() return self.get_default()
...@@ -314,6 +303,25 @@ class Field(object): ...@@ -314,6 +303,25 @@ class Field(object):
""" """
return getattr(self.root, '_context', {}) return getattr(self.root, '_context', {})
def __new__(cls, *args, **kwargs):
"""
When a field is instantiated, we store the arguments that were used,
so that we can present a helpful representation of the object.
"""
instance = super(Field, cls).__new__(cls)
instance._args = args
instance._kwargs = kwargs
return instance
def __deepcopy__(self, memo):
"""
When cloning fields we instantiate using the arguments it was
originally created with, rather than copying the complete state.
"""
args = copy.deepcopy(self._args)
kwargs = copy.deepcopy(self._kwargs)
return self.__class__(*args, **kwargs)
def __repr__(self): def __repr__(self):
""" """
Fields are represented using their initial calling arguments. Fields are represented using their initial calling arguments.
...@@ -358,6 +366,7 @@ class NullBooleanField(Field): ...@@ -358,6 +366,7 @@ class NullBooleanField(Field):
'invalid': _('`{input}` is not a valid boolean.') 'invalid': _('`{input}` is not a valid boolean.')
} }
default_empty_html = None default_empty_html = None
initial = None
TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
NULL_VALUES = set(('n', 'N', 'null', 'Null', 'NULL', '', None)) NULL_VALUES = set(('n', 'N', 'null', 'Null', 'NULL', '', None))
......
...@@ -64,7 +64,7 @@ class DjangoFilterBackend(BaseFilterBackend): ...@@ -64,7 +64,7 @@ class DjangoFilterBackend(BaseFilterBackend):
filter_class = self.get_filter_class(view, queryset) filter_class = self.get_filter_class(view, queryset)
if filter_class: if filter_class:
return filter_class(request.QUERY_PARAMS, queryset=queryset).qs return filter_class(request.query_params, queryset=queryset).qs
return queryset return queryset
...@@ -78,7 +78,7 @@ class SearchFilter(BaseFilterBackend): ...@@ -78,7 +78,7 @@ class SearchFilter(BaseFilterBackend):
Search terms are set by a ?search=... query parameter, Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited. and may be comma and/or whitespace delimited.
""" """
params = request.QUERY_PARAMS.get(self.search_param, '') params = request.query_params.get(self.search_param, '')
return params.replace(',', ' ').split() return params.replace(',', ' ').split()
def construct_search(self, field_name): def construct_search(self, field_name):
...@@ -121,7 +121,7 @@ class OrderingFilter(BaseFilterBackend): ...@@ -121,7 +121,7 @@ class OrderingFilter(BaseFilterBackend):
the `ordering_param` value on the OrderingFilter or by the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings. specifying an `ORDERING_PARAM` value in the API settings.
""" """
params = request.QUERY_PARAMS.get(self.ordering_param) params = request.query_params.get(self.ordering_param)
if params: if params:
return [param.strip() for param in params.split(',')] return [param.strip() for param in params.split(',')]
......
...@@ -112,7 +112,7 @@ class GenericAPIView(views.APIView): ...@@ -112,7 +112,7 @@ class GenericAPIView(views.APIView):
paginator = self.paginator_class(queryset, page_size) paginator = self.paginator_class(queryset, page_size)
page_kwarg = self.kwargs.get(self.page_kwarg) page_kwarg = self.kwargs.get(self.page_kwarg)
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) page_query_param = self.request.query_params.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1 page = page_kwarg or page_query_param or 1
try: try:
page_number = paginator.validate_number(page) page_number = paginator.validate_number(page)
...@@ -166,7 +166,7 @@ class GenericAPIView(views.APIView): ...@@ -166,7 +166,7 @@ class GenericAPIView(views.APIView):
if self.paginate_by_param: if self.paginate_by_param:
try: try:
return strict_positive_int( return strict_positive_int(
self.request.QUERY_PARAMS[self.paginate_by_param], self.request.query_params[self.paginate_by_param],
cutoff=self.max_paginate_by cutoff=self.max_paginate_by
) )
except (KeyError, ValueError): except (KeyError, ValueError):
......
...@@ -18,7 +18,7 @@ class CreateModelMixin(object): ...@@ -18,7 +18,7 @@ class CreateModelMixin(object):
Create a model instance. Create a model instance.
""" """
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA) serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
serializer.save() serializer.save()
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
...@@ -62,7 +62,7 @@ class UpdateModelMixin(object): ...@@ -62,7 +62,7 @@ class UpdateModelMixin(object):
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False) partial = kwargs.pop('partial', False)
instance = self.get_object() instance = self.get_object()
serializer = self.get_serializer(instance, data=request.DATA, partial=partial) serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
serializer.save() serializer.save()
return Response(serializer.data) return Response(serializer.data)
...@@ -95,7 +95,7 @@ class AllowPUTAsCreateMixin(object): ...@@ -95,7 +95,7 @@ class AllowPUTAsCreateMixin(object):
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False) partial = kwargs.pop('partial', False)
instance = self.get_object_or_none() instance = self.get_object_or_none()
serializer = self.get_serializer(instance, data=request.DATA, partial=partial) serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
if instance is None: if instance is None:
......
...@@ -38,7 +38,7 @@ class DefaultContentNegotiation(BaseContentNegotiation): ...@@ -38,7 +38,7 @@ class DefaultContentNegotiation(BaseContentNegotiation):
""" """
# Allow URL style format override. eg. "?format=json # Allow URL style format override. eg. "?format=json
format_query_param = self.settings.URL_FORMAT_OVERRIDE format_query_param = self.settings.URL_FORMAT_OVERRIDE
format = format_suffix or request.QUERY_PARAMS.get(format_query_param) format = format_suffix or request.query_params.get(format_query_param)
if format: if format:
renderers = self.filter_renderers(renderers, format) renderers = self.filter_renderers(renderers, format)
...@@ -87,5 +87,5 @@ class DefaultContentNegotiation(BaseContentNegotiation): ...@@ -87,5 +87,5 @@ class DefaultContentNegotiation(BaseContentNegotiation):
Allows URL style accept override. eg. "?accept=application/json" Allows URL style accept override. eg. "?accept=application/json"
""" """
header = request.META.get('HTTP_ACCEPT', '*/*') header = request.META.get('HTTP_ACCEPT', '*/*')
header = request.QUERY_PARAMS.get(self.settings.URL_ACCEPT_OVERRIDE, header) header = request.query_params.get(self.settings.URL_ACCEPT_OVERRIDE, header)
return [token.strip() for token in header.split(',')] return [token.strip() for token in header.split(',')]
...@@ -120,7 +120,7 @@ class JSONPRenderer(JSONRenderer): ...@@ -120,7 +120,7 @@ class JSONPRenderer(JSONRenderer):
Determine the name of the callback to wrap around the json output. Determine the name of the callback to wrap around the json output.
""" """
request = renderer_context.get('request', None) request = renderer_context.get('request', None)
params = request and request.QUERY_PARAMS or {} params = request and request.query_params or {}
return params.get(self.callback_parameter, self.default_callback) return params.get(self.callback_parameter, self.default_callback)
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
...@@ -426,7 +426,7 @@ class BrowsableAPIRenderer(BaseRenderer): ...@@ -426,7 +426,7 @@ class BrowsableAPIRenderer(BaseRenderer):
""" """
if request.method == method: if request.method == method:
try: try:
data = request.DATA data = request.data
# files = request.FILES # files = request.FILES
except ParseError: except ParseError:
data = None data = None
......
...@@ -4,7 +4,7 @@ The Request class is used as a wrapper around the standard request object. ...@@ -4,7 +4,7 @@ The Request class is used as a wrapper around the standard request object.
The wrapped request then offers a richer API, in particular : The wrapped request then offers a richer API, in particular :
- content automatically parsed according to `Content-Type` header, - content automatically parsed according to `Content-Type` header,
and available as `request.DATA` and available as `request.data`
- full support of PUT method, including support for file uploads - full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content - form overloading of HTTP method, content type and content
""" """
...@@ -13,6 +13,7 @@ from django.conf import settings ...@@ -13,6 +13,7 @@ from django.conf import settings
from django.http import QueryDict from django.http import QueryDict
from django.http.multipartparser import parse_header from django.http.multipartparser import parse_header
from django.utils.datastructures import MultiValueDict from django.utils.datastructures import MultiValueDict
from django.utils.datastructures import MergeDict as DjangoMergeDict
from rest_framework import HTTP_HEADER_ENCODING from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.compat import BytesIO from rest_framework.compat import BytesIO
...@@ -58,6 +59,15 @@ class override_method(object): ...@@ -58,6 +59,15 @@ class override_method(object):
self.view.action = self.action self.view.action = self.action
class MergeDict(DjangoMergeDict, dict):
"""
Using this as a workaround until the parsers API is properly
addressed in 3.1.
"""
def __init__(self, *dicts):
self.dicts = dicts
class Empty(object): class Empty(object):
""" """
Placeholder for unset attributes. Placeholder for unset attributes.
...@@ -82,6 +92,7 @@ def clone_request(request, method): ...@@ -82,6 +92,7 @@ def clone_request(request, method):
parser_context=request.parser_context) parser_context=request.parser_context)
ret._data = request._data ret._data = request._data
ret._files = request._files ret._files = request._files
ret._full_data = request._full_data
ret._content_type = request._content_type ret._content_type = request._content_type
ret._stream = request._stream ret._stream = request._stream
ret._method = method ret._method = method
...@@ -133,6 +144,7 @@ class Request(object): ...@@ -133,6 +144,7 @@ class Request(object):
self.parser_context = parser_context self.parser_context = parser_context
self._data = Empty self._data = Empty
self._files = Empty self._files = Empty
self._full_data = Empty
self._method = Empty self._method = Empty
self._content_type = Empty self._content_type = Empty
self._stream = Empty self._stream = Empty
...@@ -186,13 +198,26 @@ class Request(object): ...@@ -186,13 +198,26 @@ class Request(object):
return self._stream return self._stream
@property @property
def QUERY_PARAMS(self): def query_params(self):
""" """
More semantically correct name for request.GET. More semantically correct name for request.GET.
""" """
return self._request.GET return self._request.GET
@property @property
def QUERY_PARAMS(self):
"""
Synonym for `.query_params`, for backwards compatibility.
"""
return self._request.GET
@property
def data(self):
if not _hasattr(self, '_full_data'):
self._load_data_and_files()
return self._full_data
@property
def DATA(self): def DATA(self):
""" """
Parses the request body and returns the data. Parses the request body and returns the data.
...@@ -272,6 +297,10 @@ class Request(object): ...@@ -272,6 +297,10 @@ class Request(object):
if not _hasattr(self, '_data'): if not _hasattr(self, '_data'):
self._data, self._files = self._parse() self._data, self._files = self._parse()
if self._files:
self._full_data = MergeDict(self._data, self._files)
else:
self._full_data = self._data
def _load_method_and_content_type(self): def _load_method_and_content_type(self):
""" """
...@@ -333,6 +362,7 @@ class Request(object): ...@@ -333,6 +362,7 @@ class Request(object):
# At this point we're committed to parsing the request as form data. # At this point we're committed to parsing the request as form data.
self._data = self._request.POST self._data = self._request.POST
self._files = self._request.FILES self._files = self._request.FILES
self._full_data = MergeDict(self._data, self._files)
# Method overloading - change the method and remove the param from the content. # Method overloading - change the method and remove the param from the content.
if ( if (
...@@ -350,7 +380,7 @@ class Request(object): ...@@ -350,7 +380,7 @@ class Request(object):
): ):
self._content_type = self._data[self._CONTENTTYPE_PARAM] self._content_type = self._data[self._CONTENTTYPE_PARAM]
self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding'])) self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))
self._data, self._files = (Empty, Empty) self._data, self._files, self._full_data = (Empty, Empty, Empty)
def _parse(self): def _parse(self):
""" """
...@@ -380,6 +410,7 @@ class Request(object): ...@@ -380,6 +410,7 @@ class Request(object):
# logging the request or similar. # logging the request or similar.
self._data = QueryDict('', encoding=self._request._encoding) self._data = QueryDict('', encoding=self._request._encoding)
self._files = MultiValueDict() self._files = MultiValueDict()
self._full_data = self._data
raise raise
# Parser classes may return the raw data, or a # Parser classes may return the raw data, or a
......
...@@ -57,21 +57,24 @@ class BaseSerializer(Field): ...@@ -57,21 +57,24 @@ class BaseSerializer(Field):
def to_representation(self, instance): def to_representation(self, instance):
raise NotImplementedError('`to_representation()` must be implemented.') raise NotImplementedError('`to_representation()` must be implemented.')
def update(self, instance, attrs): def update(self, instance, validated_data):
raise NotImplementedError('`update()` must be implemented.') raise NotImplementedError('`update()` must be implemented.')
def create(self, attrs): def create(self, validated_data):
raise NotImplementedError('`create()` must be implemented.') raise NotImplementedError('`create()` must be implemented.')
def save(self, extras=None): def save(self, extras=None):
attrs = self.validated_data validated_data = self.validated_data
if extras is not None: if extras is not None:
attrs = dict(list(attrs.items()) + list(extras.items())) validated_data = dict(
list(validated_data.items()) +
list(extras.items())
)
if self.instance is not None: if self.instance is not None:
self.update(self.instance, attrs) self.update(self.instance, validated_data)
else: else:
self.instance = self.create(attrs) self.instance = self.create(validated_data)
return self.instance return self.instance
...@@ -321,12 +324,6 @@ class ListSerializer(BaseSerializer): ...@@ -321,12 +324,6 @@ class ListSerializer(BaseSerializer):
def create(self, attrs_list): def create(self, attrs_list):
return [self.child.create(attrs) for attrs in attrs_list] return [self.child.create(attrs) for attrs in attrs_list]
def save(self):
if self.instance is not None:
self.update(self.instance, self.validated_data)
self.instance = self.create(self.validated_data)
return self.instance
def __repr__(self): def __repr__(self):
return representation.list_repr(self, indent=1) return representation.list_repr(self, indent=1)
......
...@@ -9,7 +9,10 @@ import pytest ...@@ -9,7 +9,10 @@ import pytest
# Tests for field keyword arguments and core functionality. # Tests for field keyword arguments and core functionality.
# --------------------------------------------------------- # ---------------------------------------------------------
class TestFieldOptions: class TestEmpty:
"""
Tests for `required`, `allow_null`, `allow_blank`, `default`.
"""
def test_required(self): def test_required(self):
""" """
By default a field must be included in the input. By default a field must be included in the input.
...@@ -69,6 +72,17 @@ class TestFieldOptions: ...@@ -69,6 +72,17 @@ class TestFieldOptions:
output = field.run_validation() output = field.run_validation()
assert output is 123 assert output is 123
class TestSource:
def test_source(self):
class ExampleSerializer(serializers.Serializer):
example_field = serializers.CharField(source='other')
serializer = ExampleSerializer(data={'example_field': 'abc'})
print serializer.is_valid()
print serializer.data
assert serializer.is_valid()
assert serializer.validated_data == {'other': 'abc'}
def test_redundant_source(self): def test_redundant_source(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
example_field = serializers.CharField(source='example_field') example_field = serializers.CharField(source='example_field')
...@@ -81,6 +95,128 @@ class TestFieldOptions: ...@@ -81,6 +95,128 @@ class TestFieldOptions:
) )
class TestReadOnly:
def setup(self):
class TestSerializer(serializers.Serializer):
read_only = fields.ReadOnlyField()
writable = fields.IntegerField()
self.Serializer = TestSerializer
def test_validate_read_only(self):
"""
Read-only fields should not be included in validation.
"""
data = {'read_only': 123, 'writable': 456}
serializer = self.Serializer(data=data)
assert serializer.is_valid()
assert serializer.validated_data == {'writable': 456}
def test_serialize_read_only(self):
"""
Read-only fields should be serialized.
"""
instance = {'read_only': 123, 'writable': 456}
serializer = self.Serializer(instance)
assert serializer.data == {'read_only': 123, 'writable': 456}
class TestWriteOnly:
def setup(self):
class TestSerializer(serializers.Serializer):
write_only = fields.IntegerField(write_only=True)
readable = fields.IntegerField()
self.Serializer = TestSerializer
def test_validate_write_only(self):
"""
Write-only fields should be included in validation.
"""
data = {'write_only': 123, 'readable': 456}
serializer = self.Serializer(data=data)
assert serializer.is_valid()
assert serializer.validated_data == {'write_only': 123, 'readable': 456}
def test_serialize_write_only(self):
"""
Write-only fields should not be serialized.
"""
instance = {'write_only': 123, 'readable': 456}
serializer = self.Serializer(instance)
assert serializer.data == {'readable': 456}
class TestInitial:
def setup(self):
class TestSerializer(serializers.Serializer):
initial_field = fields.IntegerField(initial=123)
blank_field = fields.IntegerField()
self.serializer = TestSerializer()
def test_initial(self):
"""
Initial values should be included when serializing a new representation.
"""
assert self.serializer.data == {
'initial_field': 123,
'blank_field': None
}
class TestLabel:
def setup(self):
class TestSerializer(serializers.Serializer):
labeled = fields.IntegerField(label='My label')
self.serializer = TestSerializer()
def test_label(self):
"""
A field's label may be set with the `label` argument.
"""
fields = self.serializer.fields
assert fields['labeled'].label == 'My label'
class TestInvalidErrorKey:
def setup(self):
class ExampleField(serializers.Field):
def to_native(self, data):
self.fail('incorrect')
self.field = ExampleField()
def test_invalid_error_key(self):
"""
If a field raises a validation error, but does not have a corresponding
error message, then raise an appropriate assertion error.
"""
with pytest.raises(AssertionError) as exc_info:
self.field.to_native(123)
expected = (
'ValidationError raised by `ExampleField`, but error key '
'`incorrect` does not exist in the `error_messages` dictionary.'
)
assert str(exc_info.value) == expected
class TestBooleanHTMLInput:
def setup(self):
class TestSerializer(serializers.Serializer):
archived = fields.BooleanField()
self.Serializer = TestSerializer
def test_empty_html_checkbox(self):
"""
HTML checkboxes do not send any value, but should be treated
as `False` by BooleanField.
"""
# This class mocks up a dictionary like object, that behaves
# as if it was returned for multipart or urlencoded data.
class MockHTMLDict(dict):
getlist = None
serializer = self.Serializer(data=MockHTMLDict())
assert serializer.is_valid()
assert serializer.validated_data == {'archived': False}
# Tests for field input and output values. # Tests for field input and output values.
# ---------------------------------------- # ----------------------------------------
...@@ -495,7 +631,7 @@ class TestDateTimeField(FieldValues): ...@@ -495,7 +631,7 @@ class TestDateTimeField(FieldValues):
'2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()),
datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()),
datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()),
# Note that 1.4 does not support timezone string parsing. # Django 1.4 does not support timezone string parsing.
'2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()) '2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC())
} }
invalid_inputs = { invalid_inputs = {
......
...@@ -38,7 +38,7 @@ class FKInstanceView(generics.RetrieveUpdateDestroyAPIView): ...@@ -38,7 +38,7 @@ class FKInstanceView(generics.RetrieveUpdateDestroyAPIView):
class SlugSerializer(serializers.ModelSerializer): class SlugSerializer(serializers.ModelSerializer):
slug = serializers.Field(read_only=True) slug = serializers.ReadOnlyField()
class Meta: class Meta:
model = SlugBasedModel model = SlugBasedModel
......
from rest_framework import serializers from rest_framework import serializers
import pytest
# Tests for core functionality. # Tests for core functionality.
...@@ -29,6 +30,67 @@ class TestSerializer: ...@@ -29,6 +30,67 @@ class TestSerializer:
assert serializer.validated_data == {'char': 'abc'} assert serializer.validated_data == {'char': 'abc'}
assert serializer.errors == {} assert serializer.errors == {}
def test_empty_serializer(self):
serializer = self.Serializer()
assert serializer.data == {'char': '', 'integer': None}
def test_missing_attribute_during_serialization(self):
class MissingAttributes:
pass
instance = MissingAttributes()
serializer = self.Serializer(instance)
with pytest.raises(AttributeError):
serializer.data
class TestStarredSource:
"""
Tests for `source='*'` argument, which is used for nested representations.
For example:
nested_field = NestedField(source='*')
"""
data = {
'nested1': {'a': 1, 'b': 2},
'nested2': {'c': 3, 'd': 4}
}
def setup(self):
class NestedSerializer1(serializers.Serializer):
a = serializers.IntegerField()
b = serializers.IntegerField()
class NestedSerializer2(serializers.Serializer):
c = serializers.IntegerField()
d = serializers.IntegerField()
class TestSerializer(serializers.Serializer):
nested1 = NestedSerializer1(source='*')
nested2 = NestedSerializer2(source='*')
self.Serializer = TestSerializer
def test_nested_validate(self):
"""
A nested representation is validated into a flat internal object.
"""
serializer = self.Serializer(data=self.data)
assert serializer.is_valid()
assert serializer.validated_data == {
'a': 1,
'b': 2,
'c': 3,
'd': 4
}
def test_nested_serialize(self):
"""
An object can be serialized into a nested representation.
"""
instance = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
serializer = self.Serializer(instance)
assert serializer.data == self.data
# # -*- coding: utf-8 -*- # # -*- coding: utf-8 -*-
# from __future__ import unicode_literals # from __future__ import unicode_literals
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment