Commit 42b3fdbd by Tom Christie

Merge pull request #279 from tomchristie/hyperlinked-relationships

Hyperlinked relationships
parents ad214976 eaebb397
...@@ -123,16 +123,8 @@ class BaseSerializer(Field): ...@@ -123,16 +123,8 @@ class BaseSerializer(Field):
# Get the explicitly declared fields # Get the explicitly declared fields
for key, field in self.fields.items(): for key, field in self.fields.items():
ret[key] = field ret[key] = field
# Determine if the declared field corrosponds to a model field.
try:
if key == 'pk':
model_field = obj._meta.pk
else:
model_field = obj._meta.get_field_by_name(key)[0]
except:
model_field = None
# Set up the field # Set up the field
field.initialize(parent=self, model_field=model_field) field.initialize(parent=self)
# Add in the default fields # Add in the default fields
fields = self.default_fields(serialize, obj, data, nested) fields = self.default_fields(serialize, obj, data, nested)
...@@ -157,12 +149,12 @@ class BaseSerializer(Field): ...@@ -157,12 +149,12 @@ class BaseSerializer(Field):
##### #####
# Field methods - used when the serializer class is itself used as a field. # Field methods - used when the serializer class is itself used as a field.
def initialize(self, parent, model_field=None): def initialize(self, parent):
""" """
Same behaviour as usual Field, except that we need to keep track Same behaviour as usual Field, except that we need to keep track
of state so that we can deal with handling maximum depth and recursion. of state so that we can deal with handling maximum depth and recursion.
""" """
super(BaseSerializer, self).initialize(parent, model_field) super(BaseSerializer, self).initialize(parent)
self.stack = parent.stack[:] self.stack = parent.stack[:]
if parent.opts.nested and not isinstance(parent.opts.nested, bool): if parent.opts.nested and not isinstance(parent.opts.nested, bool):
self.opts.nested = parent.opts.nested - 1 self.opts.nested = parent.opts.nested - 1
...@@ -296,12 +288,22 @@ class ModelSerializerOptions(SerializerOptions): ...@@ -296,12 +288,22 @@ class ModelSerializerOptions(SerializerOptions):
self.model = getattr(meta, 'model', None) self.model = getattr(meta, 'model', None)
class ModelSerializer(RelatedField, Serializer): class ModelSerializer(Serializer):
""" """
A serializer that deals with model instances and querysets. A serializer that deals with model instances and querysets.
""" """
_options_class = ModelSerializerOptions _options_class = ModelSerializerOptions
def field_to_native(self, obj, field_name):
"""
Override default so that we can apply ModelSerializer as a nested
field to relationships.
"""
obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
def default_fields(self, serialize, obj=None, data=None, nested=False): def default_fields(self, serialize, obj=None, data=None, nested=False):
""" """
Return all the fields that should be serialized for the model. Return all the fields that should be serialized for the model.
...@@ -330,7 +332,7 @@ class ModelSerializer(RelatedField, Serializer): ...@@ -330,7 +332,7 @@ class ModelSerializer(RelatedField, Serializer):
field = self.get_field(model_field) field = self.get_field(model_field)
if field: if field:
field.initialize(parent=self, model_field=model_field) field.initialize(parent=self)
ret[model_field.name] = field ret[model_field.name] = field
return ret return ret
...@@ -339,7 +341,7 @@ class ModelSerializer(RelatedField, Serializer): ...@@ -339,7 +341,7 @@ class ModelSerializer(RelatedField, Serializer):
""" """
Returns a default instance of the pk field. Returns a default instance of the pk field.
""" """
return Field(readonly=True) return Field()
def get_nested_field(self, model_field): def get_nested_field(self, model_field):
""" """
...@@ -373,7 +375,7 @@ class ModelSerializer(RelatedField, Serializer): ...@@ -373,7 +375,7 @@ class ModelSerializer(RelatedField, Serializer):
try: try:
return field_mapping[model_field.__class__]() return field_mapping[model_field.__class__]()
except KeyError: except KeyError:
return Field() return ModelField(model_field=model_field)
def restore_object(self, attrs, instance=None): def restore_object(self, attrs, instance=None):
""" """
...@@ -396,3 +398,40 @@ class ModelSerializer(RelatedField, Serializer): ...@@ -396,3 +398,40 @@ class ModelSerializer(RelatedField, Serializer):
""" """
self.object.save() self.object.save()
return self.object.object return self.object.object
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
"""
Options for HyperlinkedModelSerializer
"""
def __init__(self, meta):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
class HyperlinkedModelSerializer(ModelSerializer):
"""
"""
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
url = HyperlinkedIdentityField()
def __init__(self, *args, **kwargs):
super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name()
def _get_default_view_name(self):
"""
Return the view name to use if 'view_name' is not specified in 'Meta'
"""
model_meta = self.opts.model._meta
format_kwargs = {
'app_label': model_meta.app_label,
'model_name': model_meta.object_name.lower()
}
return self._default_view_name % format_kwargs
def get_pk_field(self, model_field):
return None
...@@ -46,7 +46,7 @@ DEFAULTS = { ...@@ -46,7 +46,7 @@ DEFAULTS = {
'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer', 'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer',
'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer', 'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer',
'PAGINATE_BY': 20, 'PAGINATE_BY': None,
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None, 'UNAUTHENTICATED_TOKEN': None,
......
...@@ -13,7 +13,6 @@ class RootView(generics.ListCreateAPIView): ...@@ -13,7 +13,6 @@ class RootView(generics.ListCreateAPIView):
Example description for OPTIONS. Example description for OPTIONS.
""" """
model = BasicModel model = BasicModel
paginate_by = None
class InstanceView(generics.RetrieveUpdateDestroyAPIView): class InstanceView(generics.RetrieveUpdateDestroyAPIView):
......
from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status, serializers
from rest_framework.tests.models import BasicModel
factory = RequestFactory()
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
)
class TestHyperlinkedView(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers'
def setUp(self):
"""
Create 3 BasicModel intances.
"""
items = ['foo', 'bar', 'baz']
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.list_view = BasicList.as_view()
self.detail_view = BasicDetail.as_view()
def test_get_list_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/')
response = self.list_view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
def test_get_detail_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/1')
response = self.detail_view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment