Commit 2859eaf5 by Tom Christie

request.data attribute

parent 417fe1b6
......@@ -16,7 +16,7 @@ class ObtainAuthToken(APIView):
model = Token
def post(self, request):
serializer = self.serializer_class(data=request.DATA)
serializer = self.serializer_class(data=request.data)
if serializer.is_valid():
user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user)
......
......@@ -56,7 +56,7 @@ def get_attribute(instance, attrs):
except AttributeError as exc:
try:
return instance[attr]
except (KeyError, TypeError):
except (KeyError, TypeError, AttributeError):
raise exc
return instance
......@@ -90,6 +90,7 @@ NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField'
MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `error_messages` dictionary.'
......@@ -105,9 +106,10 @@ class Field(object):
}
default_validators = []
default_empty_html = None
initial = None
def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None,
required=None, default=empty, initial=empty, source=None,
label=None, help_text=None, style=None,
error_messages=None, validators=[], allow_null=False):
self._creation_counter = Field._creation_counter
......@@ -122,13 +124,14 @@ class Field(object):
assert not (read_only and required), NOT_READ_ONLY_REQUIRED
assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
assert not (read_only and self.__class__ == Field), USE_READONLYFIELD
self.read_only = read_only
self.write_only = write_only
self.required = required
self.default = default
self.source = source
self.initial = initial
self.initial = self.initial if (initial is empty) else initial
self.label = label
self.help_text = help_text
self.style = {} if style is None else style
......@@ -146,24 +149,10 @@ class Field(object):
messages.update(error_messages or {})
self.error_messages = messages
def __new__(cls, *args, **kwargs):
"""
When a field is instantiated, we store the arguments that were used,
so that we can present a helpful representation of the object.
"""
instance = super(Field, cls).__new__(cls)
instance._args = args
instance._kwargs = kwargs
return instance
def __deepcopy__(self, memo):
args = copy.deepcopy(self._args)
kwargs = copy.deepcopy(self._kwargs)
return self.__class__(*args, **kwargs)
def bind(self, field_name, parent):
"""
Setup the context for the field instance.
Initializes the field name and parent for the field instance.
Called when a field is added to the parent serializer instance.
"""
# In order to enforce a consistent style, we error if a redundant
......@@ -244,9 +233,9 @@ class Field(object):
validated data.
"""
if data is empty:
if getattr(self.root, 'partial', False):
raise SkipField()
if self.required:
if getattr(self.root, 'partial', False):
raise SkipField()
self.fail('required')
return self.get_default()
......@@ -314,6 +303,25 @@ class Field(object):
"""
return getattr(self.root, '_context', {})
def __new__(cls, *args, **kwargs):
"""
When a field is instantiated, we store the arguments that were used,
so that we can present a helpful representation of the object.
"""
instance = super(Field, cls).__new__(cls)
instance._args = args
instance._kwargs = kwargs
return instance
def __deepcopy__(self, memo):
"""
When cloning fields we instantiate using the arguments it was
originally created with, rather than copying the complete state.
"""
args = copy.deepcopy(self._args)
kwargs = copy.deepcopy(self._kwargs)
return self.__class__(*args, **kwargs)
def __repr__(self):
"""
Fields are represented using their initial calling arguments.
......@@ -358,6 +366,7 @@ class NullBooleanField(Field):
'invalid': _('`{input}` is not a valid boolean.')
}
default_empty_html = None
initial = None
TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
NULL_VALUES = set(('n', 'N', 'null', 'Null', 'NULL', '', None))
......
......@@ -64,7 +64,7 @@ class DjangoFilterBackend(BaseFilterBackend):
filter_class = self.get_filter_class(view, queryset)
if filter_class:
return filter_class(request.QUERY_PARAMS, queryset=queryset).qs
return filter_class(request.query_params, queryset=queryset).qs
return queryset
......@@ -78,7 +78,7 @@ class SearchFilter(BaseFilterBackend):
Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited.
"""
params = request.QUERY_PARAMS.get(self.search_param, '')
params = request.query_params.get(self.search_param, '')
return params.replace(',', ' ').split()
def construct_search(self, field_name):
......@@ -121,7 +121,7 @@ class OrderingFilter(BaseFilterBackend):
the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings.
"""
params = request.QUERY_PARAMS.get(self.ordering_param)
params = request.query_params.get(self.ordering_param)
if params:
return [param.strip() for param in params.split(',')]
......
......@@ -112,7 +112,7 @@ class GenericAPIView(views.APIView):
paginator = self.paginator_class(queryset, page_size)
page_kwarg = self.kwargs.get(self.page_kwarg)
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
page_query_param = self.request.query_params.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1
try:
page_number = paginator.validate_number(page)
......@@ -166,7 +166,7 @@ class GenericAPIView(views.APIView):
if self.paginate_by_param:
try:
return strict_positive_int(
self.request.QUERY_PARAMS[self.paginate_by_param],
self.request.query_params[self.paginate_by_param],
cutoff=self.max_paginate_by
)
except (KeyError, ValueError):
......
......@@ -18,7 +18,7 @@ class CreateModelMixin(object):
Create a model instance.
"""
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA)
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
headers = self.get_success_headers(serializer.data)
......@@ -62,7 +62,7 @@ class UpdateModelMixin(object):
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data)
......@@ -95,7 +95,7 @@ class AllowPUTAsCreateMixin(object):
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object_or_none()
serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
if instance is None:
......
......@@ -38,7 +38,7 @@ class DefaultContentNegotiation(BaseContentNegotiation):
"""
# Allow URL style format override. eg. "?format=json
format_query_param = self.settings.URL_FORMAT_OVERRIDE
format = format_suffix or request.QUERY_PARAMS.get(format_query_param)
format = format_suffix or request.query_params.get(format_query_param)
if format:
renderers = self.filter_renderers(renderers, format)
......@@ -87,5 +87,5 @@ class DefaultContentNegotiation(BaseContentNegotiation):
Allows URL style accept override. eg. "?accept=application/json"
"""
header = request.META.get('HTTP_ACCEPT', '*/*')
header = request.QUERY_PARAMS.get(self.settings.URL_ACCEPT_OVERRIDE, header)
header = request.query_params.get(self.settings.URL_ACCEPT_OVERRIDE, header)
return [token.strip() for token in header.split(',')]
......@@ -120,7 +120,7 @@ class JSONPRenderer(JSONRenderer):
Determine the name of the callback to wrap around the json output.
"""
request = renderer_context.get('request', None)
params = request and request.QUERY_PARAMS or {}
params = request and request.query_params or {}
return params.get(self.callback_parameter, self.default_callback)
def render(self, data, accepted_media_type=None, renderer_context=None):
......@@ -426,7 +426,7 @@ class BrowsableAPIRenderer(BaseRenderer):
"""
if request.method == method:
try:
data = request.DATA
data = request.data
# files = request.FILES
except ParseError:
data = None
......
......@@ -4,7 +4,7 @@ The Request class is used as a wrapper around the standard request object.
The wrapped request then offers a richer API, in particular :
- content automatically parsed according to `Content-Type` header,
and available as `request.DATA`
and available as `request.data`
- full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content
"""
......@@ -13,6 +13,7 @@ from django.conf import settings
from django.http import QueryDict
from django.http.multipartparser import parse_header
from django.utils.datastructures import MultiValueDict
from django.utils.datastructures import MergeDict as DjangoMergeDict
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
from rest_framework.compat import BytesIO
......@@ -58,6 +59,15 @@ class override_method(object):
self.view.action = self.action
class MergeDict(DjangoMergeDict, dict):
"""
Using this as a workaround until the parsers API is properly
addressed in 3.1.
"""
def __init__(self, *dicts):
self.dicts = dicts
class Empty(object):
"""
Placeholder for unset attributes.
......@@ -82,6 +92,7 @@ def clone_request(request, method):
parser_context=request.parser_context)
ret._data = request._data
ret._files = request._files
ret._full_data = request._full_data
ret._content_type = request._content_type
ret._stream = request._stream
ret._method = method
......@@ -133,6 +144,7 @@ class Request(object):
self.parser_context = parser_context
self._data = Empty
self._files = Empty
self._full_data = Empty
self._method = Empty
self._content_type = Empty
self._stream = Empty
......@@ -186,13 +198,26 @@ class Request(object):
return self._stream
@property
def QUERY_PARAMS(self):
def query_params(self):
"""
More semantically correct name for request.GET.
"""
return self._request.GET
@property
def QUERY_PARAMS(self):
"""
Synonym for `.query_params`, for backwards compatibility.
"""
return self._request.GET
@property
def data(self):
if not _hasattr(self, '_full_data'):
self._load_data_and_files()
return self._full_data
@property
def DATA(self):
"""
Parses the request body and returns the data.
......@@ -272,6 +297,10 @@ class Request(object):
if not _hasattr(self, '_data'):
self._data, self._files = self._parse()
if self._files:
self._full_data = MergeDict(self._data, self._files)
else:
self._full_data = self._data
def _load_method_and_content_type(self):
"""
......@@ -333,6 +362,7 @@ class Request(object):
# At this point we're committed to parsing the request as form data.
self._data = self._request.POST
self._files = self._request.FILES
self._full_data = MergeDict(self._data, self._files)
# Method overloading - change the method and remove the param from the content.
if (
......@@ -350,7 +380,7 @@ class Request(object):
):
self._content_type = self._data[self._CONTENTTYPE_PARAM]
self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))
self._data, self._files = (Empty, Empty)
self._data, self._files, self._full_data = (Empty, Empty, Empty)
def _parse(self):
"""
......@@ -380,6 +410,7 @@ class Request(object):
# logging the request or similar.
self._data = QueryDict('', encoding=self._request._encoding)
self._files = MultiValueDict()
self._full_data = self._data
raise
# Parser classes may return the raw data, or a
......
......@@ -57,21 +57,24 @@ class BaseSerializer(Field):
def to_representation(self, instance):
raise NotImplementedError('`to_representation()` must be implemented.')
def update(self, instance, attrs):
def update(self, instance, validated_data):
raise NotImplementedError('`update()` must be implemented.')
def create(self, attrs):
def create(self, validated_data):
raise NotImplementedError('`create()` must be implemented.')
def save(self, extras=None):
attrs = self.validated_data
validated_data = self.validated_data
if extras is not None:
attrs = dict(list(attrs.items()) + list(extras.items()))
validated_data = dict(
list(validated_data.items()) +
list(extras.items())
)
if self.instance is not None:
self.update(self.instance, attrs)
self.update(self.instance, validated_data)
else:
self.instance = self.create(attrs)
self.instance = self.create(validated_data)
return self.instance
......@@ -321,12 +324,6 @@ class ListSerializer(BaseSerializer):
def create(self, attrs_list):
return [self.child.create(attrs) for attrs in attrs_list]
def save(self):
if self.instance is not None:
self.update(self.instance, self.validated_data)
self.instance = self.create(self.validated_data)
return self.instance
def __repr__(self):
return representation.list_repr(self, indent=1)
......
......@@ -9,7 +9,10 @@ import pytest
# Tests for field keyword arguments and core functionality.
# ---------------------------------------------------------
class TestFieldOptions:
class TestEmpty:
"""
Tests for `required`, `allow_null`, `allow_blank`, `default`.
"""
def test_required(self):
"""
By default a field must be included in the input.
......@@ -69,6 +72,17 @@ class TestFieldOptions:
output = field.run_validation()
assert output is 123
class TestSource:
def test_source(self):
class ExampleSerializer(serializers.Serializer):
example_field = serializers.CharField(source='other')
serializer = ExampleSerializer(data={'example_field': 'abc'})
print serializer.is_valid()
print serializer.data
assert serializer.is_valid()
assert serializer.validated_data == {'other': 'abc'}
def test_redundant_source(self):
class ExampleSerializer(serializers.Serializer):
example_field = serializers.CharField(source='example_field')
......@@ -81,6 +95,128 @@ class TestFieldOptions:
)
class TestReadOnly:
def setup(self):
class TestSerializer(serializers.Serializer):
read_only = fields.ReadOnlyField()
writable = fields.IntegerField()
self.Serializer = TestSerializer
def test_validate_read_only(self):
"""
Read-only fields should not be included in validation.
"""
data = {'read_only': 123, 'writable': 456}
serializer = self.Serializer(data=data)
assert serializer.is_valid()
assert serializer.validated_data == {'writable': 456}
def test_serialize_read_only(self):
"""
Read-only fields should be serialized.
"""
instance = {'read_only': 123, 'writable': 456}
serializer = self.Serializer(instance)
assert serializer.data == {'read_only': 123, 'writable': 456}
class TestWriteOnly:
def setup(self):
class TestSerializer(serializers.Serializer):
write_only = fields.IntegerField(write_only=True)
readable = fields.IntegerField()
self.Serializer = TestSerializer
def test_validate_write_only(self):
"""
Write-only fields should be included in validation.
"""
data = {'write_only': 123, 'readable': 456}
serializer = self.Serializer(data=data)
assert serializer.is_valid()
assert serializer.validated_data == {'write_only': 123, 'readable': 456}
def test_serialize_write_only(self):
"""
Write-only fields should not be serialized.
"""
instance = {'write_only': 123, 'readable': 456}
serializer = self.Serializer(instance)
assert serializer.data == {'readable': 456}
class TestInitial:
def setup(self):
class TestSerializer(serializers.Serializer):
initial_field = fields.IntegerField(initial=123)
blank_field = fields.IntegerField()
self.serializer = TestSerializer()
def test_initial(self):
"""
Initial values should be included when serializing a new representation.
"""
assert self.serializer.data == {
'initial_field': 123,
'blank_field': None
}
class TestLabel:
def setup(self):
class TestSerializer(serializers.Serializer):
labeled = fields.IntegerField(label='My label')
self.serializer = TestSerializer()
def test_label(self):
"""
A field's label may be set with the `label` argument.
"""
fields = self.serializer.fields
assert fields['labeled'].label == 'My label'
class TestInvalidErrorKey:
def setup(self):
class ExampleField(serializers.Field):
def to_native(self, data):
self.fail('incorrect')
self.field = ExampleField()
def test_invalid_error_key(self):
"""
If a field raises a validation error, but does not have a corresponding
error message, then raise an appropriate assertion error.
"""
with pytest.raises(AssertionError) as exc_info:
self.field.to_native(123)
expected = (
'ValidationError raised by `ExampleField`, but error key '
'`incorrect` does not exist in the `error_messages` dictionary.'
)
assert str(exc_info.value) == expected
class TestBooleanHTMLInput:
def setup(self):
class TestSerializer(serializers.Serializer):
archived = fields.BooleanField()
self.Serializer = TestSerializer
def test_empty_html_checkbox(self):
"""
HTML checkboxes do not send any value, but should be treated
as `False` by BooleanField.
"""
# This class mocks up a dictionary like object, that behaves
# as if it was returned for multipart or urlencoded data.
class MockHTMLDict(dict):
getlist = None
serializer = self.Serializer(data=MockHTMLDict())
assert serializer.is_valid()
assert serializer.validated_data == {'archived': False}
# Tests for field input and output values.
# ----------------------------------------
......@@ -495,7 +631,7 @@ class TestDateTimeField(FieldValues):
'2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()),
datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()),
datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()),
# Note that 1.4 does not support timezone string parsing.
# Django 1.4 does not support timezone string parsing.
'2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC())
}
invalid_inputs = {
......
......@@ -38,7 +38,7 @@ class FKInstanceView(generics.RetrieveUpdateDestroyAPIView):
class SlugSerializer(serializers.ModelSerializer):
slug = serializers.Field(read_only=True)
slug = serializers.ReadOnlyField()
class Meta:
model = SlugBasedModel
......
from rest_framework import serializers
import pytest
# Tests for core functionality.
......@@ -29,6 +30,67 @@ class TestSerializer:
assert serializer.validated_data == {'char': 'abc'}
assert serializer.errors == {}
def test_empty_serializer(self):
serializer = self.Serializer()
assert serializer.data == {'char': '', 'integer': None}
def test_missing_attribute_during_serialization(self):
class MissingAttributes:
pass
instance = MissingAttributes()
serializer = self.Serializer(instance)
with pytest.raises(AttributeError):
serializer.data
class TestStarredSource:
"""
Tests for `source='*'` argument, which is used for nested representations.
For example:
nested_field = NestedField(source='*')
"""
data = {
'nested1': {'a': 1, 'b': 2},
'nested2': {'c': 3, 'd': 4}
}
def setup(self):
class NestedSerializer1(serializers.Serializer):
a = serializers.IntegerField()
b = serializers.IntegerField()
class NestedSerializer2(serializers.Serializer):
c = serializers.IntegerField()
d = serializers.IntegerField()
class TestSerializer(serializers.Serializer):
nested1 = NestedSerializer1(source='*')
nested2 = NestedSerializer2(source='*')
self.Serializer = TestSerializer
def test_nested_validate(self):
"""
A nested representation is validated into a flat internal object.
"""
serializer = self.Serializer(data=self.data)
assert serializer.is_valid()
assert serializer.validated_data == {
'a': 1,
'b': 2,
'c': 3,
'd': 4
}
def test_nested_serialize(self):
"""
An object can be serialized into a nested representation.
"""
instance = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
serializer = self.Serializer(instance)
assert serializer.data == self.data
# # -*- coding: utf-8 -*-
# from __future__ import unicode_literals
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment