Commit ee20cf80 by Tom Christie

Merge pull request #735 from tomchristie/one-to-one-nested-wip

One to one writable nested model serializers (wip)
parents 1aedf57f 66bdd608
...@@ -26,13 +26,17 @@ class NestedValidationError(ValidationError): ...@@ -26,13 +26,17 @@ class NestedValidationError(ValidationError):
if the messages are a list of error messages. if the messages are a list of error messages.
In the case of nested serializers, where the parent has many children, In the case of nested serializers, where the parent has many children,
then the child's `serializer.errors` will be a list of dicts. then the child's `serializer.errors` will be a list of dicts. In the case
of a single child, the `serializer.errors` will be a dict.
We need to override the default behavior to get properly nested error dicts. We need to override the default behavior to get properly nested error dicts.
""" """
def __init__(self, message): def __init__(self, message):
self.messages = message if isinstance(message, dict):
self.messages = [message]
else:
self.messages = message
class DictWithMetadata(dict): class DictWithMetadata(dict):
...@@ -311,8 +315,8 @@ class BaseSerializer(WritableField): ...@@ -311,8 +315,8 @@ class BaseSerializer(WritableField):
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
""" """
Override default so that we can apply ModelSerializer as a nested Override default so that the serializer can be used as a nested field
field to relationships. across relationships.
""" """
if self.source == '*': if self.source == '*':
return self.to_native(obj) return self.to_native(obj)
...@@ -344,6 +348,10 @@ class BaseSerializer(WritableField): ...@@ -344,6 +348,10 @@ class BaseSerializer(WritableField):
return self.to_native(value) return self.to_native(value)
def field_from_native(self, data, files, field_name, into): def field_from_native(self, data, files, field_name, into):
"""
Override default so that the serializer can be used as a writable
nested field across relationships.
"""
if self.read_only: if self.read_only:
return return
...@@ -354,15 +362,14 @@ class BaseSerializer(WritableField): ...@@ -354,15 +362,14 @@ class BaseSerializer(WritableField):
raise ValidationError(self.error_messages['required']) raise ValidationError(self.error_messages['required'])
return return
if self.parent.object: # Set the serializer object if it exists
# Set the serializer object if it exists obj = getattr(self.parent.object, field_name) if self.parent.object else None
obj = getattr(self.parent.object, field_name)
self.object = obj
if value in (None, ''): if value in (None, ''):
into[(self.source or field_name)] = None into[(self.source or field_name)] = None
else: else:
kwargs = { kwargs = {
'instance': obj,
'data': value, 'data': value,
'context': self.context, 'context': self.context,
'partial': self.partial, 'partial': self.partial,
...@@ -371,7 +378,6 @@ class BaseSerializer(WritableField): ...@@ -371,7 +378,6 @@ class BaseSerializer(WritableField):
serializer = self.__class__(**kwargs) serializer = self.__class__(**kwargs)
if serializer.is_valid(): if serializer.is_valid():
self.object = serializer.object
into[self.source or field_name] = serializer.object into[self.source or field_name] = serializer.object
else: else:
# Propagate errors up to our parent # Propagate errors up to our parent
...@@ -630,33 +636,43 @@ class ModelSerializer(Serializer): ...@@ -630,33 +636,43 @@ class ModelSerializer(Serializer):
""" """
Restore the model instance. Restore the model instance.
""" """
self.m2m_data = {} m2m_data = {}
self.related_data = {} related_data = {}
meta = self.opts.model._meta
# Reverse fk relations # Reverse fk or one-to-one relations
for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): for (obj, model) in meta.get_all_related_objects_with_model():
field_name = obj.field.related_query_name() field_name = obj.field.related_query_name()
if field_name in attrs: if field_name in attrs:
self.related_data[field_name] = attrs.pop(field_name) related_data[field_name] = attrs.pop(field_name)
# Reverse m2m relations # Reverse m2m relations
for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): for (obj, model) in meta.get_all_related_m2m_objects_with_model():
field_name = obj.field.related_query_name() field_name = obj.field.related_query_name()
if field_name in attrs: if field_name in attrs:
self.m2m_data[field_name] = attrs.pop(field_name) m2m_data[field_name] = attrs.pop(field_name)
# Forward m2m relations # Forward m2m relations
for field in self.opts.model._meta.many_to_many: for field in meta.many_to_many:
if field.name in attrs: if field.name in attrs:
self.m2m_data[field.name] = attrs.pop(field.name) m2m_data[field.name] = attrs.pop(field.name)
# Update an existing instance...
if instance is not None: if instance is not None:
for key, val in attrs.items(): for key, val in attrs.items():
setattr(instance, key, val) setattr(instance, key, val)
# ...or create a new instance
else: else:
instance = self.opts.model(**attrs) instance = self.opts.model(**attrs)
# Any relations that cannot be set until we've
# saved the model get hidden away on these
# private attributes, so we can deal with them
# at the point of save.
instance._related_data = related_data
instance._m2m_data = m2m_data
return instance return instance
def from_native(self, data, files): def from_native(self, data, files):
...@@ -673,15 +689,24 @@ class ModelSerializer(Serializer): ...@@ -673,15 +689,24 @@ class ModelSerializer(Serializer):
""" """
obj.save(**kwargs) obj.save(**kwargs)
if getattr(self, 'm2m_data', None): if getattr(obj, '_m2m_data', None):
for accessor_name, object_list in self.m2m_data.items(): for accessor_name, object_list in obj._m2m_data.items():
setattr(self.object, accessor_name, object_list) setattr(obj, accessor_name, object_list)
self.m2m_data = {} del(obj._m2m_data)
if getattr(self, 'related_data', None): if getattr(obj, '_related_data', None):
for accessor_name, object_list in self.related_data.items(): for accessor_name, related in obj._related_data.items():
setattr(self.object, accessor_name, object_list) if related is None:
self.related_data = {} previous = getattr(obj, accessor_name, related)
if previous:
previous.delete()
elif isinstance(related, models.Model):
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
setattr(related, fk_field, obj)
self.save_object(related)
else:
setattr(obj, accessor_name, related)
del(obj._related_data)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
......
from __future__ import unicode_literals from __future__ import unicode_literals
from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
class ForeignKeySourceSerializer(serializers.ModelSerializer): class OneToOneTarget(models.Model):
class Meta: name = models.CharField(max_length=100)
depth = 1
model = ForeignKeySource
class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
class OneToOneTargetSource(models.Model):
name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='target_source')
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = FlatForeignKeySourceSerializer(many=True)
class Meta: class OneToOneSource(models.Model):
model = ForeignKeyTarget name = models.CharField(max_length=100)
target_source = models.OneToOneField(OneToOneTargetSource, related_name='source')
class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class OneToOneSourceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
depth = 1 model = OneToOneSource
model = NullableForeignKeySource exclude = ('target_source', )
class NullableOneToOneSourceSerializer(serializers.ModelSerializer): class OneToOneTargetSourceSerializer(serializers.ModelSerializer):
source = OneToOneSourceSerializer()
class Meta: class Meta:
model = NullableOneToOneSource model = OneToOneTargetSource
exclude = ('target', )
class NullableOneToOneTargetSerializer(serializers.ModelSerializer): class OneToOneTargetSerializer(serializers.ModelSerializer):
nullable_source = NullableOneToOneSourceSerializer() target_source = OneToOneTargetSourceSerializer()
class Meta: class Meta:
model = OneToOneTarget model = OneToOneTarget
class ReverseForeignKeyTests(TestCase): class NestedOneToOneTests(TestCase):
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
new_target = ForeignKeyTarget(name='target-2')
new_target.save()
for idx in range(1, 4): for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target) target = OneToOneTarget(name='target-%d' % idx)
target.save()
target_source = OneToOneTargetSource(name='target-source-%d' % idx, target=target)
target_source.save()
source = OneToOneSource(name='source-%d' % idx, target_source=target_source)
source.save() source.save()
def test_foreign_key_retrieve(self): def test_one_to_one_retrieve(self):
queryset = ForeignKeySource.objects.all() queryset = OneToOneTarget.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = OneToOneTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self): def test_one_to_one_create(self):
queryset = ForeignKeyTarget.objects.all() data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}}
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = OneToOneTargetSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-4')
# Ensure (target 4, target_source 4, source 4) are added, and
# everything else is as expected.
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': 'target-1', 'sources': [ {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}},
{'id': 3, 'name': 'source-3', 'target': 1}, {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}}
]},
{'id': 2, 'name': 'target-2', 'sources': [
]}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_one_to_one_create_with_invalid_data(self):
class NestedNullableForeignKeyTests(TestCase): data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4}}}
def setUp(self): serializer = OneToOneTargetSerializer(data=data)
target = ForeignKeyTarget(name='target-1') self.assertFalse(serializer.is_valid())
target.save() self.assertEqual(serializer.errors, {'target_source': [{'source': [{'name': ['This field is required.']}]}]})
for idx in range(1, 4):
if idx == 3: def test_one_to_one_update(self):
target = None data = {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}}
source = NullableForeignKeySource(name='source-%d' % idx, target=target) instance = OneToOneTarget.objects.get(pk=3)
source.save() serializer = OneToOneTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
def test_foreign_key_retrieve_with_null(self): obj = serializer.save()
queryset = NullableForeignKeySource.objects.all() self.assertEqual(serializer.data, data)
serializer = NullableForeignKeySourceSerializer(queryset, many=True) self.assertEqual(obj.name, 'target-3-updated')
# Ensure (target 3, target_source 3, source 3) are updated,
# and everything else is as expected.
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 3, 'name': 'source-3', 'target': None}, {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_one_to_one_delete(self):
data = {'id': 3, 'name': 'target-3', 'target_source': None}
instance = OneToOneTarget.objects.get(pk=3)
serializer = OneToOneTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
serializer.save()
class NestedNullableOneToOneTests(TestCase): # Ensure (target_source 3, source 3) are deleted,
def setUp(self): # and everything else is as expected.
target = OneToOneTarget(name='target-1')
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
source = NullableOneToOneSource(name='source-1', target=target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True) serializer = OneToOneTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}}, {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}},
{'id': 2, 'name': 'target-2', 'nullable_source': None}, {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}},
{'id': 3, 'name': 'target-3', 'target_source': None}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
""" """
Tests to cover nested serializers. Tests to cover nested serializers.
Doesn't cover model serializers.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
...@@ -124,7 +126,7 @@ class WritableNestedSerializerObjectTests(TestCase): ...@@ -124,7 +126,7 @@ class WritableNestedSerializerObjectTests(TestCase):
def __init__(self, order, title, duration): def __init__(self, order, title, duration):
self.order, self.title, self.duration = order, title, duration self.order, self.title, self.duration = order, title, duration
def __cmp__(self, other): def __eq__(self, other):
return ( return (
self.order == other.order and self.order == other.order and
self.title == other.title and self.title == other.title and
...@@ -135,7 +137,7 @@ class WritableNestedSerializerObjectTests(TestCase): ...@@ -135,7 +137,7 @@ class WritableNestedSerializerObjectTests(TestCase):
def __init__(self, album_name, artist, tracks): def __init__(self, album_name, artist, tracks):
self.album_name, self.artist, self.tracks = album_name, artist, tracks self.album_name, self.artist, self.tracks = album_name, artist, tracks
def __cmp__(self, other): def __eq__(self, other):
return ( return (
self.album_name == other.album_name and self.album_name == other.album_name and
self.artist == other.artist and self.artist == other.artist and
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment