Commit acc8c1fa by Tom Christie

force_insert, force_update arguments. Closes #484.

Confirmed by `assertNumQueries(…)` in tests.
parent a53596ce
...@@ -40,6 +40,10 @@ You can determine your currently installed version using `pip freeze`: ...@@ -40,6 +40,10 @@ You can determine your currently installed version using `pip freeze`:
## 2.2.x series ## 2.2.x series
### Master
* `Serializer.save()` now supports arbitrary keyword args which are passed through to the object `.save()` method. Mixins use `force_insert` and `force_update` where appropriate, resulting in one less database query.
### 2.2.4 ### 2.2.4
**Date**: 13th March 2013 **Date**: 13th March 2013
......
...@@ -44,7 +44,7 @@ class CreateModelMixin(object): ...@@ -44,7 +44,7 @@ class CreateModelMixin(object):
if serializer.is_valid(): if serializer.is_valid():
self.pre_save(serializer.object) self.pre_save(serializer.object)
self.object = serializer.save() self.object = serializer.save(force_insert=True)
self.post_save(self.object, created=True) self.post_save(self.object, created=True)
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, return Response(serializer.data, status=status.HTTP_201_CREATED,
...@@ -119,9 +119,11 @@ class UpdateModelMixin(object): ...@@ -119,9 +119,11 @@ class UpdateModelMixin(object):
# we have relevant permissions, as if this was a POST request. # we have relevant permissions, as if this was a POST request.
self.check_permissions(clone_request(request, 'POST')) self.check_permissions(clone_request(request, 'POST'))
created = True created = True
save_kwargs = {'force_insert': True}
success_status_code = status.HTTP_201_CREATED success_status_code = status.HTTP_201_CREATED
else: else:
created = False created = False
save_kwargs = {'force_update': True}
success_status_code = status.HTTP_200_OK success_status_code = status.HTTP_200_OK
serializer = self.get_serializer(self.object, data=request.DATA, serializer = self.get_serializer(self.object, data=request.DATA,
...@@ -129,7 +131,7 @@ class UpdateModelMixin(object): ...@@ -129,7 +131,7 @@ class UpdateModelMixin(object):
if serializer.is_valid(): if serializer.is_valid():
self.pre_save(serializer.object) self.pre_save(serializer.object)
self.object = serializer.save() self.object = serializer.save(**save_kwargs)
self.post_save(self.object, created=created) self.post_save(self.object, created=created)
return Response(serializer.data, status=success_status_code) return Response(serializer.data, status=success_status_code)
......
...@@ -391,17 +391,17 @@ class BaseSerializer(Field): ...@@ -391,17 +391,17 @@ class BaseSerializer(Field):
return self._data return self._data
def save_object(self, obj): def save_object(self, obj, **kwargs):
obj.save() obj.save(**kwargs)
def save(self): def save(self, **kwargs):
""" """
Save the deserialized object and return it. Save the deserialized object and return it.
""" """
if isinstance(self.object, list): if isinstance(self.object, list):
[self.save_object(item) for item in self.object] [self.save_object(item, **kwargs) for item in self.object]
else: else:
self.save_object(self.object) self.save_object(self.object, **kwargs)
return self.object return self.object
...@@ -621,11 +621,11 @@ class ModelSerializer(Serializer): ...@@ -621,11 +621,11 @@ class ModelSerializer(Serializer):
if instance: if instance:
return self.full_clean(instance) return self.full_clean(instance)
def save_object(self, obj): def save_object(self, obj, **kwargs):
""" """
Save the deserialized object and return it. Save the deserialized object and return it.
""" """
obj.save() obj.save(**kwargs)
if getattr(self, 'm2m_data', None): if getattr(self, 'm2m_data', None):
for accessor_name, object_list in self.m2m_data.items(): for accessor_name, object_list in self.m2m_data.items():
......
...@@ -184,7 +184,7 @@ class TestInstanceView(TestCase): ...@@ -184,7 +184,7 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'} content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content), request = factory.put('/1', json.dumps(content),
content_type='application/json') content_type='application/json')
with self.assertNumQueries(3): with self.assertNumQueries(2):
response = self.view(request, pk='1').render() response = self.view(request, pk='1').render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
...@@ -199,7 +199,7 @@ class TestInstanceView(TestCase): ...@@ -199,7 +199,7 @@ class TestInstanceView(TestCase):
request = factory.patch('/1', json.dumps(content), request = factory.patch('/1', json.dumps(content),
content_type='application/json') content_type='application/json')
with self.assertNumQueries(3): with self.assertNumQueries(2):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
...@@ -248,7 +248,7 @@ class TestInstanceView(TestCase): ...@@ -248,7 +248,7 @@ class TestInstanceView(TestCase):
content = {'id': 999, 'text': 'foobar'} content = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content), request = factory.put('/1', json.dumps(content),
content_type='application/json') content_type='application/json')
with self.assertNumQueries(3): with self.assertNumQueries(2):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
...@@ -264,7 +264,7 @@ class TestInstanceView(TestCase): ...@@ -264,7 +264,7 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'} content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content), request = factory.put('/1', json.dumps(content),
content_type='application/json') content_type='application/json')
with self.assertNumQueries(4): with self.assertNumQueries(3):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
...@@ -280,7 +280,7 @@ class TestInstanceView(TestCase): ...@@ -280,7 +280,7 @@ class TestInstanceView(TestCase):
# pk fields can not be created on demand, only the database can set the pk for a new object # pk fields can not be created on demand, only the database can set the pk for a new object
request = factory.put('/5', json.dumps(content), request = factory.put('/5', json.dumps(content),
content_type='application/json') content_type='application/json')
with self.assertNumQueries(4): with self.assertNumQueries(3):
response = self.view(request, pk=5).render() response = self.view(request, pk=5).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
new_obj = self.objects.get(pk=5) new_obj = self.objects.get(pk=5)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment