Commit 37dc2520 by Tom Christie

Merge pull request #2428 from tomchristie/cursor-pagination

Cursor pagination
parents 9ec08ce5 43d983fa
......@@ -114,7 +114,7 @@ class OrderingFilter(BaseFilterBackend):
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
def get_ordering(self, request):
def get_ordering(self, request, queryset, view):
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
......@@ -124,7 +124,13 @@ class OrderingFilter(BaseFilterBackend):
"""
params = request.query_params.get(self.ordering_param)
if params:
return [param.strip() for param in params.split(',')]
fields = [param.strip() for param in params.split(',')]
ordering = self.remove_invalid_fields(queryset, fields, view)
if ordering:
return ordering
# No ordering was included, or all the ordering fields were invalid
return self.get_default_ordering(view)
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
......@@ -132,7 +138,7 @@ class OrderingFilter(BaseFilterBackend):
return (ordering,)
return ordering
def remove_invalid_fields(self, queryset, ordering, view):
def remove_invalid_fields(self, queryset, fields, view):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None:
......@@ -152,18 +158,10 @@ class OrderingFilter(BaseFilterBackend):
valid_fields = [field.name for field in queryset.model._meta.fields]
valid_fields += queryset.query.aggregates.keys()
return [term for term in ordering if term.lstrip('-') in valid_fields]
return [term for term in fields if term.lstrip('-') in valid_fields]
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request)
if ordering:
# Skip any incorrect parameters
ordering = self.remove_invalid_fields(queryset, ordering, view)
if not ordering:
# Use 'ordering' attribute by default
ordering = self.get_default_ordering(view)
ordering = self.get_ordering(request, queryset, view)
if ordering:
return queryset.order_by(*ordering)
......
......@@ -63,10 +63,20 @@ a single block in the template.
.pagination>.disabled>a,
.pagination>.disabled>a:hover,
.pagination>.disabled>a:focus {
cursor: default;
cursor: not-allowed;
pointer-events: none;
}
.pager>.disabled>a,
.pager>.disabled>a:hover,
.pager>.disabled>a:focus {
pointer-events: none;
}
.pager .next {
margin-left: 10px;
}
/*=== dabapps bootstrap styles ====*/
html {
......
<ul class="pager">
{% if previous_url %}
<li class="previous"><a href="{{ previous_url }}">&laquo; Previous</a></li>
{% else %}
<li class="previous disabled"><a href="#">&laquo; Previous</a></li>
{% endif %}
{% if next_url %}
<li class="next"><a href="{{ next_url }}">Next &raquo;</a></li>
{% else %}
<li class="next disabled"><a href="#">Next &raquo;</li>
{% endif %}
</ul>
# coding: utf-8
from __future__ import unicode_literals
from rest_framework import exceptions, generics, pagination, serializers, status, filters
from rest_framework.request import Request
......@@ -77,6 +78,20 @@ class TestPaginationIntegration:
'count': 50
}
def test_setting_page_size_to_zero(self):
"""
When page_size parameter is invalid it should return to the default.
"""
request = factory.get('/', {'page_size': 0})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
'results': [2, 4, 6, 8, 10],
'previous': None,
'next': 'http://testserver/?page=2&page_size=0',
'count': 50
}
def test_additional_query_params_are_preserved(self):
request = factory.get('/', {'page': 2, 'filter': 'even'})
response = self.view(request)
......@@ -88,6 +103,14 @@ class TestPaginationIntegration:
'count': 50
}
def test_404_not_found_for_zero_page(self):
request = factory.get('/', {'page': '0'})
response = self.view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {
'detail': 'Invalid page "0": That page number is less than 1.'
}
def test_404_not_found_for_invalid_page(self):
request = factory.get('/', {'page': 'invalid'})
response = self.view(request)
......@@ -422,6 +445,179 @@ class TestLimitOffset:
assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
class TestCursorPagination:
"""
Unit tests for `pagination.CursorPagination`.
"""
def setup(self):
class MockObject(object):
def __init__(self, idx):
self.created = idx
class MockQuerySet(object):
def __init__(self, items):
self.items = items
def filter(self, created__gt=None, created__lt=None):
if created__gt is not None:
return MockQuerySet([
item for item in self.items
if item.created > int(created__gt)
])
assert created__lt is not None
return MockQuerySet([
item for item in self.items
if item.created < int(created__lt)
])
def order_by(self, *ordering):
if ordering[0].startswith('-'):
return MockQuerySet(list(reversed(self.items)))
return self
def __getitem__(self, sliced):
return self.items[sliced]
class ExamplePagination(pagination.CursorPagination):
page_size = 5
ordering = 'created'
self.pagination = ExamplePagination()
self.queryset = MockQuerySet([
MockObject(idx) for idx in [
1, 1, 1, 1, 1,
1, 2, 3, 4, 4,
4, 4, 5, 6, 7,
7, 7, 7, 7, 7,
7, 7, 7, 8, 9,
9, 9, 9, 9, 9
]
])
def get_pages(self, url):
"""
Given a URL return a tuple of:
(previous page, current page, next page, previous url, next url)
"""
request = Request(factory.get(url))
queryset = self.pagination.paginate_queryset(self.queryset, request)
current = [item.created for item in queryset]
next_url = self.pagination.get_next_link()
previous_url = self.pagination.get_previous_link()
if next_url is not None:
request = Request(factory.get(next_url))
queryset = self.pagination.paginate_queryset(self.queryset, request)
next = [item.created for item in queryset]
else:
next = None
if previous_url is not None:
request = Request(factory.get(previous_url))
queryset = self.pagination.paginate_queryset(self.queryset, request)
previous = [item.created for item in queryset]
else:
previous = None
return (previous, current, next, previous_url, next_url)
def test_invalid_cursor(self):
request = Request(factory.get('/', {'cursor': '123'}))
with pytest.raises(exceptions.NotFound):
self.pagination.paginate_queryset(self.queryset, request)
def test_use_with_ordering_filter(self):
class MockView:
filter_backends = (filters.OrderingFilter,)
ordering_fields = ['username', 'created']
ordering = 'created'
request = Request(factory.get('/', {'ordering': 'username'}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('username',)
request = Request(factory.get('/', {'ordering': '-username'}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('-username',)
request = Request(factory.get('/', {'ordering': 'invalid'}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('created',)
def test_cursor_pagination(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/')
assert previous is None
assert current == [1, 1, 1, 1, 1]
assert next == [1, 2, 3, 4, 4]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [1, 1, 1, 1, 1]
assert current == [1, 2, 3, 4, 4]
assert next == [4, 4, 5, 6, 7]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [1, 2, 3, 4, 4]
assert current == [4, 4, 5, 6, 7]
assert next == [7, 7, 7, 7, 7]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [4, 4, 4, 5, 6] # Paging artifact
assert current == [7, 7, 7, 7, 7]
assert next == [7, 7, 7, 8, 9]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [7, 7, 7, 7, 7]
assert current == [7, 7, 7, 8, 9]
assert next == [9, 9, 9, 9, 9]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [7, 7, 7, 8, 9]
assert current == [9, 9, 9, 9, 9]
assert next is None
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [7, 7, 7, 7, 7]
assert current == [7, 7, 7, 8, 9]
assert next == [9, 9, 9, 9, 9]
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [4, 4, 5, 6, 7]
assert current == [7, 7, 7, 7, 7]
assert next == [8, 9, 9, 9, 9] # Paging artifact
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [1, 2, 3, 4, 4]
assert current == [4, 4, 5, 6, 7]
assert next == [7, 7, 7, 7, 7]
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [1, 1, 1, 1, 1]
assert current == [1, 2, 3, 4, 4]
assert next == [4, 4, 5, 6, 7]
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous is None
assert current == [1, 1, 1, 1, 1]
assert next == [1, 2, 3, 4, 4]
assert isinstance(self.pagination.to_html(), type(''))
def test_get_displayed_page_numbers():
"""
Test our contextual page display function.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment