Commit 303bc7cf by Tom Christie

Support nullable FKs, with blank=True

parent a5178e9a
...@@ -350,6 +350,12 @@ class RelatedField(WritableField): ...@@ -350,6 +350,12 @@ class RelatedField(WritableField):
return return
value = data.get(field_name) value = data.get(field_name)
if value is None and not self.blank:
raise ValidationError('Value may not be null')
elif value is None and self.blank:
into[(self.source or field_name)] = None
else:
into[(self.source or field_name)] = self.from_native(value) into[(self.source or field_name)] = self.from_native(value)
......
...@@ -431,10 +431,14 @@ class ModelSerializer(Serializer): ...@@ -431,10 +431,14 @@ class ModelSerializer(Serializer):
""" """
# TODO: filter queryset using: # TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to) # .using(db).complex_filter(self.rel.limit_choices_to)
queryset = model_field.rel.to._default_manager kwargs = {
'blank': model_field.blank,
'queryset': model_field.rel.to._default_manager
}
if to_many: if to_many:
return ManyPrimaryKeyRelatedField(queryset=queryset) return ManyPrimaryKeyRelatedField(**kwargs)
return PrimaryKeyRelatedField(queryset=queryset) return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field): def get_field(self, model_field):
""" """
...@@ -572,9 +576,9 @@ class HyperlinkedModelSerializer(ModelSerializer): ...@@ -572,9 +576,9 @@ class HyperlinkedModelSerializer(ModelSerializer):
# TODO: filter queryset using: # TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to) # .using(db).complex_filter(self.rel.limit_choices_to)
rel = model_field.rel.to rel = model_field.rel.to
queryset = rel._default_manager
kwargs = { kwargs = {
'queryset': queryset, 'blank': model_field.blank,
'queryset': rel._default_manager,
'view_name': self._get_default_view_name(rel) 'view_name': self._get_default_view_name(rel)
} }
if to_many: if to_many:
......
from django.conf.urls.defaults import patterns, url from django.conf.urls.defaults import patterns, url
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.utils import simplejson as json
from rest_framework import generics, status, serializers from rest_framework import generics, status, serializers
from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel
...@@ -54,10 +55,12 @@ class BlogPostCommentListCreate(generics.ListCreateAPIView): ...@@ -54,10 +55,12 @@ class BlogPostCommentListCreate(generics.ListCreateAPIView):
model = BlogPostComment model = BlogPostComment
serializer_class = BlogPostCommentSerializer serializer_class = BlogPostCommentSerializer
class BlogPostCommentDetail(generics.RetrieveAPIView): class BlogPostCommentDetail(generics.RetrieveAPIView):
model = BlogPostComment model = BlogPostComment
serializer_class = BlogPostCommentSerializer serializer_class = BlogPostCommentSerializer
class BlogPostDetail(generics.RetrieveAPIView): class BlogPostDetail(generics.RetrieveAPIView):
model = BlogPost model = BlogPost
...@@ -71,7 +74,7 @@ class AlbumDetail(generics.RetrieveAPIView): ...@@ -71,7 +74,7 @@ class AlbumDetail(generics.RetrieveAPIView):
model = Album model = Album
class OptionalRelationDetail(generics.RetrieveAPIView): class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
model = OptionalRelationModel model = OptionalRelationModel
model_serializer_class = serializers.HyperlinkedModelSerializer model_serializer_class = serializers.HyperlinkedModelSerializer
...@@ -162,7 +165,7 @@ class TestManyToManyHyperlinkedView(TestCase): ...@@ -162,7 +165,7 @@ class TestManyToManyHyperlinkedView(TestCase):
GET requests to ListCreateAPIView should return list of objects. GET requests to ListCreateAPIView should return list of objects.
""" """
request = factory.get('/manytomany/') request = factory.get('/manytomany/')
response = self.list_view(request).render() response = self.list_view(request)
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data) self.assertEquals(response.data, self.data)
...@@ -171,7 +174,7 @@ class TestManyToManyHyperlinkedView(TestCase): ...@@ -171,7 +174,7 @@ class TestManyToManyHyperlinkedView(TestCase):
GET requests to ListCreateAPIView should return list of objects. GET requests to ListCreateAPIView should return list of objects.
""" """
request = factory.get('/manytomany/1/') request = factory.get('/manytomany/1/')
response = self.detail_view(request, pk=1).render() response = self.detail_view(request, pk=1)
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0]) self.assertEquals(response.data, self.data[0])
...@@ -194,7 +197,7 @@ class TestCreateWithForeignKeys(TestCase): ...@@ -194,7 +197,7 @@ class TestCreateWithForeignKeys(TestCase):
} }
request = factory.post('/comments/', data=data) request = factory.post('/comments/', data=data)
response = self.create_view(request).render() response = self.create_view(request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response['Location'], 'http://testserver/comments/1/') self.assertEqual(response['Location'], 'http://testserver/comments/1/')
self.assertEqual(self.post.blogpostcomment_set.count(), 1) self.assertEqual(self.post.blogpostcomment_set.count(), 1)
...@@ -219,7 +222,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase): ...@@ -219,7 +222,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
} }
request = factory.post('/photos/', data=data) request = factory.post('/photos/', data=data)
response = self.list_create_view(request).render() response = self.list_create_view(request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
self.assertEqual(self.post.photo_set.count(), 1) self.assertEqual(self.post.photo_set.count(), 1)
...@@ -244,6 +247,16 @@ class TestOptionalRelationHyperlinkedView(TestCase): ...@@ -244,6 +247,16 @@ class TestOptionalRelationHyperlinkedView(TestCase):
for non existing relations. for non existing relations.
""" """
request = factory.get('/optionalrelationmodel-detail/1') request = factory.get('/optionalrelationmodel-detail/1')
response = self.detail_view(request, pk=1).render() response = self.detail_view(request, pk=1)
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data) self.assertEquals(response.data, self.data)
def test_put_detail_view(self):
"""
PUT requests to RetrieveUpdateDestroyAPIView with optional relations
should accept None for non existing relations.
"""
response = self.client.put('/optionalrelation/1/',
data=json.dumps(self.data),
content_type='application/json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
...@@ -49,9 +49,22 @@ class ForeignKeySourceSerializer(serializers.ModelSerializer): ...@@ -49,9 +49,22 @@ class ForeignKeySourceSerializer(serializers.ModelSerializer):
model = ForeignKeySource model = ForeignKeySource
# Nullable ForeignKey
class NullableForeignKeySource(models.Model):
name = models.CharField(max_length=100)
target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
related_name='nullable_sources')
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid
class PrimaryKeyManyToManyTests(TestCase): class PKManyToManyTests(TestCase):
def setUp(self): def setUp(self):
for idx in range(1, 4): for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx) target = ManyToManyTarget(name='target-%d' % idx)
...@@ -137,7 +150,7 @@ class PrimaryKeyManyToManyTests(TestCase): ...@@ -137,7 +150,7 @@ class PrimaryKeyManyToManyTests(TestCase):
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
class PrimaryKeyForeignKeyTests(TestCase): class PKForeignKeyTests(TestCase):
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
...@@ -174,7 +187,7 @@ class PrimaryKeyForeignKeyTests(TestCase): ...@@ -174,7 +187,7 @@ class PrimaryKeyForeignKeyTests(TestCase):
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
serializer.save() serializer.save()
# # Ensure source 1 is updated, and everything else is as expected # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
...@@ -184,6 +197,40 @@ class PrimaryKeyForeignKeyTests(TestCase): ...@@ -184,6 +197,40 @@ class PrimaryKeyForeignKeyTests(TestCase):
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': u'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
class PKNullableForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': u'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset)
expected = [
{'id': 1, 'name': u'source-1', 'target': None},
{'id': 2, 'name': u'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': 1}
]
self.assertEquals(serializer.data, expected)
# reverse foreign keys MUST be read_only # reverse foreign keys MUST be read_only
# In the general case they do not provide .remove() or .clear() # In the general case they do not provide .remove() or .clear()
# and cannot be arbitrarily set. # and cannot be arbitrarily set.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment