Commit 05cbec9d by Tom Christie

Use serializers.ValidationError

parent 5882a7a9
...@@ -144,11 +144,15 @@ The corresponding code would now look like this: ...@@ -144,11 +144,15 @@ The corresponding code would now look like this:
logging.info('Creating ticket "%s"' % name) logging.info('Creating ticket "%s"' % name)
serializer.save(user=request.user) # Include the user when saving. serializer.save(user=request.user) # Include the user when saving.
#### Use `rest_framework.exceptions.ValidationFailed`. #### Using `serializers.ValidationError`.
Django's `ValidationError` class is intended for use with HTML forms and it's API makes its use slightly awkward with nested validation errors as can occur in serializers. Previously `serializers.ValidationError` error was simply a synonym for `django.core.exceptions.ValidationError`. This has now been altered so that it inherits from the standard `APIException` base class.
We now include a simpler `ValidationFailed` exception class in REST framework that you should use when raising validation failures. The reason behind this is that Django's `ValidationError` class is intended for use with HTML forms and its API makes using it slightly awkward with nested validation errors that can occur in serializers.
For most users this change shouldn't require any updates to your codebase, but it is worth ensuring that whenever raising validation errors you are always using the `serializers.ValidationError` exception class, and not Django's built-in exception.
We strongly recommend that you use the namespaced import style of `import serializers` and not `from serializers import ValidationError` in order to avoid any potential confusion.
#### Change to `validate_<field_name>`. #### Change to `validate_<field_name>`.
...@@ -156,14 +160,14 @@ The `validate_<field_name>` method hooks that can be attached to serializer clas ...@@ -156,14 +160,14 @@ The `validate_<field_name>` method hooks that can be attached to serializer clas
def validate_score(self, attrs, source): def validate_score(self, attrs, source):
if attrs[score] % 10 != 0: if attrs[score] % 10 != 0:
raise ValidationError('This field should be a multiple of ten.') raise serializers.ValidationError('This field should be a multiple of ten.')
return attrs return attrs
This is now simplified slightly, and the method hooks simply take the value to be validated, and return it's validated value. This is now simplified slightly, and the method hooks simply take the value to be validated, and return it's validated value.
def validate_score(self, value): def validate_score(self, value):
if value % 10 != 0: if value % 10 != 0:
raise ValidationError('This field should be a multiple of ten.') raise serializers.ValidationError('This field should be a multiple of ten.')
return value return value
Any ad-hoc validation that applies to more than one field should go in the `.validate(self, attrs)` method as usual. Any ad-hoc validation that applies to more than one field should go in the `.validate(self, attrs)` method as usual.
......
...@@ -18,13 +18,13 @@ class AuthTokenSerializer(serializers.Serializer): ...@@ -18,13 +18,13 @@ class AuthTokenSerializer(serializers.Serializer):
if user: if user:
if not user.is_active: if not user.is_active:
msg = _('User account is disabled.') msg = _('User account is disabled.')
raise exceptions.ValidationFailed(msg) raise exceptions.ValidationError(msg)
else: else:
msg = _('Unable to log in with provided credentials.') msg = _('Unable to log in with provided credentials.')
raise exceptions.ValidationFailed(msg) raise exceptions.ValidationError(msg)
else: else:
msg = _('Must include "username" and "password"') msg = _('Must include "username" and "password"')
raise exceptions.ValidationFailed(msg) raise exceptions.ValidationError(msg)
attrs['user'] = user attrs['user'] = user
return attrs return attrs
...@@ -24,7 +24,14 @@ class APIException(Exception): ...@@ -24,7 +24,14 @@ class APIException(Exception):
return self.detail return self.detail
class ValidationFailed(APIException): # The recommended style for using `ValidationError` is to keep it namespaced
# under `serializers`, in order to minimize potential confusion with Django's
# built in `ValidationError`. For example:
#
# from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid')
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
def __init__(self, detail): def __init__(self, detail):
......
...@@ -13,7 +13,7 @@ from rest_framework.compat import ( ...@@ -13,7 +13,7 @@ from rest_framework.compat import (
smart_text, EmailValidator, MinValueValidator, MaxValueValidator, smart_text, EmailValidator, MinValueValidator, MaxValueValidator,
MinLengthValidator, MaxLengthValidator, URLValidator MinLengthValidator, MaxLengthValidator, URLValidator
) )
from rest_framework.exceptions import ValidationFailed from rest_framework.exceptions import ValidationError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, representation, humanize_datetime from rest_framework.utils import html, representation, humanize_datetime
import copy import copy
...@@ -102,7 +102,7 @@ NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' ...@@ -102,7 +102,7 @@ NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField' USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField'
MISSING_ERROR_MESSAGE = ( MISSING_ERROR_MESSAGE = (
'ValidationFailed raised by `{class_name}`, but error key `{key}` does ' 'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `error_messages` dictionary.' 'not exist in the `error_messages` dictionary.'
) )
...@@ -264,7 +264,7 @@ class Field(object): ...@@ -264,7 +264,7 @@ class Field(object):
def run_validators(self, value): def run_validators(self, value):
""" """
Test the given value against all the validators on the field, Test the given value against all the validators on the field,
and either raise a `ValidationFailed` or simply return. and either raise a `ValidationError` or simply return.
""" """
errors = [] errors = []
for validator in self.validators: for validator in self.validators:
...@@ -272,12 +272,12 @@ class Field(object): ...@@ -272,12 +272,12 @@ class Field(object):
validator.serializer_field = self validator.serializer_field = self
try: try:
validator(value) validator(value)
except ValidationFailed as exc: except ValidationError as exc:
errors.extend(exc.detail) errors.extend(exc.detail)
except DjangoValidationError as exc: except DjangoValidationError as exc:
errors.extend(exc.messages) errors.extend(exc.messages)
if errors: if errors:
raise ValidationFailed(errors) raise ValidationError(errors)
def validate(self, value): def validate(self, value):
pass pass
...@@ -305,7 +305,7 @@ class Field(object): ...@@ -305,7 +305,7 @@ class Field(object):
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg) raise AssertionError(msg)
message_string = msg.format(**kwargs) message_string = msg.format(**kwargs)
raise ValidationFailed(message_string) raise ValidationError(message_string)
@property @property
def root(self): def root(self):
......
...@@ -14,7 +14,7 @@ from django.core.exceptions import ImproperlyConfigured ...@@ -14,7 +14,7 @@ from django.core.exceptions import ImproperlyConfigured
from django.db import models from django.db import models
from django.utils import six from django.utils import six
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from rest_framework.exceptions import ValidationFailed from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, model_meta, representation from rest_framework.utils import html, model_meta, representation
...@@ -77,13 +77,6 @@ class BaseSerializer(Field): ...@@ -77,13 +77,6 @@ class BaseSerializer(Field):
raise NotImplementedError('`create()` must be implemented.') raise NotImplementedError('`create()` must be implemented.')
def save(self, **kwargs): def save(self, **kwargs):
assert not hasattr(self, 'restore_object'), (
'Serializer %s has old-style version 2 `.restore_object()` '
'that is no longer compatible with REST framework 3. '
'Use the new-style `.create()` and `.update()` methods instead.' %
self.__class__.__name__
)
validated_data = self.validated_data validated_data = self.validated_data
if kwargs: if kwargs:
validated_data = dict( validated_data = dict(
...@@ -105,17 +98,24 @@ class BaseSerializer(Field): ...@@ -105,17 +98,24 @@ class BaseSerializer(Field):
return self.instance return self.instance
def is_valid(self, raise_exception=False): def is_valid(self, raise_exception=False):
assert not hasattr(self, 'restore_object'), (
'Serializer %s has old-style version 2 `.restore_object()` '
'that is no longer compatible with REST framework 3. '
'Use the new-style `.create()` and `.update()` methods instead.' %
self.__class__.__name__
)
if not hasattr(self, '_validated_data'): if not hasattr(self, '_validated_data'):
try: try:
self._validated_data = self.run_validation(self._initial_data) self._validated_data = self.run_validation(self._initial_data)
except ValidationFailed as exc: except ValidationError as exc:
self._validated_data = {} self._validated_data = {}
self._errors = exc.detail self._errors = exc.detail
else: else:
self._errors = {} self._errors = {}
if self._errors and raise_exception: if self._errors and raise_exception:
raise ValidationFailed(self._errors) raise ValidationError(self._errors)
return not bool(self._errors) return not bool(self._errors)
...@@ -124,6 +124,8 @@ class BaseSerializer(Field): ...@@ -124,6 +124,8 @@ class BaseSerializer(Field):
if not hasattr(self, '_data'): if not hasattr(self, '_data'):
if self.instance is not None and not getattr(self, '_errors', None): if self.instance is not None and not getattr(self, '_errors', None):
self._data = self.to_representation(self.instance) self._data = self.to_representation(self.instance)
elif hasattr(self, '_validated_data') and not getattr(self, '_errors', None):
self._data = self.to_representation(self.validated_data)
else: else:
self._data = self.get_initial() self._data = self.get_initial()
return self._data return self._data
...@@ -329,7 +331,7 @@ class Serializer(BaseSerializer): ...@@ -329,7 +331,7 @@ class Serializer(BaseSerializer):
return None return None
if not isinstance(data, dict): if not isinstance(data, dict):
raise ValidationFailed({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data'] api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
}) })
...@@ -338,8 +340,8 @@ class Serializer(BaseSerializer): ...@@ -338,8 +340,8 @@ class Serializer(BaseSerializer):
self.run_validators(value) self.run_validators(value)
value = self.validate(value) value = self.validate(value)
assert value is not None, '.validate() should return the validated data' assert value is not None, '.validate() should return the validated data'
except ValidationFailed as exc: except ValidationError as exc:
raise ValidationFailed({ raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: exc.detail api_settings.NON_FIELD_ERRORS_KEY: exc.detail
}) })
return value return value
...@@ -359,7 +361,7 @@ class Serializer(BaseSerializer): ...@@ -359,7 +361,7 @@ class Serializer(BaseSerializer):
validated_value = field.run_validation(primitive_value) validated_value = field.run_validation(primitive_value)
if validate_method is not None: if validate_method is not None:
validated_value = validate_method(validated_value) validated_value = validate_method(validated_value)
except ValidationFailed as exc: except ValidationError as exc:
errors[field.field_name] = exc.detail errors[field.field_name] = exc.detail
except SkipField: except SkipField:
pass pass
...@@ -367,7 +369,7 @@ class Serializer(BaseSerializer): ...@@ -367,7 +369,7 @@ class Serializer(BaseSerializer):
set_value(ret, field.source_attrs, validated_value) set_value(ret, field.source_attrs, validated_value)
if errors: if errors:
raise ValidationFailed(errors) raise ValidationError(errors)
return ret return ret
......
from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from rest_framework import exceptions, serializers from rest_framework import serializers
from rest_framework.test import APISimpleTestCase from rest_framework.test import APISimpleTestCase
import pytest import pytest
...@@ -30,13 +30,13 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase): ...@@ -30,13 +30,13 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
assert instance is self.instance assert instance is self.instance
def test_pk_related_lookup_does_not_exist(self): def test_pk_related_lookup_does_not_exist(self):
with pytest.raises(exceptions.ValidationFailed) as excinfo: with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(4) self.field.to_internal_value(4)
msg = excinfo.value.detail[0] msg = excinfo.value.detail[0]
assert msg == "Invalid pk '4' - object does not exist." assert msg == "Invalid pk '4' - object does not exist."
def test_pk_related_lookup_invalid_type(self): def test_pk_related_lookup_invalid_type(self):
with pytest.raises(exceptions.ValidationFailed) as excinfo: with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(BadType()) self.field.to_internal_value(BadType())
msg = excinfo.value.detail[0] msg = excinfo.value.detail[0]
assert msg == 'Incorrect type. Expected pk value, received BadType.' assert msg == 'Incorrect type. Expected pk value, received BadType.'
...@@ -120,13 +120,13 @@ class TestSlugRelatedField(APISimpleTestCase): ...@@ -120,13 +120,13 @@ class TestSlugRelatedField(APISimpleTestCase):
assert instance is self.instance assert instance is self.instance
def test_slug_related_lookup_does_not_exist(self): def test_slug_related_lookup_does_not_exist(self):
with pytest.raises(exceptions.ValidationFailed) as excinfo: with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value('doesnotexist') self.field.to_internal_value('doesnotexist')
msg = excinfo.value.detail[0] msg = excinfo.value.detail[0]
assert msg == 'Object with name=doesnotexist does not exist.' assert msg == 'Object with name=doesnotexist does not exist.'
def test_slug_related_lookup_invalid_type(self): def test_slug_related_lookup_invalid_type(self):
with pytest.raises(exceptions.ValidationFailed) as excinfo: with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(BadType()) self.field.to_internal_value(BadType())
msg = excinfo.value.detail[0] msg = excinfo.value.detail[0]
assert msg == 'Invalid value.' assert msg == 'Invalid value.'
......
...@@ -2,7 +2,7 @@ from __future__ import unicode_literals ...@@ -2,7 +2,7 @@ from __future__ import unicode_literals
from django.core.validators import MaxValueValidator from django.core.validators import MaxValueValidator
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import exceptions, generics, serializers, status from rest_framework import generics, serializers, status
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
factory = APIRequestFactory() factory = APIRequestFactory()
...@@ -37,7 +37,7 @@ class ShouldValidateModelSerializer(serializers.ModelSerializer): ...@@ -37,7 +37,7 @@ class ShouldValidateModelSerializer(serializers.ModelSerializer):
def validate_renamed(self, value): def validate_renamed(self, value):
if len(value) < 3: if len(value) < 3:
raise exceptions.ValidationFailed('Minimum 3 characters.') raise serializers.ValidationError('Minimum 3 characters.')
return value return value
class Meta: class Meta:
...@@ -73,10 +73,10 @@ class ValidationSerializer(serializers.Serializer): ...@@ -73,10 +73,10 @@ class ValidationSerializer(serializers.Serializer):
foo = serializers.CharField() foo = serializers.CharField()
def validate_foo(self, attrs, source): def validate_foo(self, attrs, source):
raise exceptions.ValidationFailed("foo invalid") raise serializers.ValidationError("foo invalid")
def validate(self, attrs): def validate(self, attrs):
raise exceptions.ValidationFailed("serializer invalid") raise serializers.ValidationError("serializer invalid")
class TestAvoidValidation(TestCase): class TestAvoidValidation(TestCase):
...@@ -158,7 +158,7 @@ class TestChoiceFieldChoicesValidate(TestCase): ...@@ -158,7 +158,7 @@ class TestChoiceFieldChoicesValidate(TestCase):
value = self.CHOICES[0][0] value = self.CHOICES[0][0]
try: try:
f.to_internal_value(value) f.to_internal_value(value)
except exceptions.ValidationFailed: except serializers.ValidationError:
self.fail("Value %s does not validate" % str(value)) self.fail("Value %s does not validate" % str(value))
# def test_nested_choices(self): # def test_nested_choices(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment