Commit 52ba2e33 by Tom Christie

Fix #285

parent 4fd8ab17
...@@ -77,16 +77,15 @@ class UpdateModelMixin(object): ...@@ -77,16 +77,15 @@ class UpdateModelMixin(object):
self.object = None self.object = None
serializer = self.get_serializer(data=request.DATA, instance=self.object) serializer = self.get_serializer(data=request.DATA, instance=self.object)
if serializer.is_valid(): if serializer.is_valid():
if self.object is None: if self.object is None:
obj = serializer.object # If PUT occurs to a non existant object, we need to set any
# TODO: Make ModelSerializers return regular instances, # attributes on the object that are implicit in the URL.
# not DeserializedObject self.update_urlconf_attributes(serializer.object)
if hasattr(obj, 'object'):
obj = obj.object
self.update_urlconf_attributes(serializer.object.object)
self.object = serializer.save() self.object = serializer.save()
return Response(serializer.data) return Response(serializer.data)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def update_urlconf_attributes(self, obj): def update_urlconf_attributes(self, obj):
......
...@@ -2,7 +2,6 @@ import copy ...@@ -2,7 +2,6 @@ import copy
import datetime import datetime
import types import types
from decimal import Decimal from decimal import Decimal
from django.core.serializers.base import DeserializedObject
from django.db import models from django.db import models
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model from rest_framework.compat import get_concrete_model
...@@ -224,9 +223,6 @@ class BaseSerializer(Field): ...@@ -224,9 +223,6 @@ class BaseSerializer(Field):
""" """
Serialize objects -> primatives. Serialize objects -> primatives.
""" """
if isinstance(obj, DeserializedObject):
obj = obj.object
if isinstance(obj, dict): if isinstance(obj, dict):
return dict([(key, self.to_native(val)) return dict([(key, self.to_native(val))
for (key, val) in obj.items()]) for (key, val) in obj.items()])
...@@ -383,23 +379,30 @@ class ModelSerializer(Serializer): ...@@ -383,23 +379,30 @@ class ModelSerializer(Serializer):
""" """
Restore the model instance. Restore the model instance.
""" """
self.m2m_data = {}
if instance: if instance:
for key, val in attrs.items(): for key, val in attrs.items():
setattr(instance, key, val) setattr(instance, key, val)
return DeserializedObject(instance) return instance
m2m_data = {}
for field in self.opts.model._meta.many_to_many: for field in self.opts.model._meta.many_to_many:
if field.name in attrs: if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name) self.m2m_data[field.name] = attrs.pop(field.name)
return DeserializedObject(self.opts.model(**attrs), m2m_data) return self.opts.model(**attrs)
def save(self): def save(self, save_m2m=True):
""" """
Save the deserialized object and return it. Save the deserialized object and return it.
""" """
self.object.save() self.object.save()
return self.object.object
if self.m2m_data and save_m2m:
for accessor_name, object_list in self.m2m_data.items():
setattr(self.object, accessor_name, object_list)
self.m2m_data = {}
return self.object
class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
......
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.utils import simplejson as json from django.utils import simplejson as json
from rest_framework import generics, status from rest_framework import generics, serializers, status
from rest_framework.tests.models import BasicModel from rest_framework.tests.models import BasicModel, Comment
factory = RequestFactory() factory = RequestFactory()
...@@ -223,3 +223,36 @@ class TestInstanceView(TestCase): ...@@ -223,3 +223,36 @@ class TestInstanceView(TestCase):
self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1) updated = self.objects.get(id=1)
self.assertEquals(updated.text, 'foobar') self.assertEquals(updated.text, 'foobar')
# Regression test for #285
class CommentSerializer(serializers.ModelSerializer):
class Meta:
model = Comment
exclude = ('created',)
class CommentView(generics.ListCreateAPIView):
serializer_class = CommentSerializer
model = Comment
class TestCreateModelWithAutoNowAddField(TestCase):
def setUp(self):
self.objects = Comment.objects
self.view = CommentView.as_view()
def test_create_model_with_auto_now_add_field(self):
"""
Regression test for #285
https://github.com/tomchristie/django-rest-framework/issues/285
"""
content = {'email': 'foobar@example.com', 'content': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED)
created = self.objects.get(id=1)
self.assertEquals(created.content, 'foobar')
...@@ -83,3 +83,11 @@ class TaggedItem(RESTFrameworkModel): ...@@ -83,3 +83,11 @@ class TaggedItem(RESTFrameworkModel):
class Bookmark(RESTFrameworkModel): class Bookmark(RESTFrameworkModel):
url = models.URLField() url = models.URLField()
tags = GenericRelation(TaggedItem) tags = GenericRelation(TaggedItem)
# Model for regression test for #285
class Comment(RESTFrameworkModel):
email = models.EmailField()
content = models.CharField(max_length=200)
created = models.DateTimeField(auto_now_add=True)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment