Commit 8586290d by Tom Christie

Apply defaults and requiredness to unique_together fields. Closes #2092.

parent 6cb65101
...@@ -720,49 +720,60 @@ class ModelSerializer(Serializer): ...@@ -720,49 +720,60 @@ class ModelSerializer(Serializer):
# Determine if we need any additional `HiddenField` or extra keyword # Determine if we need any additional `HiddenField` or extra keyword
# arguments to deal with `unique_for` dates that are required to # arguments to deal with `unique_for` dates that are required to
# be in the input data in order to validate it. # be in the input data in order to validate it.
unique_fields = {} hidden_fields = {}
for model_field_name, field_name in model_field_mapping.items(): for model_field_name, field_name in model_field_mapping.items():
try: try:
model_field = model._meta.get_field(model_field_name) model_field = model._meta.get_field(model_field_name)
except FieldDoesNotExist: except FieldDoesNotExist:
continue continue
# Deal with each of the `unique_for_*` cases. # Include each of the `unique_for_*` field names.
for date_field_name in ( unique_constraint_names = set([
model_field.unique_for_date, model_field.unique_for_date,
model_field.unique_for_month, model_field.unique_for_month,
model_field.unique_for_year model_field.unique_for_year
): ])
if date_field_name is None: unique_constraint_names -= set([None])
continue
# Include each of the `unique_together` field names,
# Get the model field that is refered too. # so long as all the field names are included on the serializer.
date_field = model._meta.get_field(date_field_name) for parent_class in [model] + list(model._meta.parents.keys()):
for unique_together_list in parent_class._meta.unique_together:
if date_field.auto_now_add: if set(fields).issuperset(set(unique_together_list)):
default = CreateOnlyDefault(timezone.now) unique_constraint_names |= set(unique_together_list)
elif date_field.auto_now:
default = timezone.now # Now we have all the field names that have uniqueness constraints
elif date_field.has_default(): # applied, we can add the extra 'required=...' or 'default=...'
default = model_field.default # arguments that are appropriate to these fields, or add a `HiddenField` for it.
else: for unique_constraint_name in unique_constraint_names:
default = empty # Get the model field that is refered too.
unique_constraint_field = model._meta.get_field(unique_constraint_name)
if date_field_name in model_field_mapping:
# The corresponding date field is present in the serializer if getattr(unique_constraint_field, 'auto_now_add', None):
if date_field_name not in extra_kwargs: default = CreateOnlyDefault(timezone.now)
extra_kwargs[date_field_name] = {} elif getattr(unique_constraint_field, 'auto_now', None):
if default is empty: default = timezone.now
if 'required' not in extra_kwargs[date_field_name]: elif unique_constraint_field.has_default():
extra_kwargs[date_field_name]['required'] = True default = model_field.default
else: else:
if 'default' not in extra_kwargs[date_field_name]: default = empty
extra_kwargs[date_field_name]['default'] = default
if unique_constraint_name in model_field_mapping:
# The corresponding field is present in the serializer
if unique_constraint_name not in extra_kwargs:
extra_kwargs[unique_constraint_name] = {}
if default is empty:
if 'required' not in extra_kwargs[unique_constraint_name]:
extra_kwargs[unique_constraint_name]['required'] = True
else: else:
# The corresponding date field is not present in the, if 'default' not in extra_kwargs[unique_constraint_name]:
# serializer. We have a default to use for the date, so extra_kwargs[unique_constraint_name]['default'] = default
# add in a hidden field that populates it. elif default is not empty:
unique_fields[date_field_name] = HiddenField(default=default) # The corresponding field is not present in the,
# serializer. We have a default to use for it, so
# add in a hidden field that populates it.
hidden_fields[unique_constraint_name] = HiddenField(default=default)
# Now determine the fields that should be included on the serializer. # Now determine the fields that should be included on the serializer.
for field_name in fields: for field_name in fields:
...@@ -838,12 +849,16 @@ class ModelSerializer(Serializer): ...@@ -838,12 +849,16 @@ class ModelSerializer(Serializer):
'validators', 'queryset' 'validators', 'queryset'
]: ]:
kwargs.pop(attr, None) kwargs.pop(attr, None)
if extras.get('default') and kwargs.get('required') is False:
kwargs.pop('required')
kwargs.update(extras) kwargs.update(extras)
# Create the serializer field. # Create the serializer field.
ret[field_name] = field_cls(**kwargs) ret[field_name] = field_cls(**kwargs)
for field_name, field in unique_fields.items(): for field_name, field in hidden_fields.items():
ret[field_name] = field ret[field_name] = field
return ret return ret
......
...@@ -93,6 +93,9 @@ class UniqueTogetherValidator: ...@@ -93,6 +93,9 @@ class UniqueTogetherValidator:
The `UniqueTogetherValidator` always forces an implied 'required' The `UniqueTogetherValidator` always forces an implied 'required'
state on the fields it applies to. state on the fields it applies to.
""" """
if self.instance is not None:
return
missing = dict([ missing = dict([
(field_name, self.missing_message) (field_name, self.missing_message)
for field_name in self.fields for field_name in self.fields
...@@ -105,8 +108,17 @@ class UniqueTogetherValidator: ...@@ -105,8 +108,17 @@ class UniqueTogetherValidator:
""" """
Filter the queryset to all instances matching the given attributes. Filter the queryset to all instances matching the given attributes.
""" """
# If this is an update, then any unprovided field should
# have it's value set based on the existing instance attribute.
if self.instance is not None:
for field_name in self.fields:
if field_name not in attrs:
attrs[field_name] = getattr(self.instance, field_name)
# Determine the filter keyword arguments and filter the queryset.
filter_kwargs = dict([ filter_kwargs = dict([
(field_name, attrs[field_name]) for field_name in self.fields (field_name, attrs[field_name])
for field_name in self.fields
]) ])
return queryset.filter(**filter_kwargs) return queryset.filter(**filter_kwargs)
......
...@@ -88,8 +88,8 @@ class TestUniquenessTogetherValidation(TestCase): ...@@ -88,8 +88,8 @@ class TestUniquenessTogetherValidation(TestCase):
expected = dedent(""" expected = dedent("""
UniquenessTogetherSerializer(): UniquenessTogetherSerializer():
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
race_name = CharField(max_length=100) race_name = CharField(max_length=100, required=True)
position = IntegerField() position = IntegerField(required=True)
class Meta: class Meta:
validators = [<UniqueTogetherValidator(queryset=UniquenessTogetherModel.objects.all(), fields=('race_name', 'position'))>] validators = [<UniqueTogetherValidator(queryset=UniquenessTogetherModel.objects.all(), fields=('race_name', 'position'))>]
""") """)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment