Commit 4ac4676a by Tom Christie

First pass

parent 371d30aa
""" from rest_framework.utils import html
Serializer fields perform validation on incoming data.
They are very similar to Django's form fields.
"""
from __future__ import unicode_literals
import copy class empty:
import datetime
import inspect
import re
import warnings
from decimal import Decimal, DecimalException
from django import forms
from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db.models.fields import BLANK_CHOICE_DASH
from django.http import QueryDict
from django.forms import widgets
from django.utils import six, timezone
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
from django.utils.datastructures import SortedDict
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from rest_framework import ISO_8601
from rest_framework.compat import (
BytesIO, smart_text,
force_text, is_non_str_iterable
)
from rest_framework.settings import api_settings
def is_simple_callable(obj):
""" """
True if the object is a callable that takes no arguments. This class is used to represent no data being provided for a given input
""" or output value.
function = inspect.isfunction(obj)
method = inspect.ismethod(obj)
if not (function or method):
return False
args, _, _, defaults = inspect.getargspec(obj) It is required because `None` may be a valid input or output value.
len_args = len(args) if function else len(args) - 1 """
len_defaults = len(defaults) if defaults else 0 pass
return len_args <= len_defaults
def get_component(obj, attr_name): def get_attribute(instance, attrs):
""" """
Given an object, and an attribute name, Similar to Python's built in `getattr(instance, attr)`,
return that attribute on the object. but takes a list of nested attributes, instead of a single attribute.
""" """
if isinstance(obj, dict): for attr in attrs:
val = obj.get(attr_name) instance = getattr(instance, attr)
else: return instance
val = getattr(obj, attr_name)
if is_simple_callable(val):
return val()
return val
def readable_datetime_formats(formats):
format = ', '.join(formats).replace(
ISO_8601,
'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
)
return humanize_strptime(format)
def readable_date_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
return humanize_strptime(format)
def readable_time_formats(formats): def set_value(dictionary, keys, value):
format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
return humanize_strptime(format)
def humanize_strptime(format_string):
# Note that we're missing some of the locale specific mappings that
# don't really make sense.
mapping = {
"%Y": "YYYY",
"%y": "YY",
"%m": "MM",
"%b": "[Jan-Dec]",
"%B": "[January-December]",
"%d": "DD",
"%H": "hh",
"%I": "hh", # Requires '%p' to differentiate from '%H'.
"%M": "mm",
"%S": "ss",
"%f": "uuuuuu",
"%a": "[Mon-Sun]",
"%A": "[Monday-Sunday]",
"%p": "[AM|PM]",
"%z": "[+HHMM|-HHMM]"
}
for key, val in mapping.items():
format_string = format_string.replace(key, val)
return format_string
def strip_multiple_choice_msg(help_text):
""" """
Remove the 'Hold down "control" ...' message that is Django enforces in Similar to Python's built in `dictionary[key] = value`,
select multiple fields on ModelForms. (Required for 1.5 and earlier) but takes a list of nested keys instead of a single key.
See https://code.djangoproject.com/ticket/9321 set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2}
set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2}
set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}
""" """
multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.') if not keys:
multiple_choice_msg = force_text(multiple_choice_msg) dictionary.update(value)
return
return help_text.replace(multiple_choice_msg, '') for key in keys[:-1]:
if key not in dictionary:
dictionary[key] = {}
dictionary = dictionary[key]
dictionary[keys[-1]] = value
class Field(object):
read_only = True
creation_counter = 0
empty = ''
type_name = None
partial = False
use_files = False
form_field_class = forms.CharField
type_label = 'field'
widget = None
def __init__(self, source=None, label=None, help_text=None): class ValidationError(Exception):
self.parent = None pass
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
self.source = source class SkipField(Exception):
pass
if label is not None:
self.label = smart_text(label)
else:
self.label = None
if help_text is not None: class Field(object):
self.help_text = strip_multiple_choice_msg(smart_text(help_text)) _creation_counter = 0
else:
self.help_text = None
self._errors = [] MESSAGES = {
self._value = None 'required': 'This field is required.'
self._name = None }
@property _NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
def errors(self): _NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
return self._errors _NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
_NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
_MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `MESSAGES` dictionary.'
)
def widget_html(self): def __init__(self, read_only=False, write_only=False,
if not self.widget: required=None, default=empty, initial=None, source=None,
return '' label=None, style=None):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
attrs = {} # If `required` is unset, then use `True` unless a default is provided.
if 'id' not in self.widget.attrs: if required is None:
attrs['id'] = self._name required = default is empty and not read_only
return self.widget.render(self._name, self._value, attrs=attrs) # Some combinations of keyword arguments do not make sense.
assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY
assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED
assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT
assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT
def label_tag(self): self.read_only = read_only
return '<label for="%s">%s:</label>' % (self._name, self.label) self.write_only = write_only
self.required = required
self.default = default
self.source = source
self.initial = initial
self.label = label
self.style = {} if style is None else style
def initialize(self, parent, field_name): def bind(self, field_name, parent, root):
""" """
Called to set up a field prior to field_to_native or field_from_native. Setup the context for the field instance.
parent - The parent serializer.
field_name - The name of the field being initialized.
""" """
self.field_name = field_name
self.parent = parent self.parent = parent
self.root = parent.root or parent self.root = root
self.context = self.root.context
self.partial = self.root.partial
if self.partial:
self.required = False
def field_from_native(self, data, files, field_name, into): # `self.label` should deafult to being based on the field name.
""" if self.label is None:
Given a dictionary and a field name, updates the dictionary `into`, self.label = self.field_name.replace('_', ' ').capitalize()
with the field and it's deserialized value.
"""
return
def field_to_native(self, obj, field_name): # self.source should default to being the same as the field name.
""" if self.source is None:
Given an object and a field name, returns the value that should be self.source = field_name
serialized for that field.
"""
if obj is None:
return self.empty
# self.source_attrs is a list of attributes that need to be looked up
# when serializing the instance, or populating the validated data.
if self.source == '*': if self.source == '*':
return self.to_native(obj) self.source_attrs = []
else:
source = self.source or field_name self.source_attrs = self.source.split('.')
value = obj
for component in source.split('.'):
value = get_component(value, component)
if value is None:
break
return self.to_native(value)
def to_native(self, value): def get_initial(self):
""" """
Converts the field's value into it's simple representation. Return a value to use when the field is being returned as a primative
value, without any object instance.
""" """
if is_simple_callable(value): return self.initial
value = value()
if is_protected_type(value):
return value
elif (is_non_str_iterable(value) and
not isinstance(value, (dict, six.string_types))):
return [self.to_native(item) for item in value]
elif isinstance(value, dict):
# Make sure we preserve field ordering, if it exists
ret = SortedDict()
for key, val in value.items():
ret[key] = self.to_native(val)
return ret
return force_text(value)
def attributes(self): def get_value(self, dictionary):
""" """
Returns a dictionary of attributes to be used when serializing to xml. Given the *incoming* primative data, return the value for this field
that should be validated and transformed to a native value.
""" """
if self.type_name: return dictionary.get(self.field_name, empty)
return {'type': self.type_name}
return {}
def metadata(self):
metadata = SortedDict()
metadata['type'] = self.type_label
metadata['required'] = getattr(self, 'required', False)
optional_attrs = ['read_only', 'label', 'help_text',
'min_length', 'max_length']
for attr in optional_attrs:
value = getattr(self, attr, None)
if value is not None and value != '':
metadata[attr] = force_text(value, strings_only=True)
return metadata
class WritableField(Field):
"""
Base for read/write fields.
"""
write_only = False
default_validators = []
default_error_messages = {
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
widget = widgets.TextInput
default = None
def __init__(self, source=None, label=None, help_text=None,
read_only=False, write_only=False, required=None,
validators=[], error_messages=None, widget=None,
default=None, blank=None):
super(WritableField, self).__init__(source=source, label=label, help_text=help_text)
self.read_only = read_only
self.write_only = write_only
assert not (read_only and write_only), "Cannot set read_only=True and write_only=True"
if required is None:
self.required = not(read_only)
else:
assert not (read_only and required), "Cannot set required=True and read_only=True"
self.required = required
messages = {}
for c in reversed(self.__class__.__mro__):
messages.update(getattr(c, 'default_error_messages', {}))
messages.update(error_messages or {})
self.error_messages = messages
self.validators = self.default_validators + validators def get_attribute(self, instance):
self.default = default if default is not None else self.default """
Given the *outgoing* object instance, return the value for this field
# Widgets are only used for HTML forms. that should be returned as a primative value.
widget = widget or self.widget """
if isinstance(widget, type): return get_attribute(instance, self.source_attrs)
widget = widget()
self.widget = widget
def __deepcopy__(self, memo): def get_default(self):
result = copy.copy(self) """
memo[id(self)] = result Return the default value to use when validating data if no input
result.validators = self.validators[:] is provided for this field.
return result
def get_default_value(self): If a default has not been set for this field then this will simply
if is_simple_callable(self.default): return `empty`, indicating that no value should be set in the
return self.default() validated data for this field.
"""
if self.default is empty:
raise SkipField()
return self.default return self.default
def validate(self, value): def validate(self, data=empty):
if value in validators.EMPTY_VALUES and self.required: """
raise ValidationError(self.error_messages['required']) Validate a simple representation and return the internal value.
def run_validators(self, value): The provided data may be `empty` if no representation was included.
if value in validators.EMPTY_VALUES: May return `empty` if the field should not be included in the
return validated data.
errors = [] """
for v in self.validators: if data is empty:
try: if self.required:
v(value) self.fail('required')
except ValidationError as e: return self.get_default()
if hasattr(e, 'code') and e.code in self.error_messages:
message = self.error_messages[e.code]
if e.params:
message = message % e.params
errors.append(message)
else:
errors.extend(e.messages)
if errors:
raise ValidationError(errors)
def field_to_native(self, obj, field_name): return self.to_native(data)
if self.write_only:
return None
return super(WritableField, self).field_to_native(obj, field_name)
def field_from_native(self, data, files, field_name, into): def to_native(self, data):
""" """
Given a dictionary and a field name, updates the dictionary `into`, Transform the *incoming* primative data into a native value.
with the field and it's deserialized value.
""" """
if self.read_only: return data
return
try:
data = data or {}
if self.use_files:
files = files or {}
try:
native = files[field_name]
except KeyError:
native = data[field_name]
else:
native = data[field_name]
except KeyError:
if self.default is not None and not self.partial:
# Note: partial updates shouldn't set defaults
native = self.get_default_value()
else:
if self.required:
raise ValidationError(self.error_messages['required'])
return
value = self.from_native(native)
if self.source == '*':
if value:
into.update(value)
else:
self.validate(value)
self.run_validators(value)
into[self.source or field_name] = value
def from_native(self, value): def to_primative(self, value):
""" """
Reverts a simple representation back to the field's value. Transform the *outgoing* native value into primative data.
""" """
return value return value
def fail(self, key, **kwargs):
class ModelField(WritableField): """
""" A helper method that simply raises a validation error.
A generic field that can be used against an arbitrary model field. """
"""
def __init__(self, *args, **kwargs):
try: try:
self.model_field = kwargs.pop('model_field') raise ValidationError(self.MESSAGES[key].format(**kwargs))
except KeyError: except KeyError:
raise ValueError("ModelField requires 'model_field' kwarg") class_name = self.__class__.__name__
msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
self.min_length = kwargs.pop('min_length', raise AssertionError(msg)
getattr(self.model_field, 'min_length', None))
self.max_length = kwargs.pop('max_length',
getattr(self.model_field, 'max_length', None))
self.min_value = kwargs.pop('min_value',
getattr(self.model_field, 'min_value', None))
self.max_value = kwargs.pop('max_value',
getattr(self.model_field, 'max_value', None))
super(ModelField, self).__init__(*args, **kwargs)
if self.min_length is not None:
self.validators.append(validators.MinLengthValidator(self.min_length))
if self.max_length is not None:
self.validators.append(validators.MaxLengthValidator(self.max_length))
if self.min_value is not None:
self.validators.append(validators.MinValueValidator(self.min_value))
if self.max_value is not None:
self.validators.append(validators.MaxValueValidator(self.max_value))
def from_native(self, value):
rel = getattr(self.model_field, "rel", None)
if rel is not None:
return rel.to._meta.get_field(rel.field_name).to_python(value)
else:
return self.model_field.to_python(value)
def field_to_native(self, obj, field_name):
value = self.model_field._get_val_from_obj(obj)
if is_protected_type(value):
return value
return self.model_field.value_to_string(obj)
def attributes(self): class BooleanField(Field):
return { MESSAGES = {
"type": self.model_field.get_internal_type() 'required': 'This field is required.',
} 'invalid_value': '`{input}` is not a valid boolean.'
# Typed Fields
class BooleanField(WritableField):
type_name = 'BooleanField'
type_label = 'boolean'
form_field_class = forms.BooleanField
widget = widgets.CheckboxInput
default_error_messages = {
'invalid': _("'%s' value must be either True or False."),
} }
empty = False TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True}
FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False}
def field_from_native(self, data, files, field_name, into):
# HTML checkboxes do not explicitly represent unchecked as `False` def get_value(self, dictionary):
# we deal with that here... if html.is_html_input(dictionary):
if isinstance(data, QueryDict) and self.default is None: # HTML forms do not send a `False` value on an empty checkbox,
self.default = False # so we override the default empty value to be False.
return dictionary.get(self.field_name, False)
return super(BooleanField, self).field_from_native( return dictionary.get(self.field_name, empty)
data, files, field_name, into
) def to_native(self, data):
if data in self.TRUE_VALUES:
def from_native(self, value):
if value in ('true', 't', 'True', '1'):
return True return True
if value in ('false', 'f', 'False', '0'): elif data in self.FALSE_VALUES:
return False return False
return bool(value) self.fail('invalid_value', input=data)
class CharField(WritableField): class CharField(Field):
type_name = 'CharField' MESSAGES = {
type_label = 'string' 'required': 'This field is required.',
form_field_class = forms.CharField 'blank': 'This field may not be blank.'
}
def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs): def __init__(self, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length self.allow_blank = kwargs.pop('allow_blank', False)
self.allow_none = allow_none
super(CharField, self).__init__(*args, **kwargs) super(CharField, self).__init__(*args, **kwargs)
if min_length is not None:
self.validators.append(validators.MinLengthValidator(min_length))
if max_length is not None:
self.validators.append(validators.MaxLengthValidator(max_length))
def from_native(self, value):
if isinstance(value, six.string_types):
return value
if value is None and not self.allow_none:
return ''
return smart_text(value)
class URLField(CharField): def to_native(self, data):
type_name = 'URLField' if data == '' and not self.allow_blank:
type_label = 'url' self.fail('blank')
return str(data)
def __init__(self, **kwargs):
if 'validators' not in kwargs:
kwargs['validators'] = [validators.URLValidator()]
super(URLField, self).__init__(**kwargs)
class SlugField(CharField): class ChoiceField(Field):
type_name = 'SlugField' MESSAGES = {
type_label = 'slug' 'required': 'This field is required.',
form_field_class = forms.SlugField 'invalid_choice': '`{input}` is not a valid choice.'
default_error_messages = {
'invalid': _("Enter a valid 'slug' consisting of letters, numbers,"
" underscores or hyphens."),
} }
default_validators = [validators.validate_slug] coerce_to_type = str
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(SlugField, self).__init__(*args, **kwargs) choices = kwargs.pop('choices')
assert choices, '`choices` argument is required and may not be empty'
# Allow either single or paired choices style:
# choices = [1, 2, 3]
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
pairs = [
isinstance(item, (list, tuple)) and len(item) == 2
for item in choices
]
if all(pairs):
self.choices = {key: val for key, val in choices}
else:
self.choices = {item: item for item in choices}
class ChoiceField(WritableField): # Map the string representation of choices to the underlying value.
type_name = 'ChoiceField' # Allows us to deal with eg. integer choices while supporting either
type_label = 'choice' # integer or string input, but still get the correct datatype out.
form_field_class = forms.ChoiceField self.choice_strings_to_values = {
widget = widgets.Select str(key): key for key in self.choices.keys()
default_error_messages = { }
'invalid_choice': _('Select a valid choice. %(value)s is not one of '
'the available choices.'),
}
def __init__(self, choices=(), blank_display_value=None, *args, **kwargs):
self.empty = kwargs.pop('empty', '')
super(ChoiceField, self).__init__(*args, **kwargs) super(ChoiceField, self).__init__(*args, **kwargs)
self.choices = choices
if not self.required:
if blank_display_value is None:
blank_choice = BLANK_CHOICE_DASH
else:
blank_choice = [('', blank_display_value)]
self.choices = blank_choice + self.choices
def _get_choices(self):
return self._choices
def _set_choices(self, value):
# Setting choices also sets the choices on the widget.
# choices can be any iterable, but we call list() on it because
# it will be consumed more than once.
self._choices = self.widget.choices = list(value)
choices = property(_get_choices, _set_choices)
def metadata(self):
data = super(ChoiceField, self).metadata()
data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices]
return data
def validate(self, value):
"""
Validates that the input is in self.choices.
"""
super(ChoiceField, self).validate(value)
if value and not self.valid_value(value):
raise ValidationError(self.error_messages['invalid_choice'] % {'value': value})
def valid_value(self, value):
"""
Check to see if the provided value is a valid choice.
"""
for k, v in self.choices:
if isinstance(v, (list, tuple)):
# This is an optgroup, so look inside the group for options
for k2, v2 in v:
if value == smart_text(k2):
return True
else:
if value == smart_text(k) or value == k:
return True
return False
def from_native(self, value):
value = super(ChoiceField, self).from_native(value)
if value == self.empty or value in validators.EMPTY_VALUES:
return self.empty
return value
class EmailField(CharField):
type_name = 'EmailField'
type_label = 'email'
form_field_class = forms.EmailField
default_error_messages = {
'invalid': _('Enter a valid email address.'),
}
default_validators = [validators.validate_email]
def from_native(self, value):
ret = super(EmailField, self).from_native(value)
if ret is None:
return None
return ret.strip()
class RegexField(CharField):
type_name = 'RegexField'
type_label = 'regex'
form_field_class = forms.RegexField
def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs):
super(RegexField, self).__init__(max_length, min_length, *args, **kwargs)
self.regex = regex
def _get_regex(self):
return self._regex
def _set_regex(self, regex):
if isinstance(regex, six.string_types):
regex = re.compile(regex)
self._regex = regex
if hasattr(self, '_regex_validator') and self._regex_validator in self.validators:
self.validators.remove(self._regex_validator)
self._regex_validator = validators.RegexValidator(regex=regex)
self.validators.append(self._regex_validator)
regex = property(_get_regex, _set_regex)
class DateField(WritableField):
type_name = 'DateField'
type_label = 'date'
widget = widgets.DateInput
form_field_class = forms.DateField
default_error_messages = {
'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
}
empty = None
input_formats = api_settings.DATE_INPUT_FORMATS
format = api_settings.DATE_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(DateField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.datetime):
if timezone and settings.USE_TZ and timezone.is_aware(value):
# Convert aware datetimes to the default time zone
# before casting them to dates (#17742).
default_timezone = timezone.get_default_timezone()
value = timezone.make_naive(value, default_timezone)
return value.date()
if isinstance(value, datetime.date):
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_date(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed.date()
msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats)
raise ValidationError(msg)
def to_native(self, value):
if value is None or self.format is None:
return value
if isinstance(value, datetime.datetime):
value = value.date()
if self.format.lower() == ISO_8601:
return value.isoformat()
return value.strftime(self.format)
class DateTimeField(WritableField):
type_name = 'DateTimeField'
type_label = 'datetime'
widget = widgets.DateTimeInput
form_field_class = forms.DateTimeField
default_error_messages = {
'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
}
empty = None
input_formats = api_settings.DATETIME_INPUT_FORMATS
format = api_settings.DATETIME_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(DateTimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
value = datetime.datetime(value.year, value.month, value.day)
if settings.USE_TZ:
# For backwards compatibility, interpret naive datetimes in
# local time. This won't work during DST change, but we can't
# do much about it, so we let the exceptions percolate up the
# call stack.
warnings.warn("DateTimeField received a naive datetime (%s)"
" while time zone support is active." % value,
RuntimeWarning)
default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone)
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_datetime(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed
msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats)
raise ValidationError(msg)
def to_native(self, value):
if value is None or self.format is None:
return value
if self.format.lower() == ISO_8601:
ret = value.isoformat()
if ret.endswith('+00:00'):
ret = ret[:-6] + 'Z'
return ret
return value.strftime(self.format)
class TimeField(WritableField):
type_name = 'TimeField'
type_label = 'time'
widget = widgets.TimeInput
form_field_class = forms.TimeField
default_error_messages = {
'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
}
empty = None
input_formats = api_settings.TIME_INPUT_FORMATS
format = api_settings.TIME_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(TimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.time):
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_time(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed.time()
msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats)
raise ValidationError(msg)
def to_native(self, value):
if value is None or self.format is None:
return value
if isinstance(value, datetime.datetime):
value = value.time()
if self.format.lower() == ISO_8601:
return value.isoformat()
return value.strftime(self.format)
class IntegerField(WritableField):
type_name = 'IntegerField'
type_label = 'integer'
form_field_class = forms.IntegerField
empty = 0
default_error_messages = {
'invalid': _('Enter a whole number.'),
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
}
def __init__(self, max_value=None, min_value=None, *args, **kwargs):
self.max_value, self.min_value = max_value, min_value
super(IntegerField, self).__init__(*args, **kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
try:
value = int(str(value))
except (ValueError, TypeError):
raise ValidationError(self.error_messages['invalid'])
return value
class FloatField(WritableField):
type_name = 'FloatField'
type_label = 'float'
form_field_class = forms.FloatField
empty = 0
default_error_messages = {
'invalid': _("'%s' value must be a float."),
}
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
def to_native(self, data):
try: try:
return float(value) return self.choice_strings_to_values[str(data)]
except (TypeError, ValueError): except KeyError:
msg = self.error_messages['invalid'] % value self.fail('invalid_choice', input=data)
raise ValidationError(msg)
class DecimalField(WritableField):
type_name = 'DecimalField'
type_label = 'decimal'
form_field_class = forms.DecimalField
empty = Decimal('0')
default_error_messages = { class MultipleChoiceField(ChoiceField):
'invalid': _('Enter a number.'), MESSAGES = {
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), 'required': 'This field is required.',
'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), 'invalid_choice': '`{input}` is not a valid choice.',
'max_digits': _('Ensure that there are no more than %s digits in total.'), 'not_a_list': 'Expected a list of items but got type `{input_type}`'
'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.')
} }
def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): def to_native(self, data):
self.max_value, self.min_value = max_value, min_value if not hasattr(data, '__iter__'):
self.max_digits, self.decimal_places = max_digits, decimal_places self.fail('not_a_list', input_type=type(data).__name__)
super(DecimalField, self).__init__(*args, **kwargs) return set([
super(MultipleChoiceField, self).to_native(item)
if max_value is not None: for item in data
self.validators.append(validators.MaxValueValidator(max_value)) ])
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def from_native(self, value):
"""
Validates that the input is a decimal number. Returns a Decimal
instance. Returns None for empty values. Ensures that there are no more
than max_digits in the number, and no more than decimal_places digits
after the decimal point.
"""
if value in validators.EMPTY_VALUES:
return None
value = smart_text(value).strip()
try:
value = Decimal(value)
except DecimalException:
raise ValidationError(self.error_messages['invalid'])
return value
def validate(self, value):
super(DecimalField, self).validate(value)
if value in validators.EMPTY_VALUES:
return
# Check for NaN, Inf and -Inf values. We can't compare directly for NaN,
# since it is never equal to itself. However, NaN is the only value that
# isn't equal to itself, so we can use this to identify NaN
if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
raise ValidationError(self.error_messages['invalid'])
sign, digittuple, exponent = value.as_tuple()
decimals = abs(exponent)
# digittuple doesn't include any leading zeros.
digits = len(digittuple)
if decimals > digits:
# We have leading zeros up to or past the decimal point. Count
# everything past the decimal point as a digit. We do not count
# 0 before the decimal point as a digit since that would mean
# we would not allow max_digits = decimal_places.
digits = decimals
whole_digits = digits - decimals
if self.max_digits is not None and digits > self.max_digits:
raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
if self.decimal_places is not None and decimals > self.decimal_places:
raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
return value
class FileField(WritableField):
use_files = True
type_name = 'FileField'
type_label = 'file upload'
form_field_class = forms.FileField
widget = widgets.FileInput
default_error_messages = { class IntegerField(Field):
'invalid': _("No file was submitted. Check the encoding type on the form."), MESSAGES = {
'missing': _("No file was submitted."), 'required': 'This field is required.',
'empty': _("The submitted file is empty."), 'invalid_integer': 'A valid integer is required.'
'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
} }
def __init__(self, *args, **kwargs): def to_native(self, data):
self.max_length = kwargs.pop('max_length', None)
self.allow_empty_file = kwargs.pop('allow_empty_file', False)
super(FileField, self).__init__(*args, **kwargs)
def from_native(self, data):
if data in validators.EMPTY_VALUES:
return None
# UploadedFile objects should have name and size attributes.
try: try:
file_name = data.name data = int(str(data))
file_size = data.size except (ValueError, TypeError):
except AttributeError: self.fail('invalid_integer')
raise ValidationError(self.error_messages['invalid'])
if self.max_length is not None and len(file_name) > self.max_length:
error_values = {'max': self.max_length, 'length': len(file_name)}
raise ValidationError(self.error_messages['max_length'] % error_values)
if not file_name:
raise ValidationError(self.error_messages['invalid'])
if not self.allow_empty_file and not file_size:
raise ValidationError(self.error_messages['empty'])
return data return data
def to_native(self, value):
return value.name
class ImageField(FileField):
use_files = True
type_name = 'ImageField'
type_label = 'image upload'
form_field_class = forms.ImageField
default_error_messages = {
'invalid_image': _("Upload a valid image. The file you uploaded was "
"either not an image or a corrupted image."),
}
def from_native(self, data):
"""
Checks that the file-upload field data contains a valid image (GIF, JPG,
PNG, possibly others -- whatever the Python Imaging Library supports).
"""
f = super(ImageField, self).from_native(data)
if f is None:
return None
from rest_framework.compat import Image
assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.'
# We need to get a file object for PIL. We might have a path or we might
# have to read the data into memory.
if hasattr(data, 'temporary_file_path'):
file = data.temporary_file_path()
else:
if hasattr(data, 'read'):
file = BytesIO(data.read())
else:
file = BytesIO(data['content'])
try:
# load() could spot a truncated JPEG, but it loads the entire
# image in memory, which is a DoS vector. See #3848 and #18520.
# verify() must be called immediately after the constructor.
Image.open(file).verify()
except ImportError:
# Under PyPy, it is possible to import PIL. However, the underlying
# _imaging C module isn't available, so an ImportError will be
# raised. Catch and re-raise.
raise
except Exception: # Python Imaging Library doesn't recognize it as an image
raise ValidationError(self.error_messages['invalid_image'])
if hasattr(f, 'seek') and callable(f.seek):
f.seek(0)
return f
class MethodField(Field):
class SerializerMethodField(Field): def __init__(self, **kwargs):
""" kwargs['source'] = '*'
A field that gets its value by calling a method on the serializer it's attached to. kwargs['read_only'] = True
""" super(MethodField, self).__init__(**kwargs)
def __init__(self, method_name, *args, **kwargs): def to_primative(self, value):
self.method_name = method_name attr = 'get_{field_name}'.format(field_name=self.field_name)
super(SerializerMethodField, self).__init__(*args, **kwargs) method = getattr(self.parent, attr)
return method(value)
def field_to_native(self, obj, field_name):
value = getattr(self.parent, self.method_name)(obj)
return self.to_native(value)
...@@ -79,18 +79,16 @@ class GenericAPIView(views.APIView): ...@@ -79,18 +79,16 @@ class GenericAPIView(views.APIView):
'view': self 'view': self
} }
def get_serializer(self, instance=None, data=None, files=None, many=False, def get_serializer(self, instance=None, data=None, many=False, partial=False):
partial=False, allow_add_remove=False):
""" """
Return the serializer instance that should be used for validating and Return the serializer instance that should be used for validating and
deserializing input, and for serializing output. deserializing input, and for serializing output.
""" """
serializer_class = self.get_serializer_class() serializer_class = self.get_serializer_class()
context = self.get_serializer_context() context = self.get_serializer_context()
return serializer_class(instance, data=data, files=files, return serializer_class(
many=many, partial=partial, instance, data=data, many=many, partial=partial, context=context
allow_add_remove=allow_add_remove, )
context=context)
def get_pagination_serializer(self, page): def get_pagination_serializer(self, page):
""" """
......
...@@ -36,12 +36,10 @@ class CreateModelMixin(object): ...@@ -36,12 +36,10 @@ class CreateModelMixin(object):
Create a model instance. Create a model instance.
""" """
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES) serializer = self.get_serializer(data=request.DATA)
if serializer.is_valid(): if serializer.is_valid():
self.pre_save(serializer.object) self.object = serializer.save()
self.object = serializer.save(force_insert=True)
self.post_save(self.object, created=True)
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, return Response(serializer.data, status=status.HTTP_201_CREATED,
headers=headers) headers=headers)
...@@ -90,26 +88,20 @@ class UpdateModelMixin(object): ...@@ -90,26 +88,20 @@ class UpdateModelMixin(object):
partial = kwargs.pop('partial', False) partial = kwargs.pop('partial', False)
self.object = self.get_object_or_none() self.object = self.get_object_or_none()
serializer = self.get_serializer(self.object, data=request.DATA, serializer = self.get_serializer(self.object, data=request.DATA, partial=partial)
files=request.FILES, partial=partial)
if not serializer.is_valid(): if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
try: lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
self.pre_save(serializer.object) lookup_value = self.kwargs[lookup_url_kwarg]
except ValidationError as err: extras = {self.lookup_field: lookup_value}
# full_clean on model instance may be called in pre_save,
# so we have to handle eventual errors.
return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST)
if self.object is None: if self.object is None:
self.object = serializer.save(force_insert=True) self.object = serializer.save(extras=extras)
self.post_save(self.object, created=True)
return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
self.object = serializer.save(force_update=True) self.object = serializer.save(extras=extras)
self.post_save(self.object, created=False)
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)
def partial_update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs):
......
...@@ -48,17 +48,17 @@ class DefaultObjectSerializer(serializers.Field): ...@@ -48,17 +48,17 @@ class DefaultObjectSerializer(serializers.Field):
super(DefaultObjectSerializer, self).__init__(source=source) super(DefaultObjectSerializer, self).__init__(source=source)
class PaginationSerializerOptions(serializers.SerializerOptions): # class PaginationSerializerOptions(serializers.SerializerOptions):
""" # """
An object that stores the options that may be provided to a # An object that stores the options that may be provided to a
pagination serializer by using the inner `Meta` class. # pagination serializer by using the inner `Meta` class.
Accessible on the instance as `serializer.opts`. # Accessible on the instance as `serializer.opts`.
""" # """
def __init__(self, meta): # def __init__(self, meta):
super(PaginationSerializerOptions, self).__init__(meta) # super(PaginationSerializerOptions, self).__init__(meta)
self.object_serializer_class = getattr(meta, 'object_serializer_class', # self.object_serializer_class = getattr(meta, 'object_serializer_class',
DefaultObjectSerializer) # DefaultObjectSerializer)
class BasePaginationSerializer(serializers.Serializer): class BasePaginationSerializer(serializers.Serializer):
...@@ -66,7 +66,7 @@ class BasePaginationSerializer(serializers.Serializer): ...@@ -66,7 +66,7 @@ class BasePaginationSerializer(serializers.Serializer):
A base class for pagination serializers to inherit from, A base class for pagination serializers to inherit from,
to make implementing custom serializers more easy. to make implementing custom serializers more easy.
""" """
_options_class = PaginationSerializerOptions # _options_class = PaginationSerializerOptions
results_field = 'results' results_field = 'results'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
"""
Serializer fields that deal with relationships.
These fields allow you to specify the style that should be used to represent
model relationships, including hyperlinks, primary keys, or slugs.
"""
from __future__ import unicode_literals
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
from django import forms
from django.db.models.fields import BLANK_CHOICE_DASH
from django.forms import widgets
from django.forms.models import ModelChoiceIterator
from django.utils.translation import ugettext_lazy as _
from rest_framework.fields import Field, WritableField, get_component, is_simple_callable
from rest_framework.reverse import reverse
from rest_framework.compat import urlparse
from rest_framework.compat import smart_text
# Relational fields
# Not actually Writable, but subclasses may need to be.
class RelatedField(WritableField):
"""
Base class for related model fields.
This represents a relationship using the unicode representation of the target.
"""
widget = widgets.Select
many_widget = widgets.SelectMultiple
form_field_class = forms.ChoiceField
many_form_field_class = forms.MultipleChoiceField
null_values = (None, '', 'None')
cache_choices = False
empty_label = None
read_only = True
many = False
def __init__(self, *args, **kwargs):
queryset = kwargs.pop('queryset', None)
self.many = kwargs.pop('many', self.many)
if self.many:
self.widget = self.many_widget
self.form_field_class = self.many_form_field_class
kwargs['read_only'] = kwargs.pop('read_only', self.read_only)
super(RelatedField, self).__init__(*args, **kwargs)
if not self.required:
# Accessed in ModelChoiceIterator django/forms/models.py:1034
# If set adds empty choice.
self.empty_label = BLANK_CHOICE_DASH[0][1]
self.queryset = queryset
def initialize(self, parent, field_name):
super(RelatedField, self).initialize(parent, field_name)
if self.queryset is None and not self.read_only:
manager = getattr(self.parent.opts.model, self.source or field_name)
if hasattr(manager, 'related'): # Forward
self.queryset = manager.related.model._default_manager.all()
else: # Reverse
self.queryset = manager.field.rel.to._default_manager.all()
# We need this stuff to make form choices work...
def prepare_value(self, obj):
return self.to_native(obj)
def label_from_instance(self, obj):
"""
Return a readable representation for use with eg. select widgets.
"""
desc = smart_text(obj)
ident = smart_text(self.to_native(obj))
if desc == ident:
return desc
return "%s - %s" % (desc, ident)
def _get_queryset(self):
return self._queryset
def _set_queryset(self, queryset):
self._queryset = queryset
self.widget.choices = self.choices
queryset = property(_get_queryset, _set_queryset)
def _get_choices(self):
# If self._choices is set, then somebody must have manually set
# the property self.choices. In this case, just return self._choices.
if hasattr(self, '_choices'):
return self._choices
# Otherwise, execute the QuerySet in self.queryset to determine the
# choices dynamically. Return a fresh ModelChoiceIterator that has not been
# consumed. Note that we're instantiating a new ModelChoiceIterator *each*
# time _get_choices() is called (and, thus, each time self.choices is
# accessed) so that we can ensure the QuerySet has not been consumed. This
# construct might look complicated but it allows for lazy evaluation of
# the queryset.
return ModelChoiceIterator(self)
def _set_choices(self, value):
# Setting choices also sets the choices on the widget.
# choices can be any iterable, but we call list() on it because
# it will be consumed more than once.
self._choices = self.widget.choices = list(value)
choices = property(_get_choices, _set_choices)
# Default value handling
def get_default_value(self):
default = super(RelatedField, self).get_default_value()
if self.many and default is None:
return []
return default
# Regular serializer stuff...
def field_to_native(self, obj, field_name):
try:
if self.source == '*':
return self.to_native(obj)
source = self.source or field_name
value = obj
for component in source.split('.'):
if value is None:
break
value = get_component(value, component)
except ObjectDoesNotExist:
return None
if value is None:
return None
if self.many:
if is_simple_callable(getattr(value, 'all', None)):
return [self.to_native(item) for item in value.all()]
else:
# Also support non-queryset iterables.
# This allows us to also support plain lists of related items.
return [self.to_native(item) for item in value]
return self.to_native(value)
def field_from_native(self, data, files, field_name, into):
if self.read_only:
return
try:
if self.many:
try:
# Form data
value = data.getlist(field_name)
if value == [''] or value == []:
raise KeyError
except AttributeError:
# Non-form data
value = data[field_name]
else:
value = data[field_name]
except KeyError:
if self.partial:
return
value = self.get_default_value()
if value in self.null_values:
if self.required:
raise ValidationError(self.error_messages['required'])
into[(self.source or field_name)] = None
elif self.many:
into[(self.source or field_name)] = [self.from_native(item) for item in value]
else:
into[(self.source or field_name)] = self.from_native(value)
# PrimaryKey relationships
class PrimaryKeyRelatedField(RelatedField):
"""
Represents a relationship as a pk value.
"""
read_only = False
default_error_messages = {
'does_not_exist': _("Invalid pk '%s' - object does not exist."),
'incorrect_type': _('Incorrect type. Expected pk value, received %s.'),
}
# TODO: Remove these field hacks...
def prepare_value(self, obj):
return self.to_native(obj.pk)
def label_from_instance(self, obj):
"""
Return a readable representation for use with eg. select widgets.
"""
desc = smart_text(obj)
ident = smart_text(self.to_native(obj.pk))
if desc == ident:
return desc
return "%s - %s" % (desc, ident)
# TODO: Possibly change this to just take `obj`, through prob less performant
def to_native(self, pk):
return pk
def from_native(self, data):
if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument')
try:
return self.queryset.get(pk=data)
except ObjectDoesNotExist:
msg = self.error_messages['does_not_exist'] % smart_text(data)
raise ValidationError(msg)
except (TypeError, ValueError):
received = type(data).__name__
msg = self.error_messages['incorrect_type'] % received
raise ValidationError(msg)
def field_to_native(self, obj, field_name):
if self.many:
# To-many relationship
queryset = None
if not self.source:
# Prefer obj.serializable_value for performance reasons
try:
queryset = obj.serializable_value(field_name)
except AttributeError:
pass
if queryset is None:
# RelatedManager (reverse relationship)
source = self.source or field_name
queryset = obj
for component in source.split('.'):
if queryset is None:
return []
queryset = get_component(queryset, component)
# Forward relationship
if is_simple_callable(getattr(queryset, 'all', None)):
return [self.to_native(item.pk) for item in queryset.all()]
else:
# Also support non-queryset iterables.
# This allows us to also support plain lists of related items.
return [self.to_native(item.pk) for item in queryset]
# To-one relationship
try:
# Prefer obj.serializable_value for performance reasons
pk = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedObject (reverse relationship)
try:
pk = getattr(obj, self.source or field_name).pk
except (ObjectDoesNotExist, AttributeError):
return None
# Forward relationship
return self.to_native(pk)
# Slug relationships
class SlugRelatedField(RelatedField):
"""
Represents a relationship using a unique field on the target.
"""
read_only = False
default_error_messages = {
'does_not_exist': _("Object with %s=%s does not exist."),
'invalid': _('Invalid value.'),
}
def __init__(self, *args, **kwargs):
self.slug_field = kwargs.pop('slug_field', None)
assert self.slug_field, 'slug_field is required'
super(SlugRelatedField, self).__init__(*args, **kwargs)
def to_native(self, obj):
return getattr(obj, self.slug_field)
def from_native(self, data):
if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument')
try:
return self.queryset.get(**{self.slug_field: data})
except ObjectDoesNotExist:
raise ValidationError(self.error_messages['does_not_exist'] %
(self.slug_field, smart_text(data)))
except (TypeError, ValueError):
msg = self.error_messages['invalid']
raise ValidationError(msg)
# Hyperlinked relationships
class HyperlinkedRelatedField(RelatedField):
"""
Represents a relationship using hyperlinking.
"""
read_only = False
lookup_field = 'pk'
default_error_messages = {
'no_match': _('Invalid hyperlink - No URL match'),
'incorrect_match': _('Invalid hyperlink - Incorrect URL match'),
'configuration_error': _('Invalid hyperlink due to configuration error'),
'does_not_exist': _("Invalid hyperlink - object does not exist."),
'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
}
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
except KeyError:
raise ValueError("Hyperlinked field requires 'view_name' kwarg")
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.format = kwargs.pop('format', None)
super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
def get_url(self, obj, view_name, request, format):
"""
Given an object, return the URL that hyperlinks to the object.
May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
attributes are not configured to correctly match the URL conf.
"""
lookup_field = getattr(obj, self.lookup_field)
kwargs = {self.lookup_field: lookup_field}
return reverse(view_name, kwargs=kwargs, request=request, format=format)
def get_object(self, queryset, view_name, view_args, view_kwargs):
"""
Return the object corresponding to a matched URL.
Takes the matched URL conf arguments, and the queryset, and should
return an object instance, or raise an `ObjectDoesNotExist` exception.
"""
lookup_value = view_kwargs[self.lookup_field]
filter_kwargs = {self.lookup_field: lookup_value}
return queryset.get(**filter_kwargs)
def to_native(self, obj):
view_name = self.view_name
request = self.context.get('request', None)
format = self.format or self.context.get('format', None)
assert request is not None, (
"`HyperlinkedRelatedField` requires the request in the serializer "
"context. Add `context={'request': request}` when instantiating "
"the serializer."
)
# If the object has not yet been saved then we cannot hyperlink to it.
if getattr(obj, 'pk', None) is None:
return
# Return the hyperlink, or error if incorrectly configured.
try:
return self.get_url(obj, view_name, request, format)
except NoReverseMatch:
msg = (
'Could not resolve URL for hyperlinked relationship using '
'view name "%s". You may have failed to include the related '
'model in your API, or incorrectly configured the '
'`lookup_field` attribute on this field.'
)
raise Exception(msg % view_name)
def from_native(self, value):
# Convert URL -> model instance pk
# TODO: Use values_list
queryset = self.queryset
if queryset is None:
raise Exception('Writable related fields must include a `queryset` argument')
try:
http_prefix = value.startswith(('http:', 'https:'))
except AttributeError:
msg = self.error_messages['incorrect_type']
raise ValidationError(msg % type(value).__name__)
if http_prefix:
# If needed convert absolute URLs to relative path
value = urlparse.urlparse(value).path
prefix = get_script_prefix()
if value.startswith(prefix):
value = '/' + value[len(prefix):]
try:
match = resolve(value)
except Exception:
raise ValidationError(self.error_messages['no_match'])
if match.view_name != self.view_name:
raise ValidationError(self.error_messages['incorrect_match'])
try:
return self.get_object(queryset, match.view_name,
match.args, match.kwargs)
except (ObjectDoesNotExist, TypeError, ValueError):
raise ValidationError(self.error_messages['does_not_exist'])
class HyperlinkedIdentityField(Field):
"""
Represents the instance, or a property on the instance, using hyperlinking.
"""
lookup_field = 'pk'
read_only = True
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
except KeyError:
msg = "HyperlinkedIdentityField requires 'view_name' argument"
raise ValueError(msg)
self.format = kwargs.pop('format', None)
lookup_field = kwargs.pop('lookup_field', None)
self.lookup_field = lookup_field or self.lookup_field
super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
request = self.context.get('request', None)
format = self.context.get('format', None)
view_name = self.view_name
assert request is not None, (
"`HyperlinkedIdentityField` requires the request in the serializer"
" context. Add `context={'request': request}` when instantiating "
"the serializer."
)
# By default use whatever format is given for the current context
# unless the target is a different type to the source.
#
# Eg. Consider a HyperlinkedIdentityField pointing from a json
# representation to an html property of that representation...
#
# '/snippets/1/' should link to '/snippets/1/highlight/'
# ...but...
# '/snippets/1/.json' should link to '/snippets/1/highlight/.html'
if format and self.format and self.format != format:
format = self.format
# Return the hyperlink, or error if incorrectly configured.
try:
return self.get_url(obj, view_name, request, format)
except NoReverseMatch:
msg = (
'Could not resolve URL for hyperlinked relationship using '
'view name "%s". You may have failed to include the related '
'model in your API, or incorrectly configured the '
'`lookup_field` attribute on this field.'
)
raise Exception(msg % view_name)
def get_url(self, obj, view_name, request, format):
"""
Given an object, return the URL that hyperlinks to the object.
May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
attributes are not configured to correctly match the URL conf.
"""
lookup_field = getattr(obj, self.lookup_field, None)
kwargs = {self.lookup_field: lookup_field}
# Handle unsaved object case
if lookup_field is None:
return None
return reverse(view_name, kwargs=kwargs, request=request, format=format)
...@@ -458,7 +458,7 @@ class BrowsableAPIRenderer(BaseRenderer): ...@@ -458,7 +458,7 @@ class BrowsableAPIRenderer(BaseRenderer):
): ):
return return
serializer = view.get_serializer(instance=obj, data=data, files=files) serializer = view.get_serializer(instance=obj, data=data)
serializer.is_valid() serializer.is_valid()
data = serializer.data data = serializer.data
...@@ -579,10 +579,10 @@ class BrowsableAPIRenderer(BaseRenderer): ...@@ -579,10 +579,10 @@ class BrowsableAPIRenderer(BaseRenderer):
'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes], 'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
'response_headers': response_headers, 'response_headers': response_headers,
'put_form': self.get_rendered_html_form(view, 'PUT', request), #'put_form': self.get_rendered_html_form(view, 'PUT', request),
'post_form': self.get_rendered_html_form(view, 'POST', request), #'post_form': self.get_rendered_html_form(view, 'POST', request),
'delete_form': self.get_rendered_html_form(view, 'DELETE', request), #'delete_form': self.get_rendered_html_form(view, 'DELETE', request),
'options_form': self.get_rendered_html_form(view, 'OPTIONS', request), #'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),
'raw_data_put_form': raw_data_put_form, 'raw_data_put_form': raw_data_put_form,
'raw_data_post_form': raw_data_post_form, 'raw_data_post_form': raw_data_post_form,
......
...@@ -10,21 +10,14 @@ python primitives. ...@@ -10,21 +10,14 @@ python primitives.
2. The process of marshalling between python primitives and request and 2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers. response content is handled by parsers and renderers.
""" """
from __future__ import unicode_literals
import copy
import datetime
import inspect
import types
from decimal import Decimal
from django.contrib.contenttypes.generic import GenericForeignKey
from django.core.paginator import Page
from django.db import models from django.db import models
from django.forms import widgets
from django.utils import six from django.utils import six
from django.utils.datastructures import SortedDict from collections import namedtuple, OrderedDict
from django.core.exceptions import ObjectDoesNotExist from rest_framework.fields import empty, set_value, Field, SkipField, ValidationError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html
import copy
import inspect
# Note: We do the following so that users of the framework can use this style: # Note: We do the following so that users of the framework can use this style:
# #
...@@ -37,635 +30,339 @@ from rest_framework.relations import * # NOQA ...@@ -37,635 +30,339 @@ from rest_framework.relations import * # NOQA
from rest_framework.fields import * # NOQA from rest_framework.fields import * # NOQA
def _resolve_model(obj): FieldResult = namedtuple('FieldResult', ['field', 'value', 'error'])
"""
Resolve supplied `obj` to a Django model class.
`obj` must be a Django model class itself, or a string
representation of one. Useful in situtations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
String representations should have the format:
'appname.ModelName'
"""
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
else:
raise ValueError("{0} is not a Django model".format(obj))
def pretty_name(name):
"""Converts 'first_name' to 'First name'"""
if not name:
return ''
return name.replace('_', ' ').capitalize()
class BaseSerializer(Field):
def __init__(self, instance=None, data=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.instance = instance
self._initial_data = data
class RelationsList(list): def to_native(self, data):
_deleted = [] raise NotImplementedError()
def to_primative(self, instance):
raise NotImplementedError()
class NestedValidationError(ValidationError): def update(self, instance):
""" raise NotImplementedError()
The default ValidationError behavior is to stringify each item in the list
if the messages are a list of error messages.
In the case of nested serializers, where the parent has many children, def create(self):
then the child's `serializer.errors` will be a list of dicts. In the case raise NotImplementedError()
of a single child, the `serializer.errors` will be a dict.
We need to override the default behavior to get properly nested error dicts. def save(self, extras=None):
""" if extras is not None:
self._validated_data.update(extras)
def __init__(self, message): if self.instance is not None:
if isinstance(message, dict): self.update(self.instance)
self._messages = [message]
else: else:
self._messages = message self.instance = self.create()
@property
def messages(self):
return self._messages
return self.instance
class DictWithMetadata(dict): def is_valid(self):
""" try:
A dict-like object, that can have additional properties attached. self._validated_data = self.to_native(self._initial_data)
""" except ValidationError as exc:
def __getstate__(self): self._validated_data = {}
""" self._errors = exc.args[0]
Used by pickle (e.g., caching). return False
Overridden to remove the metadata from the dict, since it shouldn't be self._errors = {}
pickled and may in some instances be unpickleable. return True
"""
return dict(self)
class SortedDictWithMetadata(SortedDict): @property
""" def data(self):
A sorted dict-like object, that can have additional properties attached. if not hasattr(self, '_data'):
""" if self.instance is not None:
def __getstate__(self): self._data = self.to_primative(self.instance)
""" elif self._initial_data is not None:
Used by pickle (e.g., caching). self._data = {
Overriden to remove the metadata from the dict, since it shouldn't be field_name: field.get_value(self._initial_data)
pickle and may in some instances be unpickleable. for field_name, field in self.fields.items()
""" }
return SortedDict(self).__dict__ else:
self._data = self.get_initial()
return self._data
@property
def errors(self):
if not hasattr(self, '_errors'):
msg = 'You must call `.is_valid()` before accessing `.errors`.'
raise AssertionError(msg)
return self._errors
def _is_protected_type(obj): @property
""" def validated_data(self):
True if the object is a native datatype that does not need to if not hasattr(self, '_validated_data'):
be serialized further. msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
""" raise AssertionError(msg)
return isinstance(obj, ( return self._validated_data
types.NoneType,
int, long,
datetime.datetime, datetime.date, datetime.time,
float, Decimal,
basestring)
)
def _get_declared_fields(bases, attrs): class SerializerMetaclass(type):
""" """
Create a list of serializer field instances from the passed in 'attrs', This metaclass sets a dictionary named `base_fields` on the class.
plus any fields on the base classes (in 'bases').
Note that all fields from the base classes are used. Any fields included as attributes on either the class or it's superclasses
will be include in the `base_fields` dictionary.
""" """
fields = [(field_name, attrs.pop(field_name))
for field_name, obj in list(six.iteritems(attrs))
if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1].creation_counter)
# If this class is subclassing another Serializer, add that Serializer's @classmethod
# fields. Note that we loop over the bases in *reverse*. This is necessary def _get_fields(cls, bases, attrs):
# in order to maintain the correct order of fields. fields = [(field_name, attrs.pop(field_name))
for base in bases[::-1]: for field_name, obj in list(attrs.items())
if hasattr(base, 'base_fields'): if isinstance(obj, Field)]
fields = list(base.base_fields.items()) + fields fields.sort(key=lambda x: x[1]._creation_counter)
return SortedDict(fields) # If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields.
for base in bases[::-1]:
if hasattr(base, 'base_fields'):
fields = list(base.base_fields.items()) + fields
return OrderedDict(fields)
class SerializerMetaclass(type):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
attrs['base_fields'] = _get_declared_fields(bases, attrs) attrs['base_fields'] = cls._get_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
class SerializerOptions(object): @six.add_metaclass(SerializerMetaclass)
""" class Serializer(BaseSerializer):
Meta class options for Serializer
"""
def __init__(self, meta):
self.depth = getattr(meta, 'depth', 0)
self.fields = getattr(meta, 'fields', ())
self.exclude = getattr(meta, 'exclude', ())
def __new__(cls, *args, **kwargs):
many = kwargs.pop('many', False)
if many:
class DynamicListSerializer(ListSerializer):
child = cls()
return DynamicListSerializer(*args, **kwargs)
return super(Serializer, cls).__new__(cls)
class BaseSerializer(WritableField): def __init__(self, *args, **kwargs):
""" kwargs.pop('context', None)
This is the Serializer implementation. kwargs.pop('partial', None)
We need to implement it as `BaseSerializer` due to metaclass magicks. kwargs.pop('many', False)
"""
class Meta(object):
pass
_options_class = SerializerOptions
_dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None,
context=None, partial=False, many=False,
allow_add_remove=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
self.partial = partial
self.many = many
self.allow_add_remove = allow_add_remove
self.context = context or {} super(Serializer, self).__init__(*args, **kwargs)
self.init_data = data # Every new serializer is created with a clone of the field instances.
self.init_files = files # This allows users to dynamically modify the fields on a serializer
self.object = instance # instance without affecting every other serializer class.
self.fields = self.get_fields() self.fields = self.get_fields()
self._data = None # Setup all the child fields, to provide them with the current context.
self._files = None
self._errors = None
if many and instance is not None and not hasattr(instance, '__iter__'):
raise ValueError('instance should be a queryset or other iterable with many=True')
if allow_add_remove and not many:
raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
#####
# Methods to determine which fields to use when (de)serializing objects.
def get_default_fields(self):
"""
Return the complete set of default fields for the object, as a dict.
"""
return {}
def get_fields(self):
"""
Returns the complete set of fields for the object as a dict.
This will be the set of any explicitly declared fields,
plus the set of fields returned by get_default_fields().
"""
ret = SortedDict()
# Get the explicitly declared fields
base_fields = copy.deepcopy(self.base_fields)
for key, field in base_fields.items():
ret[key] = field
# Add in the default fields
default_fields = self.get_default_fields()
for key, val in default_fields.items():
if key not in ret:
ret[key] = val
# If 'fields' is specified, use those fields, in that order.
if self.opts.fields:
assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple'
new = SortedDict()
for key in self.opts.fields:
new[key] = ret[key]
ret = new
# Remove anything in 'exclude'
if self.opts.exclude:
assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple'
for key in self.opts.exclude:
ret.pop(key, None)
for key, field in ret.items():
field.initialize(parent=self, field_name=key)
return ret
#####
# Methods to convert or revert from objects <--> primitive representations.
def get_field_key(self, field_name):
"""
Return the key that should be used for a given field.
"""
return field_name
def restore_fields(self, data, files):
"""
Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields.
"""
reverted_data = {}
if data is not None and not isinstance(data, dict):
self._errors['non_field_errors'] = ['Invalid data']
return None
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
field.initialize(parent=self, field_name=field_name) field.bind(field_name, self, self)
try:
field.field_from_native(data, files, field_name, reverted_data)
except ValidationError as err:
self._errors[field_name] = list(err.messages)
return reverted_data def get_fields(self):
return copy.deepcopy(self.base_fields)
def perform_validation(self, attrs): def bind(self, field_name, parent, root):
""" # If the serializer is used as a field then when it becomes bound
Run `validate_<fieldname>()` and `validate()` methods on the serializer # it also needs to bind all its child fields.
""" super(Serializer, self).bind(field_name, parent, root)
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
if field_name in self._errors: field.bind(field_name, self, root)
continue
source = field.source or field_name def get_initial(self):
if self.partial and source not in attrs: return {
continue field.field_name: field.get_initial()
try: for field in self.fields.values()
validate_method = getattr(self, 'validate_%s' % field_name, None) }
if validate_method:
attrs = validate_method(attrs, source)
except ValidationError as err:
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
# If there are already errors, we don't run .validate() because
# field-validation failed and thus `attrs` may not be complete.
# which in turn can cause inconsistent validation errors.
if not self._errors:
try:
attrs = self.validate(attrs)
except ValidationError as err:
if hasattr(err, 'message_dict'):
for field_name, error_messages in err.message_dict.items():
self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages)
elif hasattr(err, 'messages'):
self._errors['non_field_errors'] = err.messages
return attrs
def validate(self, attrs): def get_value(self, dictionary):
""" # We override the default field access in order to support
Stub method, to be overridden in Serializer subclasses # nested HTML forms.
""" if html.is_html_input(dictionary):
return attrs return html.parse_html_dict(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
def restore_object(self, attrs, instance=None): def to_native(self, data):
""" """
Deserialize a dictionary of attributes into an object instance. Dict of native values <- Dict of primitive datatypes.
You should override this method to control how deserialized objects
are instantiated.
""" """
if instance is not None: ret = {}
instance.update(attrs) errors = {}
return instance fields = [field for field in self.fields.values() if not field.read_only]
return attrs
def to_native(self, obj): for field in fields:
""" primitive_value = field.get_value(data)
Serialize objects -> primitives. try:
""" validated_value = field.validate(primitive_value)
ret = self._dict_class() except ValidationError as exc:
ret.fields = self._dict_class() errors[field.field_name] = str(exc)
except SkipField:
pass
else:
set_value(ret, field.source_attrs, validated_value)
for field_name, field in self.fields.items(): if errors:
if field.read_only and obj is None: raise ValidationError(errors)
continue
field.initialize(parent=self, field_name=field_name)
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
method = getattr(self, 'transform_%s' % field_name, None)
if callable(method):
value = method(obj, value)
if not getattr(field, 'write_only', False):
ret[key] = value
ret.fields[key] = self.augment_field(field, field_name, key, value)
return ret return ret
def from_native(self, data, files=None): def to_primative(self, instance):
"""
Deserialize primitives -> objects.
"""
self._errors = {}
if data is not None or files is not None:
attrs = self.restore_fields(data, files)
if attrs is not None:
attrs = self.perform_validation(attrs)
else:
self._errors['non_field_errors'] = ['No input provided']
if not self._errors:
return self.restore_object(attrs, instance=getattr(self, 'object', None))
def augment_field(self, field, field_name, key, value):
# This horrible stuff is to manage serializers rendering to HTML
field._errors = self._errors.get(key) if self._errors else None
field._name = field_name
field._value = self.init_data.get(key) if self._errors and self.init_data else value
if not field.label:
field.label = pretty_name(key)
return field
def field_to_native(self, obj, field_name):
""" """
Override default so that the serializer can be used as a nested field Object instance -> Dict of primitive datatypes.
across relationships.
""" """
if self.write_only: ret = OrderedDict()
return None fields = [field for field in self.fields.values() if not field.write_only]
if self.source == '*': for field in fields:
return self.to_native(obj) native_value = field.get_attribute(instance)
ret[field.field_name] = field.to_primative(native_value)
# Get the raw field value return ret
try:
source = self.source or field_name
value = obj
for component in source.split('.'):
if value is None:
break
value = get_component(value, component)
except ObjectDoesNotExist:
return None
if is_simple_callable(getattr(value, 'all', None)): def __iter__(self):
return [self.to_native(item) for item in value.all()] errors = self.errors if hasattr(self, '_errors') else {}
for field in self.fields.values():
value = self.data.get(field.field_name) if self.data else None
error = errors.get(field.field_name)
yield FieldResult(field, value, error)
if value is None:
return None
if self.many: class ListSerializer(BaseSerializer):
return [self.to_native(item) for item in value] child = None
return self.to_native(value) initial = []
def field_from_native(self, data, files, field_name, into): def __init__(self, *args, **kwargs):
""" self.child = kwargs.pop('child', copy.deepcopy(self.child))
Override default so that the serializer can be used as a writable assert self.child is not None, '`child` is a required argument.'
nested field across relationships.
"""
if self.read_only:
return
try: kwargs.pop('context', None)
value = data[field_name] kwargs.pop('partial', None)
except KeyError:
if self.default is not None and not self.partial:
# Note: partial updates shouldn't set defaults
value = copy.deepcopy(self.default)
else:
if self.required:
raise ValidationError(self.error_messages['required'])
return
if self.source == '*':
if value:
reverted_data = self.restore_fields(value, {})
if not self._errors:
into.update(reverted_data)
else:
if value in (None, ''):
into[(self.source or field_name)] = None
else:
# Set the serializer object if it exists
obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
# If we have a model manager or similar object then we need
# to iterate through each instance.
if (
self.many and
not hasattr(obj, '__iter__') and
is_simple_callable(getattr(obj, 'all', None))
):
obj = obj.all()
kwargs = {
'instance': obj,
'data': value,
'context': self.context,
'partial': self.partial,
'many': self.many,
'allow_add_remove': self.allow_add_remove
}
serializer = self.__class__(**kwargs)
if serializer.is_valid(): super(ListSerializer, self).__init__(*args, **kwargs)
into[self.source or field_name] = serializer.object self.child.bind('', self, self)
else:
# Propagate errors up to our parent
raise NestedValidationError(serializer.errors)
def get_identity(self, data): def bind(self, field_name, parent, root):
""" # If the list is used as a field then it needs to provide
This hook is required for bulk update. # the current context to the child serializer.
It is used to determine the canonical identity of a given object. super(ListSerializer, self).bind(field_name, parent, root)
self.child.bind(field_name, self, root)
Note that the data has not been validated at this point, so we need def get_value(self, dictionary):
to make sure that we catch any cases of incorrect datatypes being # We override the default field access in order to support
passed to this method. # lists in HTML forms.
""" if is_html_input(dictionary):
try: return html.parse_html_list(dictionary, prefix=self.field_name)
return data.get('id', None) return dictionary.get(self.field_name, empty)
except AttributeError:
return None
@property def to_native(self, data):
def errors(self):
""" """
Run deserialization and return error data, List of dicts of native values <- List of dicts of primitive datatypes.
setting self.object if no errors occurred.
""" """
if self._errors is None: if html.is_html_input(data):
data, files = self.init_data, self.init_files data = html.parse_html_list(data)
if self.many is not None: return [self.child.validate(item) for item in data]
many = self.many
else:
many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
if many:
warnings.warn('Implicit list/queryset serialization is deprecated. '
'Use the `many=True` flag when instantiating the serializer.',
DeprecationWarning, stacklevel=3)
if many:
ret = RelationsList()
errors = []
update = self.object is not None
if update:
# If this is a bulk update we need to map all the objects
# to a canonical identity so we can determine which
# individual object is being updated for each item in the
# incoming data
objects = self.object
identities = [self.get_identity(self.to_native(obj)) for obj in objects]
identity_to_objects = dict(zip(identities, objects))
if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)):
for item in data:
if update:
# Determine which object we're updating
identity = self.get_identity(item)
self.object = identity_to_objects.pop(identity, None)
if self.object is None and not self.allow_add_remove:
ret.append(None)
errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
continue
ret.append(self.from_native(item, None))
errors.append(self._errors)
if update and self.allow_add_remove:
ret._deleted = identity_to_objects.values()
self._errors = any(errors) and errors or []
else:
self._errors = {'non_field_errors': ['Expected a list of items.']}
else:
ret = self.from_native(data, files)
if not self._errors:
self.object = ret
return self._errors
def is_valid(self):
return not self.errors
@property def to_primative(self, data):
def data(self):
""" """
Returns the serialized data on the serializer. List of object instances -> List of dicts of primitive datatypes.
""" """
if self._data is None: return [self.child.to_primative(item) for item in data]
obj = self.object
if self.many is not None: def create(self, attrs_list):
many = self.many return [self.child.create(attrs) for attrs in attrs_list]
else:
many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
if many:
warnings.warn('Implicit list/queryset serialization is deprecated. '
'Use the `many=True` flag when instantiating the serializer.',
DeprecationWarning, stacklevel=2)
if many:
self._data = [self.to_native(item) for item in obj]
else:
self._data = self.to_native(obj)
return self._data def save(self):
if self.instance is not None:
self.update(self.instance, self.validated_data)
self.instance = self.create(self.validated_data)
return self.instance
def save_object(self, obj, **kwargs):
obj.save(**kwargs)
def delete_object(self, obj): def _resolve_model(obj):
obj.delete() """
Resolve supplied `obj` to a Django model class.
def save(self, **kwargs):
"""
Save the deserialized object and return it.
"""
# Clear cached _data, which may be invalidated by `save()`
self._data = None
if isinstance(self.object, list):
[self.save_object(item, **kwargs) for item in self.object]
if self.object._deleted:
[self.delete_object(item) for item in self.object._deleted]
else:
self.save_object(self.object, **kwargs)
return self.object
def metadata(self):
"""
Return a dictionary of metadata about the fields on the serializer.
Useful for things like responding to OPTIONS requests, or generating
API schemas for auto-documentation.
"""
return SortedDict(
[
(field_name, field.metadata())
for field_name, field in six.iteritems(self.fields)
]
)
`obj` must be a Django model class itself, or a string
representation of one. Useful in situtations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): String representations should have the format:
pass 'appname.ModelName'
"""
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
else:
raise ValueError("{0} is not a Django model".format(obj))
class ModelSerializerOptions(SerializerOptions): class ModelSerializerOptions(object):
""" """
Meta class options for ModelSerializer Meta class options for ModelSerializer
""" """
def __init__(self, meta): def __init__(self, meta):
super(ModelSerializerOptions, self).__init__(meta) self.model = getattr(meta, 'model')
self.model = getattr(meta, 'model', None) self.fields = getattr(meta, 'fields', ())
self.read_only_fields = getattr(meta, 'read_only_fields', ()) self.depth = getattr(meta, 'depth', 0)
self.write_only_fields = getattr(meta, 'write_only_fields', ())
class ModelSerializer(Serializer): class ModelSerializer(Serializer):
"""
A serializer that deals with model instances and querysets.
"""
_options_class = ModelSerializerOptions
field_mapping = { field_mapping = {
models.AutoField: IntegerField, models.AutoField: IntegerField,
models.FloatField: FloatField, # models.FloatField: FloatField,
models.IntegerField: IntegerField, models.IntegerField: IntegerField,
models.PositiveIntegerField: IntegerField, models.PositiveIntegerField: IntegerField,
models.SmallIntegerField: IntegerField, models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField,
models.DateTimeField: DateTimeField, # models.DateTimeField: DateTimeField,
models.DateField: DateField, # models.DateField: DateField,
models.TimeField: TimeField, # models.TimeField: TimeField,
models.DecimalField: DecimalField, # models.DecimalField: DecimalField,
models.EmailField: EmailField, # models.EmailField: EmailField,
models.CharField: CharField, models.CharField: CharField,
models.URLField: URLField, # models.URLField: URLField,
models.SlugField: SlugField, # models.SlugField: SlugField,
models.TextField: CharField, models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField, models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField, models.BooleanField: BooleanField,
models.NullBooleanField: BooleanField, models.NullBooleanField: BooleanField,
models.FileField: FileField, # models.FileField: FileField,
models.ImageField: ImageField, # models.ImageField: ImageField,
} }
_options_class = ModelSerializerOptions
def __init__(self, *args, **kwargs):
self.opts = self._options_class(self.Meta)
super(ModelSerializer, self).__init__(*args, **kwargs)
def get_fields(self):
# Get the explicitly declared fields.
fields = copy.deepcopy(self.base_fields)
# Add in the default fields.
for key, val in self.get_default_fields().items():
if key not in fields:
fields[key] = val
# If `fields` is set on the `Meta` class,
# then use only those fields, and in that order.
if self.opts.fields:
fields = OrderedDict([
(key, fields[key]) for key in self.opts.fields
])
return fields
def get_default_fields(self): def get_default_fields(self):
""" """
Return all the fields that should be serialized for the model. Return all the fields that should be serialized for the model.
""" """
cls = self.opts.model cls = self.opts.model
assert cls is not None, (
"Serializer class '%s' is missing 'model' Meta option" %
self.__class__.__name__
)
opts = cls._meta.concrete_model._meta opts = cls._meta.concrete_model._meta
ret = SortedDict() ret = OrderedDict()
nested = bool(self.opts.depth) nested = bool(self.opts.depth)
# Deal with adding the primary key field # Deal with adding the primary key field
...@@ -694,29 +391,9 @@ class ModelSerializer(Serializer): ...@@ -694,29 +391,9 @@ class ModelSerializer(Serializer):
has_through_model = True has_through_model = True
if model_field.rel and nested: if model_field.rel and nested:
if len(inspect.getargspec(self.get_nested_field).args) == 2: field = self.get_nested_field(model_field, related_model, to_many)
warnings.warn(
'The `get_nested_field(model_field)` call signature '
'is deprecated. '
'Use `get_nested_field(model_field, related_model, '
'to_many) instead',
DeprecationWarning
)
field = self.get_nested_field(model_field)
else:
field = self.get_nested_field(model_field, related_model, to_many)
elif model_field.rel: elif model_field.rel:
if len(inspect.getargspec(self.get_nested_field).args) == 3: field = self.get_related_field(model_field, related_model, to_many)
warnings.warn(
'The `get_related_field(model_field, to_many)` call '
'signature is deprecated. '
'Use `get_related_field(model_field, related_model, '
'to_many) instead',
DeprecationWarning
)
field = self.get_related_field(model_field, to_many=to_many)
else:
field = self.get_related_field(model_field, related_model, to_many)
else: else:
field = self.get_field(model_field) field = self.get_field(model_field)
...@@ -763,38 +440,6 @@ class ModelSerializer(Serializer): ...@@ -763,38 +440,6 @@ class ModelSerializer(Serializer):
ret[accessor_name] = field ret[accessor_name] = field
# Ensure that 'read_only_fields' is an iterable
assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
# Add the `read_only` flag to any fields that have been specified
# in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
assert field_name not in self.base_fields.keys(), (
"field '%s' on serializer '%s' specified in "
"`read_only_fields`, but also added "
"as an explicit field. Remove it from `read_only_fields`." %
(field_name, self.__class__.__name__))
assert field_name in ret, (
"Non-existant field '%s' specified in `read_only_fields` "
"on serializer '%s'." %
(field_name, self.__class__.__name__))
ret[field_name].read_only = True
# Ensure that 'write_only_fields' is an iterable
assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
for field_name in self.opts.write_only_fields:
assert field_name not in self.base_fields.keys(), (
"field '%s' on serializer '%s' specified in "
"`write_only_fields`, but also added "
"as an explicit field. Remove it from `write_only_fields`." %
(field_name, self.__class__.__name__))
assert field_name in ret, (
"Non-existant field '%s' specified in `write_only_fields` "
"on serializer '%s'." %
(field_name, self.__class__.__name__))
ret[field_name].write_only = True
return ret return ret
def get_pk_field(self, model_field): def get_pk_field(self, model_field):
...@@ -825,28 +470,24 @@ class ModelSerializer(Serializer): ...@@ -825,28 +470,24 @@ class ModelSerializer(Serializer):
# TODO: filter queryset using: # TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to) # .using(db).complex_filter(self.rel.limit_choices_to)
kwargs = { kwargs = {}
'queryset': related_model._default_manager, # 'queryset': related_model._default_manager,
'many': to_many # 'many': to_many
} # }
if model_field: if model_field:
kwargs['required'] = not(model_field.null or model_field.blank) kwargs['required'] = not(model_field.null or model_field.blank)
if model_field.help_text is not None: # if model_field.help_text is not None:
kwargs['help_text'] = model_field.help_text # kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None: if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name kwargs['label'] = model_field.verbose_name
if not model_field.editable: if not model_field.editable:
kwargs['read_only'] = True kwargs['read_only'] = True
if model_field.verbose_name is not None: if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name kwargs['label'] = model_field.verbose_name
if model_field.help_text is not None: return IntegerField(**kwargs)
kwargs['help_text'] = model_field.help_text # TODO: return PrimaryKeyRelatedField(**kwargs)
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field): def get_field(self, model_field):
""" """
...@@ -869,8 +510,8 @@ class ModelSerializer(Serializer): ...@@ -869,8 +510,8 @@ class ModelSerializer(Serializer):
if model_field.verbose_name is not None: if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name kwargs['label'] = model_field.verbose_name
if model_field.help_text is not None: # if model_field.help_text is not None:
kwargs['help_text'] = model_field.help_text # kwargs['help_text'] = model_field.help_text
# TODO: TypedChoiceField? # TODO: TypedChoiceField?
if model_field.flatchoices: # This ModelField contains choices if model_field.flatchoices: # This ModelField contains choices
...@@ -880,7 +521,7 @@ class ModelSerializer(Serializer): ...@@ -880,7 +521,7 @@ class ModelSerializer(Serializer):
return ChoiceField(**kwargs) return ChoiceField(**kwargs)
# put this below the ChoiceField because min_value isn't a valid initializer # put this below the ChoiceField because min_value isn't a valid initializer
if issubclass(model_field.__class__, models.PositiveIntegerField) or\ if issubclass(model_field.__class__, models.PositiveIntegerField) or \
issubclass(model_field.__class__, models.PositiveSmallIntegerField): issubclass(model_field.__class__, models.PositiveSmallIntegerField):
kwargs['min_value'] = 0 kwargs['min_value'] = 0
...@@ -888,170 +529,27 @@ class ModelSerializer(Serializer): ...@@ -888,170 +529,27 @@ class ModelSerializer(Serializer):
issubclass(model_field.__class__, (models.CharField, models.TextField)): issubclass(model_field.__class__, (models.CharField, models.TextField)):
kwargs['allow_none'] = True kwargs['allow_none'] = True
attribute_dict = { # attribute_dict = {
models.CharField: ['max_length'], # models.CharField: ['max_length'],
models.CommaSeparatedIntegerField: ['max_length'], # models.CommaSeparatedIntegerField: ['max_length'],
models.DecimalField: ['max_digits', 'decimal_places'], # models.DecimalField: ['max_digits', 'decimal_places'],
models.EmailField: ['max_length'], # models.EmailField: ['max_length'],
models.FileField: ['max_length'], # models.FileField: ['max_length'],
models.ImageField: ['max_length'], # models.ImageField: ['max_length'],
models.SlugField: ['max_length'], # models.SlugField: ['max_length'],
models.URLField: ['max_length'], # models.URLField: ['max_length'],
} # }
if model_field.__class__ in attribute_dict: # if model_field.__class__ in attribute_dict:
attributes = attribute_dict[model_field.__class__] # attributes = attribute_dict[model_field.__class__]
for attribute in attributes: # for attribute in attributes:
kwargs.update({attribute: getattr(model_field, attribute)}) # kwargs.update({attribute: getattr(model_field, attribute)})
try: try:
return self.field_mapping[model_field.__class__](**kwargs) return self.field_mapping[model_field.__class__](**kwargs)
except KeyError: except KeyError:
return ModelField(model_field=model_field, **kwargs) # TODO: Change this to `return ModelField(model_field=model_field, **kwargs)`
return CharField(**kwargs)
def get_validation_exclusions(self, instance=None):
"""
Return a list of field names to exclude from model validation.
"""
cls = self.opts.model
opts = cls._meta.concrete_model._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many]
for field_name, field in self.fields.items():
field_name = field.source or field_name
if (
field_name in exclusions
and not field.read_only
and (field.required or hasattr(instance, field_name))
and not isinstance(field, Serializer)
):
exclusions.remove(field_name)
return exclusions
def full_clean(self, instance):
"""
Perform Django's full_clean, and populate the `errors` dictionary
if any validation errors occur.
Note that we don't perform this inside the `.restore_object()` method,
so that subclasses can override `.restore_object()`, and still get
the full_clean validation checking.
"""
try:
instance.full_clean(exclude=self.get_validation_exclusions(instance))
except ValidationError as err:
self._errors = err.message_dict
return None
return instance
def restore_object(self, attrs, instance=None):
"""
Restore the model instance.
"""
m2m_data = {}
related_data = {}
nested_forward_relations = {}
meta = self.opts.model._meta
# Reverse fk or one-to-one relations
for (obj, model) in meta.get_all_related_objects_with_model():
field_name = obj.get_accessor_name()
if field_name in attrs:
related_data[field_name] = attrs.pop(field_name)
# Reverse m2m relations
for (obj, model) in meta.get_all_related_m2m_objects_with_model():
field_name = obj.get_accessor_name()
if field_name in attrs:
m2m_data[field_name] = attrs.pop(field_name)
# Forward m2m relations
for field in meta.many_to_many + meta.virtual_fields:
if isinstance(field, GenericForeignKey):
continue
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
# Nested forward relations - These need to be marked so we can save
# them before saving the parent model instance.
for field_name in attrs.keys():
if isinstance(self.fields.get(field_name, None), Serializer):
nested_forward_relations[field_name] = attrs[field_name]
# Create an empty instance of the model
if instance is None:
instance = self.opts.model()
for key, val in attrs.items():
try:
setattr(instance, key, val)
except ValueError:
self._errors[key] = [self.error_messages['required']]
# Any relations that cannot be set until we've
# saved the model get hidden away on these
# private attributes, so we can deal with them
# at the point of save.
instance._related_data = related_data
instance._m2m_data = m2m_data
instance._nested_forward_relations = nested_forward_relations
return instance
def from_native(self, data, files):
"""
Override the default method to also include model field validation.
"""
instance = super(ModelSerializer, self).from_native(data, files)
if not self._errors:
return self.full_clean(instance)
def save_object(self, obj, **kwargs):
"""
Save the deserialized object.
"""
if getattr(obj, '_nested_forward_relations', None):
# Nested relationships need to be saved before we can save the
# parent instance.
for field_name, sub_object in obj._nested_forward_relations.items():
if sub_object:
self.save_object(sub_object)
setattr(obj, field_name, sub_object)
obj.save(**kwargs)
if getattr(obj, '_m2m_data', None):
for accessor_name, object_list in obj._m2m_data.items():
setattr(obj, accessor_name, object_list)
del(obj._m2m_data)
if getattr(obj, '_related_data', None):
related_fields = dict([
(field.get_accessor_name(), field)
for field, model
in obj._meta.get_all_related_objects_with_model()
])
for accessor_name, related in obj._related_data.items():
if isinstance(related, RelationsList):
# Nested reverse fk relationship
for related_item in related:
fk_field = related_fields[accessor_name].field.name
setattr(related_item, fk_field, obj)
self.save_object(related_item)
# Delete any removed objects
if related._deleted:
[self.delete_object(item) for item in related._deleted]
elif isinstance(related, models.Model):
# Nested reverse one-one relationship
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
setattr(related, fk_field, obj)
self.save_object(related)
else:
# Reverse FK or reverse one-one
setattr(obj, accessor_name, related)
del(obj._related_data)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
...@@ -1066,14 +564,10 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): ...@@ -1066,14 +564,10 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer): class HyperlinkedModelSerializer(ModelSerializer):
"""
A subclass of ModelSerializer that uses hyperlinked relationships,
instead of primary key relationships.
"""
_options_class = HyperlinkedModelSerializerOptions _options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail' _default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField #_hyperlink_field_class = HyperlinkedRelatedField
_hyperlink_identify_field_class = HyperlinkedIdentityField #_hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self): def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields() fields = super(HyperlinkedModelSerializer, self).get_default_fields()
...@@ -1081,15 +575,15 @@ class HyperlinkedModelSerializer(ModelSerializer): ...@@ -1081,15 +575,15 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.view_name is None: if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model) self.opts.view_name = self._get_default_view_name(self.opts.model)
if self.opts.url_field_name not in fields: # if self.opts.url_field_name not in fields:
url_field = self._hyperlink_identify_field_class( # url_field = self._hyperlink_identify_field_class(
view_name=self.opts.view_name, # view_name=self.opts.view_name,
lookup_field=self.opts.lookup_field # lookup_field=self.opts.lookup_field
) # )
ret = self._dict_class() # ret = self._dict_class()
ret[self.opts.url_field_name] = url_field # ret[self.opts.url_field_name] = url_field
ret.update(fields) # ret.update(fields)
fields = ret # fields = ret
return fields return fields
...@@ -1103,33 +597,25 @@ class HyperlinkedModelSerializer(ModelSerializer): ...@@ -1103,33 +597,25 @@ class HyperlinkedModelSerializer(ModelSerializer):
""" """
# TODO: filter queryset using: # TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to) # .using(db).complex_filter(self.rel.limit_choices_to)
kwargs = { # kwargs = {
'queryset': related_model._default_manager, # 'queryset': related_model._default_manager,
'view_name': self._get_default_view_name(related_model), # 'view_name': self._get_default_view_name(related_model),
'many': to_many # 'many': to_many
} # }
kwargs = {}
if model_field: if model_field:
kwargs['required'] = not(model_field.null or model_field.blank) kwargs['required'] = not(model_field.null or model_field.blank)
if model_field.help_text is not None: # if model_field.help_text is not None:
kwargs['help_text'] = model_field.help_text # kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None: if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name kwargs['label'] = model_field.verbose_name
if self.opts.lookup_field: return IntegerField(**kwargs)
kwargs['lookup_field'] = self.opts.lookup_field # if self.opts.lookup_field:
# kwargs['lookup_field'] = self.opts.lookup_field
return self._hyperlink_field_class(**kwargs)
def get_identity(self, data): # return self._hyperlink_field_class(**kwargs)
"""
This hook is required for bulk update.
We need to override the default, to use the url as the identity.
"""
try:
return data.get(self.opts.url_field_name, None)
except AttributeError:
return None
def _get_default_view_name(self, model): def _get_default_view_name(self, model):
""" """
......
...@@ -7,7 +7,7 @@ from django.db.models.query import QuerySet ...@@ -7,7 +7,7 @@ from django.db.models.query import QuerySet
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.utils.functional import Promise from django.utils.functional import Promise
from rest_framework.compat import force_text from rest_framework.compat import force_text
from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata # from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime import datetime
import decimal import decimal
import types import types
...@@ -106,14 +106,14 @@ else: ...@@ -106,14 +106,14 @@ else:
SortedDict, SortedDict,
yaml.representer.SafeRepresenter.represent_dict yaml.representer.SafeRepresenter.represent_dict
) )
SafeDumper.add_representer( # SafeDumper.add_representer(
DictWithMetadata, # DictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict # yaml.representer.SafeRepresenter.represent_dict
) # )
SafeDumper.add_representer( # SafeDumper.add_representer(
SortedDictWithMetadata, # SortedDictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict # yaml.representer.SafeRepresenter.represent_dict
) # )
SafeDumper.add_representer( SafeDumper.add_representer(
types.GeneratorType, types.GeneratorType,
yaml.representer.SafeRepresenter.represent_list yaml.representer.SafeRepresenter.represent_list
......
"""
Helpers for dealing with HTML input.
"""
def is_html_input(dictionary):
# MultiDict type datastructures are used to represent HTML form input,
# which may have more than one value for each key.
return hasattr(dictionary, 'getlist')
def parse_html_list(dictionary, prefix=''):
"""
Used to suport list values in HTML forms.
Supports lists of primitives and/or dictionaries.
* List of primitives.
{
'[0]': 'abc',
'[1]': 'def',
'[2]': 'hij'
}
-->
[
'abc',
'def',
'hij'
]
* List of dictionaries.
{
'[0]foo': 'abc',
'[0]bar': 'def',
'[1]foo': 'hij',
'[2]bar': 'klm',
}
-->
[
{'foo': 'abc', 'bar': 'def'},
{'foo': 'hij', 'bar': 'klm'}
]
"""
Dict = type(dictionary)
ret = {}
regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix))
for field, value in dictionary.items():
match = regex.match(field)
if not match:
continue
index, key = match.groups()
index = int(index)
if not key:
ret[index] = value
elif isinstance(ret.get(index), dict):
ret[index][key] = value
else:
ret[index] = Dict({key: value})
return [ret[item] for item in sorted(ret.keys())]
def parse_html_dict(dictionary, prefix):
"""
Used to support dictionary values in HTML forms.
{
'profile.username': 'example',
'profile.email': 'example@example.com',
}
-->
{
'profile': {
'username': 'example,
'email': 'example@example.com'
}
}
"""
ret = {}
regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix))
for field, value in dictionary.items():
match = regex.match(field)
if not match:
continue
key = match.groups()[0]
ret[key] = value
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment