Commit 9d6df1e1 by wajeeha-khalid

Merge pull request #11189 from edx/ekafeel/discussion-api-namespace-pagination

Ekafeel/discussion api namespace pagination
parents cf5dc107 83206535
...@@ -18,7 +18,6 @@ from courseware.courses import get_course_with_access ...@@ -18,7 +18,6 @@ from courseware.courses import get_course_with_access
from discussion_api.exceptions import ThreadNotFoundError, CommentNotFoundError, DiscussionDisabledError from discussion_api.exceptions import ThreadNotFoundError, CommentNotFoundError, DiscussionDisabledError
from discussion_api.forms import CommentActionsForm, ThreadActionsForm from discussion_api.forms import CommentActionsForm, ThreadActionsForm
from discussion_api.pagination import get_paginated_data
from discussion_api.permissions import ( from discussion_api.permissions import (
can_delete, can_delete,
get_editable_fields, get_editable_fields,
...@@ -38,6 +37,7 @@ from django_comment_common.signals import ( ...@@ -38,6 +37,7 @@ from django_comment_common.signals import (
comment_deleted, comment_deleted,
) )
from django_comment_client.utils import get_accessible_discussion_modules, is_commentable_cohorted from django_comment_client.utils import get_accessible_discussion_modules, is_commentable_cohorted
from lms.djangoapps.discussion_api.pagination import DiscussionAPIPagination
from lms.lib.comment_client.comment import Comment from lms.lib.comment_client.comment import Comment
from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.utils import CommentClientRequestError from lms.lib.comment_client.utils import CommentClientRequestError
...@@ -328,22 +328,30 @@ def get_thread_list( ...@@ -328,22 +328,30 @@ def get_thread_list(
}) })
if following: if following:
threads, result_page, num_pages = context["cc_requester"].subscribed_threads(query_params) paginated_results = context["cc_requester"].subscribed_threads(query_params)
else: else:
query_params["course_id"] = unicode(course.id) query_params["course_id"] = unicode(course.id)
query_params["commentable_ids"] = ",".join(topic_id_list) if topic_id_list else None query_params["commentable_ids"] = ",".join(topic_id_list) if topic_id_list else None
query_params["text"] = text_search query_params["text"] = text_search
threads, result_page, num_pages, text_search_rewrite = Thread.search(query_params) paginated_results = Thread.search(query_params)
# The comments service returns the last page of results if the requested # The comments service returns the last page of results if the requested
# page is beyond the last page, but we want be consistent with DRF's general # page is beyond the last page, but we want be consistent with DRF's general
# behavior and return a PageNotFoundError in that case # behavior and return a PageNotFoundError in that case
if result_page != page: if paginated_results.page != page:
raise PageNotFoundError("Page not found (No results on this page).") raise PageNotFoundError("Page not found (No results on this page).")
results = [ThreadSerializer(thread, context=context).data for thread in threads] results = [ThreadSerializer(thread, context=context).data for thread in paginated_results.collection]
ret = get_paginated_data(request, results, page, num_pages)
ret["text_search_rewrite"] = text_search_rewrite paginator = DiscussionAPIPagination(
return ret request,
paginated_results.page,
paginated_results.num_pages,
paginated_results.thread_count
)
return paginator.get_paginated_response({
"results": results,
"text_search_rewrite": paginated_results.corrected_text,
})
def get_comment_list(request, thread_id, endorsed, page, page_size): def get_comment_list(request, thread_id, endorsed, page, page_size):
...@@ -412,7 +420,8 @@ def get_comment_list(request, thread_id, endorsed, page, page_size): ...@@ -412,7 +420,8 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
num_pages = (resp_total + page_size - 1) / page_size if resp_total else 1 num_pages = (resp_total + page_size - 1) / page_size if resp_total else 1
results = [CommentSerializer(response, context=context).data for response in responses] results = [CommentSerializer(response, context=context).data for response in responses]
return get_paginated_data(request, results, page, num_pages) paginator = DiscussionAPIPagination(request, page, num_pages, resp_total)
return paginator.get_paginated_response(results)
def _check_fields(allowed_fields, data, message): def _check_fields(allowed_fields, data, message):
...@@ -747,11 +756,15 @@ def get_response_comments(request, comment_id, page, page_size): ...@@ -747,11 +756,15 @@ def get_response_comments(request, comment_id, page, page_size):
response_skip = page_size * (page - 1) response_skip = page_size * (page - 1)
paged_response_comments = response_comments[response_skip:(response_skip + page_size)] paged_response_comments = response_comments[response_skip:(response_skip + page_size)]
if len(paged_response_comments) == 0 and page != 1:
raise PageNotFoundError("Page not found (No results on this page).")
results = [CommentSerializer(comment, context=context).data for comment in paged_response_comments] results = [CommentSerializer(comment, context=context).data for comment in paged_response_comments]
comments_count = len(response_comments) comments_count = len(response_comments)
num_pages = (comments_count + page_size - 1) / page_size if comments_count else 1 num_pages = (comments_count + page_size - 1) / page_size if comments_count else 1
return get_paginated_data(request, results, page, num_pages) paginator = DiscussionAPIPagination(request, page, num_pages, comments_count)
return paginator.get_paginated_response(results)
except CommentClientRequestError: except CommentClientRequestError:
raise CommentNotFoundError("Comment not found") raise CommentNotFoundError("Comment not found")
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Discussion API pagination support Discussion API pagination support
""" """
from rest_framework.utils.urls import replace_query_param from rest_framework.utils.urls import replace_query_param
from openedx.core.lib.api.paginators import NamespacedPageNumberPagination
class _Page(object): class _Page(object):
...@@ -9,12 +10,11 @@ class _Page(object): ...@@ -9,12 +10,11 @@ class _Page(object):
Implements just enough of the django.core.paginator.Page interface to allow Implements just enough of the django.core.paginator.Page interface to allow
PaginationSerializer to work. PaginationSerializer to work.
""" """
def __init__(self, object_list, page_num, num_pages): def __init__(self, page_num, num_pages):
""" """
Create a new page containing the given objects, with the given page Create a new page containing the given objects, with the given page
number and number of pages number and number of pages
""" """
self.object_list = object_list
self.page_num = page_num self.page_num = page_num
self.num_pages = num_pages self.num_pages = num_pages
...@@ -35,33 +35,50 @@ class _Page(object): ...@@ -35,33 +35,50 @@ class _Page(object):
return self.page_num - 1 return self.page_num - 1
def get_paginated_data(request, results, page_num, per_page): class DiscussionAPIPagination(NamespacedPageNumberPagination):
""" """
Return a dict with the following values: Subclasses NamespacedPageNumberPagination to provide custom implementation of pagination metadata
by overriding it's methods
next: The URL for the next page
previous: The URL for the previous page
results: The results on this page
""" """
# Note: Previous versions of this function used Django Rest Framework's def __init__(self, request, page_num, num_pages, result_count=0):
# paginated serializer. With the upgrade to DRF 3.1, paginated serializers """
# have been removed. We *could* use DRF's paginator classes, but there are Overrides parent constructor to take information from discussion api
# some slight differences between how DRF does pagination and how we're doing essential for the parent method
# pagination here. (For example, we respond with a next_url param even if """
# there is only one result on the current page.) To maintain backwards self.page = _Page(page_num, num_pages)
# compatability, we simulate the behavior that DRF used to provide. self.base_url = request.build_absolute_uri()
page = _Page(results, page_num, per_page) self.count = result_count
next_url, previous_url = None, None
base_url = request.build_absolute_uri()
if page.has_next(): super(DiscussionAPIPagination, self).__init__()
next_url = replace_query_param(base_url, "page", page.next_page_number())
if page.has_previous(): def get_result_count(self):
previous_url = replace_query_param(base_url, "page", page.previous_page_number()) """
Returns total number of results
"""
return self.count
return { def get_num_pages(self):
"next": next_url, """
"previous": previous_url, Returns total number of pages the response is divided into
"results": results, """
} return self.page.num_pages
def get_next_link(self):
"""
Returns absolute url of the next page if there's a next page available
otherwise returns None
"""
next_url = None
if self.page.has_next():
next_url = replace_query_param(self.base_url, "page", self.page.next_page_number())
return next_url
def get_previous_link(self):
"""
Returns absolute url of the previous page if there's a previous page available
otherwise returns None
"""
previous_url = None
if self.page.has_previous():
previous_url = replace_query_param(self.base_url, "page", self.page.previous_page_number())
return previous_url
...@@ -39,6 +39,7 @@ from discussion_api.tests.utils import ( ...@@ -39,6 +39,7 @@ from discussion_api.tests.utils import (
CommentsServiceMockMixin, CommentsServiceMockMixin,
make_minimal_cs_comment, make_minimal_cs_comment,
make_minimal_cs_thread, make_minimal_cs_thread,
make_paginated_api_response
) )
from django_comment_common.models import ( from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_ADMINISTRATOR,
...@@ -538,11 +539,15 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto ...@@ -538,11 +539,15 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto
def test_empty(self): def test_empty(self):
self.assertEqual( self.assertEqual(
self.get_thread_list([]), self.get_thread_list([], num_pages=0).data,
{ {
"pagination": {
"next": None,
"previous": None,
"num_pages": 0,
"count": 0
},
"results": [], "results": [],
"next": None,
"previous": None,
"text_search_rewrite": None, "text_search_rewrite": None,
} }
) )
...@@ -688,14 +693,14 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto ...@@ -688,14 +693,14 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto
"read": False, "read": False,
}, },
] ]
expected_result = make_paginated_api_response(
results=expected_threads, count=2, num_pages=1, next_link=None, previous_link=None
)
expected_result.update({"text_search_rewrite": None})
self.assertEqual( self.assertEqual(
self.get_thread_list(source_threads), self.get_thread_list(source_threads).data,
{ expected_result
"results": expected_threads,
"next": None,
"previous": None,
"text_search_rewrite": None,
}
) )
@ddt.data( @ddt.data(
...@@ -723,32 +728,35 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto ...@@ -723,32 +728,35 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto
def test_pagination(self): def test_pagination(self):
# N.B. Empty thread list is not realistic but convenient for this test # N.B. Empty thread list is not realistic but convenient for this test
expected_result = make_paginated_api_response(
results=[], count=0, num_pages=3, next_link="http://testserver/test_path?page=2", previous_link=None
)
expected_result.update({"text_search_rewrite": None})
self.assertEqual( self.assertEqual(
self.get_thread_list([], page=1, num_pages=3), self.get_thread_list([], page=1, num_pages=3).data,
{ expected_result
"results": [],
"next": "http://testserver/test_path?page=2",
"previous": None,
"text_search_rewrite": None,
}
) )
expected_result = make_paginated_api_response(
results=[],
count=0,
num_pages=3,
next_link="http://testserver/test_path?page=3",
previous_link="http://testserver/test_path?page=1"
)
expected_result.update({"text_search_rewrite": None})
self.assertEqual( self.assertEqual(
self.get_thread_list([], page=2, num_pages=3), self.get_thread_list([], page=2, num_pages=3).data,
{ expected_result
"results": [], )
"next": "http://testserver/test_path?page=3",
"previous": "http://testserver/test_path?page=1", expected_result = make_paginated_api_response(
"text_search_rewrite": None, results=[], count=0, num_pages=3, next_link=None, previous_link="http://testserver/test_path?page=2"
}
) )
expected_result.update({"text_search_rewrite": None})
self.assertEqual( self.assertEqual(
self.get_thread_list([], page=3, num_pages=3), self.get_thread_list([], page=3, num_pages=3).data,
{ expected_result
"results": [],
"next": None,
"previous": "http://testserver/test_path?page=2",
"text_search_rewrite": None,
}
) )
# Test page past the last one # Test page past the last one
...@@ -758,7 +766,11 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto ...@@ -758,7 +766,11 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto
@ddt.data(None, "rewritten search string") @ddt.data(None, "rewritten search string")
def test_text_search(self, text_search_rewrite): def test_text_search(self, text_search_rewrite):
self.register_get_threads_search_response([], text_search_rewrite) expected_result = make_paginated_api_response(
results=[], count=0, num_pages=0, next_link=None, previous_link=None
)
expected_result.update({"text_search_rewrite": text_search_rewrite})
self.register_get_threads_search_response([], text_search_rewrite, num_pages=0)
self.assertEqual( self.assertEqual(
get_thread_list( get_thread_list(
self.request, self.request,
...@@ -766,13 +778,8 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto ...@@ -766,13 +778,8 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto
page=1, page=1,
page_size=10, page_size=10,
text_search="test search string" text_search="test search string"
), ).data,
{ expected_result
"results": [],
"next": None,
"previous": None,
"text_search_rewrite": text_search_rewrite,
}
) )
self.assert_last_query_params({ self.assert_last_query_params({
"user_id": [unicode(self.user.id)], "user_id": [unicode(self.user.id)],
...@@ -786,17 +793,22 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto ...@@ -786,17 +793,22 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto
}) })
def test_following(self): def test_following(self):
self.register_subscribed_threads_response(self.user, [], page=1, num_pages=1) self.register_subscribed_threads_response(self.user, [], page=1, num_pages=0)
result = get_thread_list( result = get_thread_list(
self.request, self.request,
self.course.id, self.course.id,
page=1, page=1,
page_size=11, page_size=11,
following=True, following=True,
).data
expected_result = make_paginated_api_response(
results=[], count=0, num_pages=0, next_link=None, previous_link=None
) )
expected_result.update({"text_search_rewrite": None})
self.assertEqual( self.assertEqual(
result, result,
{"results": [], "next": None, "previous": None, "text_search_rewrite": None} expected_result
) )
self.assertEqual( self.assertEqual(
urlparse(httpretty.last_request().path).path, urlparse(httpretty.last_request().path).path,
...@@ -813,17 +825,22 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto ...@@ -813,17 +825,22 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto
@ddt.data("unanswered", "unread") @ddt.data("unanswered", "unread")
def test_view_query(self, query): def test_view_query(self, query):
self.register_get_threads_response([], page=1, num_pages=1) self.register_get_threads_response([], page=1, num_pages=0)
result = get_thread_list( result = get_thread_list(
self.request, self.request,
self.course.id, self.course.id,
page=1, page=1,
page_size=11, page_size=11,
view=query, view=query,
).data
expected_result = make_paginated_api_response(
results=[], count=0, num_pages=0, next_link=None, previous_link=None
) )
expected_result.update({"text_search_rewrite": None})
self.assertEqual( self.assertEqual(
result, result,
{"results": [], "next": None, "previous": None, "text_search_rewrite": None} expected_result
) )
self.assertEqual( self.assertEqual(
urlparse(httpretty.last_request().path).path, urlparse(httpretty.last_request().path).path,
...@@ -854,18 +871,20 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto ...@@ -854,18 +871,20 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto
http_query (str): Query string sent in the http request http_query (str): Query string sent in the http request
cc_query (str): Query string used for the comments client service cc_query (str): Query string used for the comments client service
""" """
self.register_get_threads_response([], page=1, num_pages=1) self.register_get_threads_response([], page=1, num_pages=0)
result = get_thread_list( result = get_thread_list(
self.request, self.request,
self.course.id, self.course.id,
page=1, page=1,
page_size=11, page_size=11,
order_by=http_query, order_by=http_query,
).data
expected_result = make_paginated_api_response(
results=[], count=0, num_pages=0, next_link=None, previous_link=None
) )
self.assertEqual( expected_result.update({"text_search_rewrite": None})
result, self.assertEqual(result, expected_result)
{"results": [], "next": None, "previous": None, "text_search_rewrite": None}
)
self.assertEqual( self.assertEqual(
urlparse(httpretty.last_request().path).path, urlparse(httpretty.last_request().path).path,
"/api/v1/threads" "/api/v1/threads"
...@@ -882,18 +901,20 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto ...@@ -882,18 +901,20 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, SharedModuleSto
@ddt.data("asc", "desc") @ddt.data("asc", "desc")
def test_order_direction_query(self, http_query): def test_order_direction_query(self, http_query):
self.register_get_threads_response([], page=1, num_pages=1) self.register_get_threads_response([], page=1, num_pages=0)
result = get_thread_list( result = get_thread_list(
self.request, self.request,
self.course.id, self.course.id,
page=1, page=1,
page_size=11, page_size=11,
order_direction=http_query, order_direction=http_query,
).data
expected_result = make_paginated_api_response(
results=[], count=0, num_pages=0, next_link=None, previous_link=None
) )
self.assertEqual( expected_result.update({"text_search_rewrite": None})
result, self.assertEqual(result, expected_result)
{"results": [], "next": None, "previous": None, "text_search_rewrite": None}
)
self.assertEqual( self.assertEqual(
urlparse(httpretty.last_request().path).path, urlparse(httpretty.last_request().path).path,
"/api/v1/threads" "/api/v1/threads"
...@@ -1055,8 +1076,8 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1055,8 +1076,8 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
{"thread_type": "discussion", "children": [], "resp_total": 0} {"thread_type": "discussion", "children": [], "resp_total": 0}
) )
self.assertEqual( self.assertEqual(
self.get_comment_list(discussion_thread), self.get_comment_list(discussion_thread).data,
{"results": [], "next": None, "previous": None} make_paginated_api_response(results=[], count=0, num_pages=1, next_link=None, previous_link=None)
) )
question_thread = self.make_minimal_cs_thread({ question_thread = self.make_minimal_cs_thread({
...@@ -1066,12 +1087,12 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1066,12 +1087,12 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
"non_endorsed_resp_total": 0 "non_endorsed_resp_total": 0
}) })
self.assertEqual( self.assertEqual(
self.get_comment_list(question_thread, endorsed=False), self.get_comment_list(question_thread, endorsed=False).data,
{"results": [], "next": None, "previous": None} make_paginated_api_response(results=[], count=0, num_pages=1, next_link=None, previous_link=None)
) )
self.assertEqual( self.assertEqual(
self.get_comment_list(question_thread, endorsed=True), self.get_comment_list(question_thread, endorsed=True).data,
{"results": [], "next": None, "previous": None} make_paginated_api_response(results=[], count=0, num_pages=1, next_link=None, previous_link=None)
) )
def test_basic_query_params(self): def test_basic_query_params(self):
...@@ -1173,7 +1194,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1173,7 +1194,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
] ]
actual_comments = self.get_comment_list( actual_comments = self.get_comment_list(
self.make_minimal_cs_thread({"children": source_comments}) self.make_minimal_cs_thread({"children": source_comments})
)["results"] ).data["results"]
self.assertEqual(actual_comments, expected_comments) self.assertEqual(actual_comments, expected_comments)
def test_question_content(self): def test_question_content(self):
...@@ -1184,10 +1205,10 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1184,10 +1205,10 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
"non_endorsed_resp_total": 1, "non_endorsed_resp_total": 1,
}) })
endorsed_actual = self.get_comment_list(thread, endorsed=True) endorsed_actual = self.get_comment_list(thread, endorsed=True).data
self.assertEqual(endorsed_actual["results"][0]["id"], "endorsed_comment") self.assertEqual(endorsed_actual["results"][0]["id"], "endorsed_comment")
non_endorsed_actual = self.get_comment_list(thread, endorsed=False) non_endorsed_actual = self.get_comment_list(thread, endorsed=False).data
self.assertEqual(non_endorsed_actual["results"][0]["id"], "non_endorsed_comment") self.assertEqual(non_endorsed_actual["results"][0]["id"], "non_endorsed_comment")
def test_endorsed_by_anonymity(self): def test_endorsed_by_anonymity(self):
...@@ -1203,7 +1224,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1203,7 +1224,7 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
}) })
] ]
}) })
actual_comments = self.get_comment_list(thread)["results"] actual_comments = self.get_comment_list(thread).data["results"]
self.assertIsNone(actual_comments[0]["endorsed_by"]) self.assertIsNone(actual_comments[0]["endorsed_by"])
@ddt.data( @ddt.data(
...@@ -1231,24 +1252,24 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1231,24 +1252,24 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
}) })
# Only page # Only page
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=1, page_size=5) actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=1, page_size=5).data
self.assertIsNone(actual["next"]) self.assertIsNone(actual["pagination"]["next"])
self.assertIsNone(actual["previous"]) self.assertIsNone(actual["pagination"]["previous"])
# First page of many # First page of many
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=1, page_size=2) actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=1, page_size=2).data
self.assertEqual(actual["next"], "http://testserver/test_path?page=2") self.assertEqual(actual["pagination"]["next"], "http://testserver/test_path?page=2")
self.assertIsNone(actual["previous"]) self.assertIsNone(actual["pagination"]["previous"])
# Middle page of many # Middle page of many
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=2, page_size=2) actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=2, page_size=2).data
self.assertEqual(actual["next"], "http://testserver/test_path?page=3") self.assertEqual(actual["pagination"]["next"], "http://testserver/test_path?page=3")
self.assertEqual(actual["previous"], "http://testserver/test_path?page=1") self.assertEqual(actual["pagination"]["previous"], "http://testserver/test_path?page=1")
# Last page of many # Last page of many
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=3, page_size=2) actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=3, page_size=2).data
self.assertIsNone(actual["next"]) self.assertIsNone(actual["pagination"]["next"])
self.assertEqual(actual["previous"], "http://testserver/test_path?page=2") self.assertEqual(actual["pagination"]["previous"], "http://testserver/test_path?page=2")
# Page past the end # Page past the end
thread = self.make_minimal_cs_thread({ thread = self.make_minimal_cs_thread({
...@@ -1272,18 +1293,18 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase): ...@@ -1272,18 +1293,18 @@ class GetCommentListTest(CommentsServiceMockMixin, SharedModuleStoreTestCase):
Check that requesting the given page/page_size returns the expected Check that requesting the given page/page_size returns the expected
output output
""" """
actual = self.get_comment_list(thread, endorsed=True, page=page, page_size=page_size) actual = self.get_comment_list(thread, endorsed=True, page=page, page_size=page_size).data
result_ids = [result["id"] for result in actual["results"]] result_ids = [result["id"] for result in actual["results"]]
self.assertEqual( self.assertEqual(
result_ids, result_ids,
["comment_{}".format(i) for i in range(expected_start, expected_stop)] ["comment_{}".format(i) for i in range(expected_start, expected_stop)]
) )
self.assertEqual( self.assertEqual(
actual["next"], actual["pagination"]["next"],
"http://testserver/test_path?page={}".format(expected_next) if expected_next else None "http://testserver/test_path?page={}".format(expected_next) if expected_next else None
) )
self.assertEqual( self.assertEqual(
actual["previous"], actual["pagination"]["previous"],
"http://testserver/test_path?page={}".format(expected_prev) if expected_prev else None "http://testserver/test_path?page={}".format(expected_prev) if expected_prev else None
) )
......
...@@ -5,7 +5,8 @@ from unittest import TestCase ...@@ -5,7 +5,8 @@ from unittest import TestCase
from django.test import RequestFactory from django.test import RequestFactory
from discussion_api.pagination import get_paginated_data from discussion_api.pagination import DiscussionAPIPagination
from discussion_api.tests.utils import make_paginated_api_response
class PaginationSerializerTest(TestCase): class PaginationSerializerTest(TestCase):
...@@ -16,55 +17,45 @@ class PaginationSerializerTest(TestCase): ...@@ -16,55 +17,45 @@ class PaginationSerializerTest(TestCase):
parameters returns the expected result parameters returns the expected result
""" """
request = RequestFactory().get("/test") request = RequestFactory().get("/test")
actual = get_paginated_data(request, objects, page_num, num_pages) paginator = DiscussionAPIPagination(request, page_num, num_pages)
self.assertEqual(actual, expected) actual = paginator.get_paginated_response(objects)
self.assertEqual(actual.data, expected)
def test_empty(self): def test_empty(self):
self.do_case( self.do_case(
[], 1, 0, [], 1, 0, make_paginated_api_response(
{ results=[], count=0, num_pages=0, next_link=None, previous_link=None
"next": None, )
"previous": None,
"results": [],
}
) )
def test_only_page(self): def test_only_page(self):
self.do_case( self.do_case(
["foo"], 1, 1, ["foo"], 1, 1, make_paginated_api_response(
{ results=["foo"], count=0, num_pages=1, next_link=None, previous_link=None
"next": None, )
"previous": None,
"results": ["foo"],
}
) )
def test_first_of_many(self): def test_first_of_many(self):
self.do_case( self.do_case(
["foo"], 1, 3, ["foo"], 1, 3, make_paginated_api_response(
{ results=["foo"], count=0, num_pages=3, next_link="http://testserver/test?page=2", previous_link=None
"next": "http://testserver/test?page=2", )
"previous": None,
"results": ["foo"],
}
) )
def test_last_of_many(self): def test_last_of_many(self):
self.do_case( self.do_case(
["foo"], 3, 3, ["foo"], 3, 3, make_paginated_api_response(
{ results=["foo"], count=0, num_pages=3, next_link=None, previous_link="http://testserver/test?page=2"
"next": None, )
"previous": "http://testserver/test?page=2",
"results": ["foo"],
}
) )
def test_middle_of_many(self): def test_middle_of_many(self):
self.do_case( self.do_case(
["foo"], 2, 3, ["foo"], 2, 3, make_paginated_api_response(
{ results=["foo"],
"next": "http://testserver/test?page=3", count=0,
"previous": "http://testserver/test?page=1", num_pages=3,
"results": ["foo"], next_link="http://testserver/test?page=3",
} previous_link="http://testserver/test?page=1"
)
) )
...@@ -23,6 +23,7 @@ from discussion_api.tests.utils import ( ...@@ -23,6 +23,7 @@ from discussion_api.tests.utils import (
CommentsServiceMockMixin, CommentsServiceMockMixin,
make_minimal_cs_comment, make_minimal_cs_comment,
make_minimal_cs_thread, make_minimal_cs_thread,
make_paginated_api_response
) )
from student.tests.factories import CourseEnrollmentFactory, UserFactory from student.tests.factories import CourseEnrollmentFactory, UserFactory
from util.testing import UrlResetMixin, PatchMediaTypeMixin from util.testing import UrlResetMixin, PatchMediaTypeMixin
...@@ -307,15 +308,18 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -307,15 +308,18 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
}] }]
self.register_get_threads_response(source_threads, page=1, num_pages=2) self.register_get_threads_response(source_threads, page=1, num_pages=2)
response = self.client.get(self.url, {"course_id": unicode(self.course.id), "following": ""}) response = self.client.get(self.url, {"course_id": unicode(self.course.id), "following": ""})
expected_response = make_paginated_api_response(
results=expected_threads,
count=1,
num_pages=2,
next_link="http://testserver/api/discussion/v1/threads/?course_id=x%2Fy%2Fz&page=2",
previous_link=None
)
expected_response.update({"text_search_rewrite": None})
self.assert_response_correct( self.assert_response_correct(
response, response,
200, 200,
{ expected_response
"results": expected_threads,
"next": "http://testserver/api/discussion/v1/threads/?course_id=x%2Fy%2Fz&page=2",
"previous": None,
"text_search_rewrite": None,
}
) )
self.assert_last_query_params({ self.assert_last_query_params({
"user_id": [unicode(self.user.id)], "user_id": [unicode(self.user.id)],
...@@ -374,15 +378,20 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -374,15 +378,20 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
def test_text_search(self): def test_text_search(self):
self.register_get_user_response(self.user) self.register_get_user_response(self.user)
self.register_get_threads_search_response([], None) self.register_get_threads_search_response([], None, num_pages=0)
response = self.client.get( response = self.client.get(
self.url, self.url,
{"course_id": unicode(self.course.id), "text_search": "test search string"} {"course_id": unicode(self.course.id), "text_search": "test search string"}
) )
expected_response = make_paginated_api_response(
results=[], count=0, num_pages=0, next_link=None, previous_link=None
)
expected_response.update({"text_search_rewrite": None})
self.assert_response_correct( self.assert_response_correct(
response, response,
200, 200,
{"results": [], "next": None, "previous": None, "text_search_rewrite": None} expected_response
) )
self.assert_last_query_params({ self.assert_last_query_params({
"user_id": [unicode(self.user.id)], "user_id": [unicode(self.user.id)],
...@@ -398,7 +407,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -398,7 +407,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
@ddt.data(True, "true", "1") @ddt.data(True, "true", "1")
def test_following_true(self, following): def test_following_true(self, following):
self.register_get_user_response(self.user) self.register_get_user_response(self.user)
self.register_subscribed_threads_response(self.user, [], page=1, num_pages=1) self.register_subscribed_threads_response(self.user, [], page=1, num_pages=0)
response = self.client.get( response = self.client.get(
self.url, self.url,
{ {
...@@ -406,10 +415,15 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -406,10 +415,15 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"following": following, "following": following,
} }
) )
expected_response = make_paginated_api_response(
results=[], count=0, num_pages=0, next_link=None, previous_link=None
)
expected_response.update({"text_search_rewrite": None})
self.assert_response_correct( self.assert_response_correct(
response, response,
200, 200,
{"results": [], "next": None, "previous": None, "text_search_rewrite": None} expected_response
) )
self.assertEqual( self.assertEqual(
urlparse(httpretty.last_request().path).path, urlparse(httpretty.last_request().path).path,
...@@ -933,16 +947,15 @@ class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -933,16 +947,15 @@ class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"resp_total": 100, "resp_total": 100,
}) })
response = self.client.get(self.url, {"thread_id": self.thread_id}) response = self.client.get(self.url, {"thread_id": self.thread_id})
next_link = "http://testserver/api/discussion/v1/comments/?page=2&thread_id={}".format(
self.thread_id
)
self.assert_response_correct( self.assert_response_correct(
response, response,
200, 200,
{ make_paginated_api_response(
"results": expected_comments, results=expected_comments, count=100, num_pages=10, next_link=next_link, previous_link=None
"next": "http://testserver/api/discussion/v1/comments/?page=2&thread_id={}".format( )
self.thread_id
),
"previous": None,
}
) )
self.assert_query_params_equal( self.assert_query_params_equal(
httpretty.httpretty.latest_requests[-2], httpretty.httpretty.latest_requests[-2],
...@@ -1427,3 +1440,29 @@ class CommentViewSetRetrieveTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase ...@@ -1427,3 +1440,29 @@ class CommentViewSetRetrieveTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase
self.register_get_comment_error_response(self.comment_id, 404) self.register_get_comment_error_response(self.comment_id, 404)
response = self.client.get(self.url) response = self.client.get(self.url)
self.assertEqual(response.status_code, 404) self.assertEqual(response.status_code, 404)
def test_pagination(self):
"""
Test that pagination parameters are correctly plumbed through to the
comments service and that a 404 is correctly returned if a page past the
end is requested
"""
self.register_get_user_response(self.user)
cs_comment_child = self.make_comment_data("test_child_comment", self.comment_id, children=[])
cs_comment = self.make_comment_data(self.comment_id, None, [cs_comment_child])
cs_thread = make_minimal_cs_thread({
"id": self.thread_id,
"course_id": unicode(self.course.id),
"children": [cs_comment],
})
self.register_get_thread_response(cs_thread)
self.register_get_comment_response(cs_comment)
response = self.client.get(
self.url,
{"comment_id": self.comment_id, "page": "18", "page_size": "4"}
)
self.assert_response_correct(
response,
404,
{"developer_message": "Page not found (No results on this page)."}
)
...@@ -67,11 +67,12 @@ class CommentsServiceMockMixin(object): ...@@ -67,11 +67,12 @@ class CommentsServiceMockMixin(object):
"collection": threads, "collection": threads,
"page": page, "page": page,
"num_pages": num_pages, "num_pages": num_pages,
"thread_count": len(threads),
}), }),
status=200 status=200
) )
def register_get_threads_search_response(self, threads, rewrite): def register_get_threads_search_response(self, threads, rewrite, num_pages=1):
"""Register a mock response for GET on the CS thread search endpoint""" """Register a mock response for GET on the CS thread search endpoint"""
httpretty.register_uri( httpretty.register_uri(
httpretty.GET, httpretty.GET,
...@@ -79,8 +80,9 @@ class CommentsServiceMockMixin(object): ...@@ -79,8 +80,9 @@ class CommentsServiceMockMixin(object):
body=json.dumps({ body=json.dumps({
"collection": threads, "collection": threads,
"page": 1, "page": 1,
"num_pages": 1, "num_pages": num_pages,
"corrected_text": rewrite, "corrected_text": rewrite,
"thread_count": len(threads),
}), }),
status=200 status=200
) )
...@@ -200,6 +202,7 @@ class CommentsServiceMockMixin(object): ...@@ -200,6 +202,7 @@ class CommentsServiceMockMixin(object):
"collection": threads, "collection": threads,
"page": page, "page": page,
"num_pages": num_pages, "num_pages": num_pages,
"thread_count": len(threads),
}), }),
status=200 status=200
) )
...@@ -371,3 +374,18 @@ def make_minimal_cs_comment(overrides=None): ...@@ -371,3 +374,18 @@ def make_minimal_cs_comment(overrides=None):
} }
ret.update(overrides or {}) ret.update(overrides or {})
return ret return ret
def make_paginated_api_response(results=None, count=0, num_pages=0, next_link=None, previous_link=None):
"""
Generates the response dictionary of paginated APIs with passed data
"""
return {
"pagination": {
"next": next_link,
"previous": previous_link,
"count": count,
"num_pages": num_pages,
},
"results": results or []
}
...@@ -261,19 +261,17 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet): ...@@ -261,19 +261,17 @@ class ThreadViewSet(DeveloperErrorViewMixin, ViewSet):
form = ThreadListGetForm(request.GET) form = ThreadListGetForm(request.GET)
if not form.is_valid(): if not form.is_valid():
raise ValidationError(form.errors) raise ValidationError(form.errors)
return Response( return get_thread_list(
get_thread_list( request,
request, form.cleaned_data["course_id"],
form.cleaned_data["course_id"], form.cleaned_data["page"],
form.cleaned_data["page"], form.cleaned_data["page_size"],
form.cleaned_data["page_size"], form.cleaned_data["topic_id"],
form.cleaned_data["topic_id"], form.cleaned_data["text_search"],
form.cleaned_data["text_search"], form.cleaned_data["following"],
form.cleaned_data["following"], form.cleaned_data["view"],
form.cleaned_data["view"], form.cleaned_data["order_by"],
form.cleaned_data["order_by"], form.cleaned_data["order_direction"],
form.cleaned_data["order_direction"],
)
) )
def retrieve(self, request, thread_id=None): def retrieve(self, request, thread_id=None):
...@@ -443,14 +441,12 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet): ...@@ -443,14 +441,12 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet):
form = CommentListGetForm(request.GET) form = CommentListGetForm(request.GET)
if not form.is_valid(): if not form.is_valid():
raise ValidationError(form.errors) raise ValidationError(form.errors)
return Response( return get_comment_list(
get_comment_list( request,
request, form.cleaned_data["thread_id"],
form.cleaned_data["thread_id"], form.cleaned_data["endorsed"],
form.cleaned_data["endorsed"], form.cleaned_data["page"],
form.cleaned_data["page"], form.cleaned_data["page_size"]
form.cleaned_data["page_size"]
)
) )
def retrieve(self, request, comment_id=None): def retrieve(self, request, comment_id=None):
...@@ -460,13 +456,11 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet): ...@@ -460,13 +456,11 @@ class CommentViewSet(DeveloperErrorViewMixin, ViewSet):
form = _PaginationForm(request.GET) form = _PaginationForm(request.GET)
if not form.is_valid(): if not form.is_valid():
raise ValidationError(form.errors) raise ValidationError(form.errors)
return Response( return get_response_comments(
get_response_comments( request,
request, comment_id,
comment_id, form.cleaned_data["page"],
form.cleaned_data["page"], form.cleaned_data["page_size"]
form.cleaned_data["page_size"]
)
) )
def create(self, request): def create(self, request):
......
...@@ -6,6 +6,7 @@ from django.core.urlresolvers import reverse ...@@ -6,6 +6,7 @@ from django.core.urlresolvers import reverse
from django.http import Http404 from django.http import Http404
from django.test.client import Client, RequestFactory from django.test.client import Client, RequestFactory
from django.test.utils import override_settings from django.test.utils import override_settings
from lms.lib.comment_client.utils import CommentClientPaginatedResult
from edxmako.tests import mako_middleware_process_request from edxmako.tests import mako_middleware_process_request
from django_comment_common.utils import ThreadContext from django_comment_common.utils import ThreadContext
...@@ -96,7 +97,7 @@ class ViewsExceptionTestCase(UrlResetMixin, ModuleStoreTestCase): ...@@ -96,7 +97,7 @@ class ViewsExceptionTestCase(UrlResetMixin, ModuleStoreTestCase):
# Mock the code that makes the HTTP requests to the cs_comment_service app # Mock the code that makes the HTTP requests to the cs_comment_service app
# for the profiled user's active threads # for the profiled user's active threads
mock_threads.return_value = [], 1, 1 mock_threads.return_value = CommentClientPaginatedResult(collection=[], page=1, num_pages=1)
# Mock the code that makes the HTTP request to the cs_comment_service app # Mock the code that makes the HTTP request to the cs_comment_service app
# that gets the current user's info # that gets the current user's info
......
...@@ -146,7 +146,8 @@ def get_threads(request, course, discussion_id=None, per_page=THREADS_PER_PAGE): ...@@ -146,7 +146,8 @@ def get_threads(request, course, discussion_id=None, per_page=THREADS_PER_PAGE):
) )
) )
threads, page, num_pages, corrected_text = cc.Thread.search(query_params) paginated_results = cc.Thread.search(query_params)
threads = paginated_results.collection
# If not provided with a discussion id, filter threads by commentable ids # If not provided with a discussion id, filter threads by commentable ids
# which are accessible to the current user. # which are accessible to the current user.
...@@ -162,9 +163,9 @@ def get_threads(request, course, discussion_id=None, per_page=THREADS_PER_PAGE): ...@@ -162,9 +163,9 @@ def get_threads(request, course, discussion_id=None, per_page=THREADS_PER_PAGE):
if 'pinned' not in thread: if 'pinned' not in thread:
thread['pinned'] = False thread['pinned'] = False
query_params['page'] = page query_params['page'] = paginated_results.page
query_params['num_pages'] = num_pages query_params['num_pages'] = paginated_results.num_pages
query_params['corrected_text'] = corrected_text query_params['corrected_text'] = paginated_results.corrected_text
return threads, query_params return threads, query_params
...@@ -336,7 +337,12 @@ def single_thread(request, course_key, discussion_id, thread_id): ...@@ -336,7 +337,12 @@ def single_thread(request, course_key, discussion_id, thread_id):
is_staff = has_permission(request.user, 'openclose_thread', course.id) is_staff = has_permission(request.user, 'openclose_thread', course.id)
if request.is_ajax(): if request.is_ajax():
with newrelic.agent.FunctionTrace(nr_transaction, "get_annotated_content_infos"): with newrelic.agent.FunctionTrace(nr_transaction, "get_annotated_content_infos"):
annotated_content_info = utils.get_annotated_content_infos(course_key, thread, request.user, user_info=user_info) annotated_content_info = utils.get_annotated_content_infos(
course_key,
thread,
request.user,
user_info=user_info
)
content = utils.prepare_content(thread.to_dict(), course_key, is_staff) content = utils.prepare_content(thread.to_dict(), course_key, is_staff)
with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"): with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"):
add_courseware_context([content], course, request.user) add_courseware_context([content], course, request.user)
...@@ -511,18 +517,26 @@ def followed_threads(request, course_key, user_id): ...@@ -511,18 +517,26 @@ def followed_threads(request, course_key, user_id):
if group_id is not None: if group_id is not None:
query_params['group_id'] = group_id query_params['group_id'] = group_id
threads, page, num_pages = profiled_user.subscribed_threads(query_params) paginated_results = profiled_user.subscribed_threads(query_params)
query_params['page'] = page print "\n \n \n paginated results \n \n \n "
query_params['num_pages'] = num_pages print paginated_results
query_params['page'] = paginated_results.page
query_params['num_pages'] = paginated_results.num_pages
user_info = cc.User.from_django_user(request.user).to_dict() user_info = cc.User.from_django_user(request.user).to_dict()
with newrelic.agent.FunctionTrace(nr_transaction, "get_metadata_for_threads"): with newrelic.agent.FunctionTrace(nr_transaction, "get_metadata_for_threads"):
annotated_content_info = utils.get_metadata_for_threads(course_key, threads, request.user, user_info) annotated_content_info = utils.get_metadata_for_threads(
course_key,
paginated_results.collection,
request.user, user_info
)
if request.is_ajax(): if request.is_ajax():
is_staff = has_permission(request.user, 'openclose_thread', course.id) is_staff = has_permission(request.user, 'openclose_thread', course.id)
return utils.JsonResponse({ return utils.JsonResponse({
'annotated_content_info': annotated_content_info, 'annotated_content_info': annotated_content_info,
'discussion_data': [utils.prepare_content(thread, course_key, is_staff) for thread in threads], 'discussion_data': [
utils.prepare_content(thread, course_key, is_staff) for thread in paginated_results.collection
],
'page': query_params['page'], 'page': query_params['page'],
'num_pages': query_params['num_pages'], 'num_pages': query_params['num_pages'],
}) })
...@@ -533,7 +547,7 @@ def followed_threads(request, course_key, user_id): ...@@ -533,7 +547,7 @@ def followed_threads(request, course_key, user_id):
'user': request.user, 'user': request.user,
'django_user': User.objects.get(id=user_id), 'django_user': User.objects.get(id=user_id),
'profiled_user': profiled_user.to_dict(), 'profiled_user': profiled_user.to_dict(),
'threads': json.dumps(threads), 'threads': json.dumps(paginated_results.collection),
'user_info': json.dumps(user_info), 'user_info': json.dumps(user_info),
'annotated_content_info': json.dumps(annotated_content_info), 'annotated_content_info': json.dumps(annotated_content_info),
# 'content': content, # 'content': content,
......
import logging import logging
from eventtracking import tracker from eventtracking import tracker
from .utils import merge_dict, strip_blank, strip_none, extract, perform_request from .utils import merge_dict, strip_blank, strip_none, extract, perform_request, CommentClientPaginatedResult
from .utils import CommentClientRequestError from .utils import CommentClientRequestError
import models import models
import settings import settings
...@@ -94,7 +94,14 @@ class Thread(models.Model): ...@@ -94,7 +94,14 @@ class Thread(models.Model):
total_results=total_results total_results=total_results
) )
) )
return response.get('collection', []), response.get('page', 1), response.get('num_pages', 1), response.get('corrected_text')
return CommentClientPaginatedResult(
collection=response.get('collection', []),
page=response.get('page', 1),
num_pages=response.get('num_pages', 1),
thread_count=response.get('thread_count', 0),
corrected_text=response.get('corrected_text', None)
)
@classmethod @classmethod
def url_for_threads(cls, params={}): def url_for_threads(cls, params={}):
......
from .utils import merge_dict, perform_request, CommentClientRequestError """ User model wrapper for comment service"""
from .utils import merge_dict, perform_request, CommentClientRequestError, CommentClientPaginatedResult
import models import models
import settings import settings
...@@ -113,7 +114,12 @@ class User(models.Model): ...@@ -113,7 +114,12 @@ class User(models.Model):
metric_tags=self._metric_tags, metric_tags=self._metric_tags,
paged_results=True paged_results=True
) )
return response.get('collection', []), response.get('page', 1), response.get('num_pages', 1) return CommentClientPaginatedResult(
collection=response.get('collection', []),
page=response.get('page', 1),
num_pages=response.get('num_pages', 1),
thread_count=response.get('thread_count', 0)
)
def _retrieve(self, *args, **kwargs): def _retrieve(self, *args, **kwargs):
url = self.url(action='get', params=self.attributes) url = self.url(action='get', params=self.attributes)
......
"""" Common utilities for comment client wrapper """
from contextlib import contextmanager from contextlib import contextmanager
import dogstats_wrapper as dog_stats_api import dogstats_wrapper as dog_stats_api
import logging import logging
...@@ -141,9 +142,9 @@ class CommentClientError(Exception): ...@@ -141,9 +142,9 @@ class CommentClientError(Exception):
class CommentClientRequestError(CommentClientError): class CommentClientRequestError(CommentClientError):
def __init__(self, msg, status_code=400): def __init__(self, msg, status_codes=400):
super(CommentClientRequestError, self).__init__(msg) super(CommentClientRequestError, self).__init__(msg)
self.status_code = status_code self.status_code = status_codes
class CommentClient500Error(CommentClientError): class CommentClient500Error(CommentClientError):
...@@ -152,3 +153,14 @@ class CommentClient500Error(CommentClientError): ...@@ -152,3 +153,14 @@ class CommentClient500Error(CommentClientError):
class CommentClientMaintenanceError(CommentClientError): class CommentClientMaintenanceError(CommentClientError):
pass pass
class CommentClientPaginatedResult(object):
""" class for paginated results returned from comment services"""
def __init__(self, collection, page, num_pages, thread_count=0, corrected_text=None):
self.collection = collection
self.page = page
self.num_pages = num_pages
self.thread_count = thread_count
self.corrected_text = corrected_text
...@@ -39,6 +39,18 @@ class NamespacedPageNumberPagination(pagination.PageNumberPagination): ...@@ -39,6 +39,18 @@ class NamespacedPageNumberPagination(pagination.PageNumberPagination):
page_size_query_param = "page_size" page_size_query_param = "page_size"
def get_result_count(self):
"""
Returns total number of results
"""
return self.page.paginator.count
def get_num_pages(self):
"""
Returns total number of pages the results are divided into
"""
return self.page.paginator.num_pages
def get_paginated_response(self, data): def get_paginated_response(self, data):
""" """
Annotate the response with pagination information Annotate the response with pagination information
...@@ -46,8 +58,8 @@ class NamespacedPageNumberPagination(pagination.PageNumberPagination): ...@@ -46,8 +58,8 @@ class NamespacedPageNumberPagination(pagination.PageNumberPagination):
metadata = { metadata = {
'next': self.get_next_link(), 'next': self.get_next_link(),
'previous': self.get_previous_link(), 'previous': self.get_previous_link(),
'count': self.page.paginator.count, 'count': self.get_result_count(),
'num_pages': self.page.paginator.num_pages, 'num_pages': self.get_num_pages(),
} }
if isinstance(data, dict): if isinstance(data, dict):
if 'results' not in data: if 'results' not in data:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment