Commit 6fa589fe by Tom Christie

Pagination support

parent 43d3634e
...@@ -139,7 +139,13 @@ class Field(object): ...@@ -139,7 +139,13 @@ class Field(object):
if hasattr(self, 'model_field'): if hasattr(self, 'model_field'):
return self.to_native(self.model_field._get_val_from_obj(obj)) return self.to_native(self.model_field._get_val_from_obj(obj))
return self.to_native(getattr(obj, self.source or field_name)) if self.source:
value = obj
for component in self.source.split('.'):
value = getattr(value, component)
else:
value = getattr(obj, field_name)
return self.to_native(value)
def to_native(self, value): def to_native(self, value):
""" """
...@@ -175,7 +181,7 @@ class RelatedField(Field): ...@@ -175,7 +181,7 @@ class RelatedField(Field):
""" """
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
obj = getattr(obj, field_name) obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'): if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
return [self.to_native(item) for item in obj.all()] return [self.to_native(item) for item in obj.all()]
return self.to_native(obj) return self.to_native(obj)
...@@ -215,10 +221,10 @@ class PrimaryKeyRelatedField(RelatedField): ...@@ -215,10 +221,10 @@ class PrimaryKeyRelatedField(RelatedField):
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
try: try:
obj = obj.serializable_value(field_name) obj = obj.serializable_value(self.source or field_name)
except AttributeError: except AttributeError:
field = obj._meta.get_field_by_name(field_name)[0] field = obj._meta.get_field_by_name(field_name)[0]
obj = getattr(obj, field_name) obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ == 'RelatedManager': if obj.__class__.__name__ == 'RelatedManager':
return [self.to_native(item.pk) for item in obj.all()] return [self.to_native(item.pk) for item in obj.all()]
elif isinstance(field, RelatedObject): elif isinstance(field, RelatedObject):
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
Generic views that provide commmonly needed behaviour. Generic views that provide commmonly needed behaviour.
""" """
from rest_framework import views, mixins, serializers from rest_framework import views, mixins
from rest_framework.settings import api_settings
from django.views.generic.detail import SingleObjectMixin from django.views.generic.detail import SingleObjectMixin
from django.views.generic.list import MultipleObjectMixin from django.views.generic.list import MultipleObjectMixin
...@@ -14,23 +15,37 @@ class BaseView(views.APIView): ...@@ -14,23 +15,37 @@ class BaseView(views.APIView):
Base class for all other generic views. Base class for all other generic views.
""" """
serializer_class = None serializer_class = None
model_serializer_class = api_settings.MODEL_SERIALIZER
pagination_serializer_class = api_settings.PAGINATION_SERIALIZER
paginate_by = api_settings.PAGINATE_BY
def get_serializer(self, data=None, files=None, instance=None): def get_serializer_context(self):
return {
'request': self.request,
'format': self.kwargs.get('format', None)
}
def get_serializer(self, data=None, files=None, instance=None, kwargs=None):
# TODO: add support for files # TODO: add support for files
# TODO: add support for seperate serializer/deserializer # TODO: add support for seperate serializer/deserializer
serializer_class = self.serializer_class serializer_class = self.serializer_class
kwargs = kwargs or {}
if serializer_class is None: if serializer_class is None:
class DefaultSerializer(serializers.ModelSerializer): class DefaultSerializer(self.model_serializer_class):
class Meta: class Meta:
model = self.model model = self.model
serializer_class = DefaultSerializer serializer_class = DefaultSerializer
context = { context = self.get_serializer_context()
'request': self.request, return serializer_class(data, instance=instance, context=context, **kwargs)
'format': self.kwargs.get('format', None)
} def get_pagination_serializer(self, page=None):
return serializer_class(data, instance=instance, context=context) serializer_class = self.pagination_serializer_class
context = self.get_serializer_context()
ret = serializer_class(instance=page, context=context)
ret.fields['results'] = self.get_serializer(kwargs={'source': 'object_list'})
return ret
class MultipleObjectBaseView(MultipleObjectMixin, BaseView): class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
......
...@@ -7,6 +7,7 @@ which allows mixin classes to be composed in interesting ways. ...@@ -7,6 +7,7 @@ which allows mixin classes to be composed in interesting ways.
Eg. Use mixins to build a Resource class, and have a Router class Eg. Use mixins to build a Resource class, and have a Router class
perform the binding of http methods to actions for us. perform the binding of http methods to actions for us.
""" """
from django.http import Http404
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
...@@ -30,9 +31,27 @@ class ListModelMixin(object): ...@@ -30,9 +31,27 @@ class ListModelMixin(object):
List a queryset. List a queryset.
Should be mixed in with `MultipleObjectBaseView`. Should be mixed in with `MultipleObjectBaseView`.
""" """
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
self.object_list = self.get_queryset() self.object_list = self.get_queryset()
serializer = self.get_serializer(instance=self.object_list)
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
allow_empty = self.get_allow_empty()
if not allow_empty and len(self.object_list) == 0:
error_args = {'class_name': self.__class__.__name__}
raise Http404(self.empty_error % error_args)
# Pagination size is set by the `.paginate_by` attribute,
# which may be `None` to disable pagination.
page_size = self.get_paginate_by(self.object_list)
if page_size:
paginator, page, queryset, is_paginated = self.paginate_queryset(self.object_list, page_size)
serializer = self.get_pagination_serializer(page)
else:
serializer = self.get_serializer(instance=self.object_list)
return Response(serializer.data) return Response(serializer.data)
......
from rest_framework import serializers
# TODO: Support URLconf kwarg-style paging
class NextPageField(serializers.Field):
def to_native(self, value):
if not value.has_next():
return None
page = value.next_page_number()
request = self.context['request']
return request.build_absolute_uri('?page=%d' % page)
class PreviousPageField(serializers.Field):
def to_native(self, value):
if not value.has_previous():
return None
page = value.previous_page_number()
request = self.context['request']
return request.build_absolute_uri('?page=%d' % page)
class PaginationSerializer(serializers.Serializer):
count = serializers.Field(source='paginator.count')
next = NextPageField(source='*')
previous = PreviousPageField(source='*')
def to_native(self, obj):
"""
Prevent default behaviour of iterating over elements, and serializing
each in turn.
"""
return self.convert_object(obj)
...@@ -44,6 +44,10 @@ DEFAULTS = { ...@@ -44,6 +44,10 @@ DEFAULTS = {
'anon': None, 'anon': None,
}, },
'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer',
'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer',
'PAGINATE_BY': 20,
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None, 'UNAUTHENTICATED_TOKEN': None,
...@@ -65,6 +69,8 @@ IMPORT_STRINGS = ( ...@@ -65,6 +69,8 @@ IMPORT_STRINGS = (
'DEFAULT_PERMISSIONS', 'DEFAULT_PERMISSIONS',
'DEFAULT_THROTTLES', 'DEFAULT_THROTTLES',
'DEFAULT_CONTENT_NEGOTIATION', 'DEFAULT_CONTENT_NEGOTIATION',
'MODEL_SERIALIZER',
'PAGINATION_SERIALIZER',
'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN', 'UNAUTHENTICATED_TOKEN',
) )
......
from django import template from django import template
from django.core.urlresolvers import reverse, NoReverseMatch from django.core.urlresolvers import reverse
from django.http import QueryDict from django.http import QueryDict
from django.utils.encoding import force_unicode from django.utils.encoding import force_unicode
from django.utils.html import escape from django.utils.html import escape
......
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status
from rest_framework.tests.models import BasicModel
factory = RequestFactory()
class RootView(generics.RootAPIView):
"""
Example description for OPTIONS.
"""
model = BasicModel
paginate_by = 10
class TestPaginatedView(TestCase):
def setUp(self):
"""
Create 26 BasicModel intances.
"""
for char in 'abcdefghijklmnopqrstuvwxyz':
BasicModel(text=char * 3).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = RootView.as_view()
def test_get_paginated_root_view(self):
"""
GET requests to paginated RootAPIView should return paginated results.
"""
request = factory.get('/')
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data['count'], 26)
self.assertEquals(response.data['results'], self.data[:10])
self.assertNotEquals(response.data['next'], None)
self.assertEquals(response.data['previous'], None)
request = factory.get(response.data['next'])
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data['count'], 26)
self.assertEquals(response.data['results'], self.data[10:20])
self.assertNotEquals(response.data['next'], None)
self.assertNotEquals(response.data['previous'], None)
request = factory.get(response.data['next'])
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data['count'], 26)
self.assertEquals(response.data['results'], self.data[20:])
self.assertEquals(response.data['next'], None)
self.assertNotEquals(response.data['previous'], None)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment