Commit 5d80f7f9 by Tom Christie

allow_blank, allow_null

parent 5a95baf2
...@@ -98,14 +98,15 @@ class Field(object): ...@@ -98,14 +98,15 @@ class Field(object):
_creation_counter = 0 _creation_counter = 0
default_error_messages = { default_error_messages = {
'required': _('This field is required.') 'required': _('This field is required.'),
'null': _('This field may not be null.')
} }
default_validators = [] default_validators = []
def __init__(self, read_only=False, write_only=False, def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None, required=None, default=empty, initial=None, source=None,
label=None, help_text=None, style=None, label=None, help_text=None, style=None,
error_messages=None, validators=[]): error_messages=None, validators=[], allow_null=False):
self._creation_counter = Field._creation_counter self._creation_counter = Field._creation_counter
Field._creation_counter += 1 Field._creation_counter += 1
...@@ -129,6 +130,7 @@ class Field(object): ...@@ -129,6 +130,7 @@ class Field(object):
self.help_text = help_text self.help_text = help_text
self.style = {} if style is None else style self.style = {} if style is None else style
self.validators = validators or self.default_validators[:] self.validators = validators or self.default_validators[:]
self.allow_null = allow_null
# Collect default error message from self and parent classes # Collect default error message from self and parent classes
messages = {} messages = {}
...@@ -220,6 +222,11 @@ class Field(object): ...@@ -220,6 +222,11 @@ class Field(object):
self.fail('required') self.fail('required')
return self.get_default() return self.get_default()
if data is None:
if not self.allow_null:
self.fail('null')
return None
value = self.to_internal_value(data) value = self.to_internal_value(data)
self.run_validators(value) self.run_validators(value)
return value return value
...@@ -315,11 +322,14 @@ class CharField(Field): ...@@ -315,11 +322,14 @@ class CharField(Field):
self.min_length = kwargs.pop('min_length', None) self.min_length = kwargs.pop('min_length', None)
super(CharField, self).__init__(**kwargs) super(CharField, self).__init__(**kwargs)
def run_validation(self, data=empty):
if data == '':
if not self.allow_blank:
self.fail('blank')
return ''
return super(CharField, self).run_validation(data)
def to_internal_value(self, data): def to_internal_value(self, data):
if data == '' and not self.allow_blank:
self.fail('blank')
if data is None:
return None
return str(data) return str(data)
def to_representation(self, value): def to_representation(self, value):
...@@ -339,10 +349,6 @@ class EmailField(CharField): ...@@ -339,10 +349,6 @@ class EmailField(CharField):
self.validators.append(validator) self.validators.append(validator)
def to_internal_value(self, data): def to_internal_value(self, data):
if data == '' and not self.allow_blank:
self.fail('blank')
if data is None:
return None
return str(data).strip() return str(data).strip()
def to_representation(self, value): def to_representation(self, value):
...@@ -437,8 +443,6 @@ class FloatField(Field): ...@@ -437,8 +443,6 @@ class FloatField(Field):
self.validators.append(MinValueValidator(min_value, message=message)) self.validators.append(MinValueValidator(min_value, message=message))
def to_internal_value(self, value): def to_internal_value(self, value):
if value is None:
return None
try: try:
return float(value) return float(value)
except (TypeError, ValueError): except (TypeError, ValueError):
...@@ -481,9 +485,6 @@ class DecimalField(Field): ...@@ -481,9 +485,6 @@ class DecimalField(Field):
than max_digits in the number, and no more than decimal_places digits than max_digits in the number, and no more than decimal_places digits
after the decimal point. after the decimal point.
""" """
if value in (None, ''):
return None
value = smart_text(value).strip() value = smart_text(value).strip()
try: try:
value = decimal.Decimal(value) value = decimal.Decimal(value)
...@@ -554,9 +555,6 @@ class DateField(Field): ...@@ -554,9 +555,6 @@ class DateField(Field):
super(DateField, self).__init__(*args, **kwargs) super(DateField, self).__init__(*args, **kwargs)
def to_internal_value(self, value): def to_internal_value(self, value):
if value in (None, ''):
return None
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
self.fail('datetime') self.fail('datetime')
...@@ -622,9 +620,6 @@ class DateTimeField(Field): ...@@ -622,9 +620,6 @@ class DateTimeField(Field):
return value return value
def to_internal_value(self, value): def to_internal_value(self, value):
if value in (None, ''):
return None
if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime):
self.fail('date') self.fail('date')
...@@ -676,9 +671,6 @@ class TimeField(Field): ...@@ -676,9 +671,6 @@ class TimeField(Field):
super(TimeField, self).__init__(*args, **kwargs) super(TimeField, self).__init__(*args, **kwargs)
def to_internal_value(self, value): def to_internal_value(self, value):
if value in (None, ''):
return None
if isinstance(value, datetime.time): if isinstance(value, datetime.time):
return value return value
......
from rest_framework import fields
import pytest
class TestFieldOptions:
def test_required(self):
"""
By default a field must be included in the input.
"""
field = fields.IntegerField()
with pytest.raises(fields.ValidationError) as exc_info:
field.run_validation()
assert exc_info.value.messages == ['This field is required.']
def test_not_required(self):
"""
If `required=False` then a field may be omitted from the input.
"""
field = fields.IntegerField(required=False)
with pytest.raises(fields.SkipField):
field.run_validation()
def test_disallow_null(self):
"""
By default `None` is not a valid input.
"""
field = fields.IntegerField()
with pytest.raises(fields.ValidationError) as exc_info:
field.run_validation(None)
assert exc_info.value.messages == ['This field may not be null.']
def test_allow_null(self):
"""
If `allow_null=True` then `None` is a valid input.
"""
field = fields.IntegerField(allow_null=True)
output = field.run_validation(None)
assert output is None
def test_disallow_blank(self):
"""
By default '' is not a valid input.
"""
field = fields.CharField()
with pytest.raises(fields.ValidationError) as exc_info:
field.run_validation('')
assert exc_info.value.messages == ['This field may not be blank.']
def test_allow_blank(self):
"""
If `allow_blank=True` then '' is a valid input.
"""
field = fields.CharField(allow_blank=True)
output = field.run_validation('')
assert output is ''
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment