Commit 2451e067 by Greg Price

Add comment list endpoint to Discussion API

parent e124fb06
""" """
Discussion API internal interface Discussion API internal interface
""" """
from django.core.exceptions import ValidationError
from django.http import Http404 from django.http import Http404
from collections import defaultdict from collections import defaultdict
from opaque_keys.edx.locator import CourseLocator
from courseware.courses import get_course_with_access from courseware.courses import get_course_with_access
from discussion_api.pagination import get_paginated_data from discussion_api.pagination import get_paginated_data
from discussion_api.serializers import ThreadSerializer, get_context from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context
from django_comment_client.utils import get_accessible_discussion_modules from django_comment_client.utils import get_accessible_discussion_modules
from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.utils import CommentClientRequestError
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id
from xmodule.tabs import DiscussionTab from xmodule.tabs import DiscussionTab
...@@ -123,3 +127,84 @@ def get_thread_list(request, course_key, page, page_size): ...@@ -123,3 +127,84 @@ def get_thread_list(request, course_key, page, page_size):
results = [ThreadSerializer(thread, context=context).data for thread in threads] results = [ThreadSerializer(thread, context=context).data for thread in threads]
return get_paginated_data(request, results, page, num_pages) return get_paginated_data(request, results, page, num_pages)
def get_comment_list(request, thread_id, endorsed, page, page_size):
"""
Return the list of comments in the given thread.
Parameters:
request: The django request object used for build_absolute_uri and
determining the requesting user.
thread_id: The id of the thread to get comments for.
endorsed: Boolean indicating whether to get endorsed or non-endorsed
comments (or None for all comments). Must be None for a discussion
thread and non-None for a question thread.
page: The page number (1-indexed) to retrieve
page_size: The number of comments to retrieve per page
Returns:
A paginated result containing a list of comments; see
discussion_api.views.CommentViewSet for more detail.
"""
response_skip = page_size * (page - 1)
try:
cc_thread = Thread(id=thread_id).retrieve(
recursive=True,
user_id=request.user.id,
mark_as_read=True,
response_skip=response_skip,
response_limit=page_size
)
except CommentClientRequestError:
# page and page_size are validated at a higher level, so the only
# possible request error is if the thread doesn't exist
raise Http404
course_key = CourseLocator.from_string(cc_thread["course_id"])
course = _get_course_or_404(course_key, request.user)
context = get_context(course, request.user)
# Ensure user has access to the thread
if not context["is_requester_privileged"] and cc_thread["group_id"]:
requester_cohort = get_cohort_id(request.user, course_key)
if requester_cohort is not None and cc_thread["group_id"] != requester_cohort:
raise Http404
# Responses to discussion threads cannot be separated by endorsed, but
# responses to question threads must be separated by endorsed due to the
# existing comments service interface
if cc_thread["thread_type"] == "question":
if endorsed is None:
raise ValidationError({"endorsed": ["This field is required for question threads."]})
elif endorsed:
# CS does not apply resp_skip and resp_limit to endorsed responses
# of a question post
responses = cc_thread["endorsed_responses"][response_skip:(response_skip + page_size)]
resp_total = len(cc_thread["endorsed_responses"])
else:
responses = cc_thread["non_endorsed_responses"]
resp_total = cc_thread["non_endorsed_resp_total"]
else:
if endorsed is not None:
raise ValidationError(
{"endorsed": ["This field may not be specified for discussion threads."]}
)
responses = cc_thread["children"]
resp_total = cc_thread["resp_total"]
# The comments service returns the last page of results if the requested
# page is beyond the last page, but we want be consistent with DRF's general
# behavior and return a 404 in that case
if not responses and page != 1:
raise Http404
num_pages = (resp_total + page_size - 1) / page_size if resp_total else 1
results = [CommentSerializer(response, context=context).data for response in responses]
return get_paginated_data(request, results, page, num_pages)
...@@ -2,19 +2,31 @@ ...@@ -2,19 +2,31 @@
Discussion API forms Discussion API forms
""" """
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.forms import Form, CharField, IntegerField from django.forms import CharField, Form, IntegerField, NullBooleanField
from opaque_keys import InvalidKeyError from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
class ThreadListGetForm(Form): class _PaginationForm(Form):
"""A form that includes pagination fields"""
page = IntegerField(required=False, min_value=1)
page_size = IntegerField(required=False, min_value=1)
def clean_page(self):
"""Return given valid page or default of 1"""
return self.cleaned_data.get("page") or 1
def clean_page_size(self):
"""Return given valid page_size (capped at 100) or default of 10"""
return min(self.cleaned_data.get("page_size") or 10, 100)
class ThreadListGetForm(_PaginationForm):
""" """
A form to validate query parameters in the thread list retrieval endpoint A form to validate query parameters in the thread list retrieval endpoint
""" """
course_id = CharField() course_id = CharField()
page = IntegerField(required=False, min_value=1)
page_size = IntegerField(required=False, min_value=1)
def clean_course_id(self): def clean_course_id(self):
"""Validate course_id""" """Validate course_id"""
...@@ -24,10 +36,12 @@ class ThreadListGetForm(Form): ...@@ -24,10 +36,12 @@ class ThreadListGetForm(Form):
except InvalidKeyError: except InvalidKeyError:
raise ValidationError("'{}' is not a valid course id".format(value)) raise ValidationError("'{}' is not a valid course id".format(value))
def clean_page(self):
"""Return given valid page or default of 1"""
return self.cleaned_data.get("page") or 1
def clean_page_size(self): class CommentListGetForm(_PaginationForm):
"""Return given valid page_size (capped at 100) or default of 10""" """
return min(self.cleaned_data.get("page_size") or 10, 100) A form to validate query parameters in the comment list retrieval endpoint
"""
thread_id = CharField()
# TODO: should we use something better here? This only accepts "True",
# "False", "1", and "0"
endorsed = NullBooleanField(required=False)
...@@ -14,7 +14,10 @@ from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names ...@@ -14,7 +14,10 @@ from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names
def get_context(course, requester): def get_context(course, requester):
"""Returns a context appropriate for use with ThreadSerializer.""" """
Returns a context appropriate for use with ThreadSerializer or
CommentSerializer.
"""
# TODO: cache staff_user_ids and ta_user_ids if we need to improve perf # TODO: cache staff_user_ids and ta_user_ids if we need to improve perf
staff_user_ids = { staff_user_ids = {
user.id user.id
...@@ -39,49 +42,27 @@ def get_context(course, requester): ...@@ -39,49 +42,27 @@ def get_context(course, requester):
} }
class ThreadSerializer(serializers.Serializer): class _ContentSerializer(serializers.Serializer):
""" """A base class for thread and comment serializers."""
A serializer for thread data.
N.B. This should not be used with a comment_client Thread object that has
not had retrieve() called, because of the interaction between DRF's attempts
at introspection and Thread's __getattr__.
"""
id_ = serializers.CharField(read_only=True) id_ = serializers.CharField(read_only=True)
course_id = serializers.CharField()
topic_id = serializers.CharField(source="commentable_id")
group_id = serializers.IntegerField()
group_name = serializers.SerializerMethodField("get_group_name")
author = serializers.SerializerMethodField("get_author") author = serializers.SerializerMethodField("get_author")
author_label = serializers.SerializerMethodField("get_author_label") author_label = serializers.SerializerMethodField("get_author_label")
created_at = serializers.CharField(read_only=True) created_at = serializers.CharField(read_only=True)
updated_at = serializers.CharField(read_only=True) updated_at = serializers.CharField(read_only=True)
type_ = serializers.ChoiceField(source="thread_type", choices=("discussion", "question"))
title = serializers.CharField()
raw_body = serializers.CharField(source="body") raw_body = serializers.CharField(source="body")
pinned = serializers.BooleanField()
closed = serializers.BooleanField()
following = serializers.SerializerMethodField("get_following")
abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged") abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged")
voted = serializers.SerializerMethodField("get_voted") voted = serializers.SerializerMethodField("get_voted")
vote_count = serializers.SerializerMethodField("get_vote_count") vote_count = serializers.SerializerMethodField("get_vote_count")
comment_count = serializers.IntegerField(source="comments_count")
unread_comment_count = serializers.IntegerField(source="unread_comments_count")
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ThreadSerializer, self).__init__(*args, **kwargs) super(_ContentSerializer, self).__init__(*args, **kwargs)
# type and id are invalid class attribute names, so we must declare # id is an invalid class attribute name, so we must declare a different
# different names above and modify them here # name above and modify it here
self.fields["id"] = self.fields.pop("id_") self.fields["id"] = self.fields.pop("id_")
self.fields["type"] = self.fields.pop("type_")
def get_group_name(self, obj):
"""Returns the name of the group identified by the thread's group_id."""
return self.context["group_ids_to_names"].get(obj["group_id"])
def _is_anonymous(self, obj): def _is_anonymous(self, obj):
""" """
Returns a boolean indicating whether the thread should be anonymous to Returns a boolean indicating whether the content should be anonymous to
the requester. the requester.
""" """
return ( return (
...@@ -90,7 +71,7 @@ class ThreadSerializer(serializers.Serializer): ...@@ -90,7 +71,7 @@ class ThreadSerializer(serializers.Serializer):
) )
def get_author(self, obj): def get_author(self, obj):
"""Returns the author's username, or None if the thread is anonymous.""" """Returns the author's username, or None if the content is anonymous."""
return None if self._is_anonymous(obj) else obj["username"] return None if self._is_anonymous(obj) else obj["username"]
def _get_user_label(self, user_id): def _get_user_label(self, user_id):
...@@ -105,30 +86,84 @@ class ThreadSerializer(serializers.Serializer): ...@@ -105,30 +86,84 @@ class ThreadSerializer(serializers.Serializer):
) )
def get_author_label(self, obj): def get_author_label(self, obj):
"""Returns the role label for the thread author.""" """Returns the role label for the content author."""
return None if self._is_anonymous(obj) else self._get_user_label(int(obj["user_id"])) return None if self._is_anonymous(obj) else self._get_user_label(int(obj["user_id"]))
def get_following(self, obj):
"""
Returns a boolean indicating whether the requester is following the
thread.
"""
return obj["id"] in self.context["cc_requester"]["subscribed_thread_ids"]
def get_abuse_flagged(self, obj): def get_abuse_flagged(self, obj):
""" """
Returns a boolean indicating whether the requester has flagged the Returns a boolean indicating whether the requester has flagged the
thread as abusive. content as abusive.
""" """
return self.context["cc_requester"]["id"] in obj["abuse_flaggers"] return self.context["cc_requester"]["id"] in obj["abuse_flaggers"]
def get_voted(self, obj): def get_voted(self, obj):
""" """
Returns a boolean indicating whether the requester has voted for the Returns a boolean indicating whether the requester has voted for the
thread. content.
""" """
return obj["id"] in self.context["cc_requester"]["upvoted_ids"] return obj["id"] in self.context["cc_requester"]["upvoted_ids"]
def get_vote_count(self, obj): def get_vote_count(self, obj):
"""Returns the number of votes for the thread.""" """Returns the number of votes for the content."""
return obj["votes"]["up_count"] return obj["votes"]["up_count"]
class ThreadSerializer(_ContentSerializer):
"""
A serializer for thread data.
N.B. This should not be used with a comment_client Thread object that has
not had retrieve() called, because of the interaction between DRF's attempts
at introspection and Thread's __getattr__.
"""
course_id = serializers.CharField()
topic_id = serializers.CharField(source="commentable_id")
group_id = serializers.IntegerField()
group_name = serializers.SerializerMethodField("get_group_name")
type_ = serializers.ChoiceField(source="thread_type", choices=("discussion", "question"))
title = serializers.CharField()
pinned = serializers.BooleanField()
closed = serializers.BooleanField()
following = serializers.SerializerMethodField("get_following")
comment_count = serializers.IntegerField(source="comments_count")
unread_comment_count = serializers.IntegerField(source="unread_comments_count")
def __init__(self, *args, **kwargs):
super(ThreadSerializer, self).__init__(*args, **kwargs)
# type is an invalid class attribute name, so we must declare a
# different name above and modify it here
self.fields["type"] = self.fields.pop("type_")
def get_group_name(self, obj):
"""Returns the name of the group identified by the thread's group_id."""
return self.context["group_ids_to_names"].get(obj["group_id"])
def get_following(self, obj):
"""
Returns a boolean indicating whether the requester is following the
thread.
"""
return obj["id"] in self.context["cc_requester"]["subscribed_thread_ids"]
class CommentSerializer(_ContentSerializer):
"""
A serializer for comment data.
N.B. This should not be used with a comment_client Comment object that has
not had retrieve() called, because of the interaction between DRF's attempts
at introspection and Comment's __getattr__.
"""
thread_id = serializers.CharField()
parent_id = serializers.SerializerMethodField("get_parent_id")
children = serializers.SerializerMethodField("get_children")
def get_parent_id(self, _obj):
"""Returns the comment's parent's id (taken from the context)."""
return self.context.get("parent_id")
def get_children(self, obj):
"""Returns the list of the comment's children, serialized."""
child_context = dict(self.context)
child_context["parent_id"] = obj["id"]
return [CommentSerializer(child, context=child_context).data for child in obj["children"]]
...@@ -9,14 +9,19 @@ import httpretty ...@@ -9,14 +9,19 @@ import httpretty
import mock import mock
from pytz import UTC from pytz import UTC
from django.core.exceptions import ValidationError
from django.http import Http404 from django.http import Http404
from django.test.client import RequestFactory from django.test.client import RequestFactory
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
from courseware.tests.factories import BetaTesterFactory, StaffFactory from courseware.tests.factories import BetaTesterFactory, StaffFactory
from discussion_api.api import get_course_topics, get_thread_list from discussion_api.api import get_comment_list, get_course_topics, get_thread_list
from discussion_api.tests.utils import CommentsServiceMockMixin from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_comment,
make_minimal_cs_thread,
)
from django_comment_common.models import ( from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_COMMUNITY_TA, FORUM_ROLE_COMMUNITY_TA,
...@@ -378,12 +383,6 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -378,12 +383,6 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
ret = get_thread_list(self.request, course.id, page, page_size) ret = get_thread_list(self.request, course.id, page, page_size)
return ret return ret
def create_role(self, role_name, users):
"""Create a Role in self.course with the given name and users"""
role = Role.objects.create(name=role_name, course_id=self.course.id)
role.users = users
role.save()
def test_nonexistent_course(self): def test_nonexistent_course(self):
with self.assertRaises(Http404): with self.assertRaises(Http404):
get_thread_list(self.request, CourseLocator.from_string("non/existent/course"), 1, 1) get_thread_list(self.request, CourseLocator.from_string("non/existent/course"), 1, 1)
...@@ -573,3 +572,368 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -573,3 +572,368 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
self.register_get_threads_response([], page=3, num_pages=3) self.register_get_threads_response([], page=3, num_pages=3)
with self.assertRaises(Http404): with self.assertRaises(Http404):
get_thread_list(self.request, self.course.id, page=4, page_size=10) get_thread_list(self.request, self.course.id, page=4, page_size=10)
@ddt.ddt
class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"""Test for get_comment_list"""
def setUp(self):
super(GetCommentListTest, self).setUp()
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
self.maxDiff = None # pylint: disable=invalid-name
self.user = UserFactory.create()
self.register_get_user_response(self.user)
self.request = RequestFactory().get("/test_path")
self.request.user = self.user
self.course = CourseFactory.create()
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
self.author = UserFactory.create()
def make_minimal_cs_thread(self, overrides=None):
"""
Create a thread with the given overrides, plus the course_id if not
already in overrides.
"""
overrides = overrides.copy() if overrides else {}
overrides.setdefault("course_id", unicode(self.course.id))
return make_minimal_cs_thread(overrides)
def get_comment_list(self, thread, endorsed=None, page=1, page_size=1):
"""
Register the appropriate comments service response, then call
get_comment_list and return the result.
"""
self.register_get_thread_response(thread)
return get_comment_list(self.request, thread["id"], endorsed, page, page_size)
def test_nonexistent_thread(self):
thread_id = "nonexistent_thread"
self.register_get_thread_error_response(thread_id, 404)
with self.assertRaises(Http404):
get_comment_list(self.request, thread_id, endorsed=False, page=1, page_size=1)
def test_nonexistent_course(self):
with self.assertRaises(Http404):
self.get_comment_list(self.make_minimal_cs_thread({"course_id": "non/existent/course"}))
def test_not_enrolled(self):
self.request.user = UserFactory.create()
with self.assertRaises(Http404):
self.get_comment_list(self.make_minimal_cs_thread())
def test_discussions_disabled(self):
_remove_discussion_tab(self.course, self.user.id)
with self.assertRaises(Http404):
self.get_comment_list(self.make_minimal_cs_thread())
@ddt.data(
*itertools.product(
[
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_MODERATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT,
],
[True, False],
["no_group", "match_group", "different_group"],
)
)
@ddt.unpack
def test_group_access(self, role_name, course_is_cohorted, thread_group_state):
cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted})
CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id)
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
role.users = [self.user]
thread = self.make_minimal_cs_thread({
"course_id": unicode(cohort_course.id),
"group_id": (
None if thread_group_state == "no_group" else
cohort.id if thread_group_state == "match_group" else
cohort.id + 1
),
})
expected_error = (
role_name == FORUM_ROLE_STUDENT and
course_is_cohorted and
thread_group_state == "different_group"
)
try:
self.get_comment_list(thread)
self.assertFalse(expected_error)
except Http404:
self.assertTrue(expected_error)
@ddt.data(True, False)
def test_discussion_endorsed(self, endorsed_value):
with self.assertRaises(ValidationError) as assertion:
self.get_comment_list(
self.make_minimal_cs_thread({"thread_type": "discussion"}),
endorsed=endorsed_value
)
self.assertEqual(
assertion.exception.message_dict,
{"endorsed": ["This field may not be specified for discussion threads."]}
)
def test_question_without_endorsed(self):
with self.assertRaises(ValidationError) as assertion:
self.get_comment_list(
self.make_minimal_cs_thread({"thread_type": "question"}),
endorsed=None
)
self.assertEqual(
assertion.exception.message_dict,
{"endorsed": ["This field is required for question threads."]}
)
def test_empty(self):
discussion_thread = self.make_minimal_cs_thread(
{"thread_type": "discussion", "children": [], "resp_total": 0}
)
self.assertEqual(
self.get_comment_list(discussion_thread),
{"results": [], "next": None, "previous": None}
)
question_thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [],
"non_endorsed_responses": [],
"non_endorsed_resp_total": 0
})
self.assertEqual(
self.get_comment_list(question_thread, endorsed=False),
{"results": [], "next": None, "previous": None}
)
self.assertEqual(
self.get_comment_list(question_thread, endorsed=True),
{"results": [], "next": None, "previous": None}
)
def test_basic_query_params(self):
self.get_comment_list(
self.make_minimal_cs_thread({
"children": [make_minimal_cs_comment()],
"resp_total": 71
}),
page=6,
page_size=14
)
self.assert_query_params_equal(
httpretty.httpretty.latest_requests[-2],
{
"recursive": ["True"],
"user_id": [str(self.user.id)],
"mark_as_read": ["True"],
"resp_skip": ["70"],
"resp_limit": ["14"],
}
)
def test_discussion_content(self):
source_comments = [
{
"id": "test_comment_1",
"thread_id": "test_thread",
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-05-11T00:00:00Z",
"updated_at": "2015-05-11T11:11:11Z",
"body": "Test body",
"abuse_flaggers": [],
"votes": {"up_count": 4},
"children": [],
},
{
"id": "test_comment_2",
"thread_id": "test_thread",
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": True,
"anonymous_to_peers": False,
"created_at": "2015-05-11T22:22:22Z",
"updated_at": "2015-05-11T33:33:33Z",
"body": "More content",
"abuse_flaggers": [str(self.user.id)],
"votes": {"up_count": 7},
"children": [],
}
]
expected_comments = [
{
"id": "test_comment_1",
"thread_id": "test_thread",
"parent_id": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-05-11T00:00:00Z",
"updated_at": "2015-05-11T11:11:11Z",
"raw_body": "Test body",
"abuse_flagged": False,
"voted": False,
"vote_count": 4,
"children": [],
},
{
"id": "test_comment_2",
"thread_id": "test_thread",
"parent_id": None,
"author": None,
"author_label": None,
"created_at": "2015-05-11T22:22:22Z",
"updated_at": "2015-05-11T33:33:33Z",
"raw_body": "More content",
"abuse_flagged": True,
"voted": False,
"vote_count": 7,
"children": [],
},
]
actual_comments = self.get_comment_list(
self.make_minimal_cs_thread({"children": source_comments})
)["results"]
self.assertEqual(actual_comments, expected_comments)
def test_question_content(self):
thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [make_minimal_cs_comment({"id": "endorsed_comment"})],
"non_endorsed_responses": [make_minimal_cs_comment({"id": "non_endorsed_comment"})],
"non_endorsed_resp_total": 1,
})
endorsed_actual = self.get_comment_list(thread, endorsed=True)
self.assertEqual(endorsed_actual["results"][0]["id"], "endorsed_comment")
non_endorsed_actual = self.get_comment_list(thread, endorsed=False)
self.assertEqual(non_endorsed_actual["results"][0]["id"], "non_endorsed_comment")
@ddt.data(
("discussion", None, "children", "resp_total"),
("question", False, "non_endorsed_responses", "non_endorsed_resp_total"),
)
@ddt.unpack
def test_cs_pagination(self, thread_type, endorsed_arg, response_field, response_total_field):
"""
Test cases in which pagination is done by the comments service.
thread_type is the type of thread (question or discussion).
endorsed_arg is the value of the endorsed argument.
repsonse_field is the field in which responses are returned for the
given thread type.
response_total_field is the field in which the total number of responses
is returned for the given thread type.
"""
# N.B. The mismatch between the number of children and the listed total
# number of responses is unrealistic but convenient for this test
thread = self.make_minimal_cs_thread({
"thread_type": thread_type,
response_field: [make_minimal_cs_comment()],
response_total_field: 5,
})
# Only page
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=1, page_size=5)
self.assertIsNone(actual["next"])
self.assertIsNone(actual["previous"])
# First page of many
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=1, page_size=2)
self.assertEqual(actual["next"], "http://testserver/test_path?page=2")
self.assertIsNone(actual["previous"])
# Middle page of many
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=2, page_size=2)
self.assertEqual(actual["next"], "http://testserver/test_path?page=3")
self.assertEqual(actual["previous"], "http://testserver/test_path?page=1")
# Last page of many
actual = self.get_comment_list(thread, endorsed=endorsed_arg, page=3, page_size=2)
self.assertIsNone(actual["next"])
self.assertEqual(actual["previous"], "http://testserver/test_path?page=2")
# Page past the end
thread = self.make_minimal_cs_thread({
"thread_type": thread_type,
response_field: [],
response_total_field: 5
})
with self.assertRaises(Http404):
self.get_comment_list(thread, endorsed=endorsed_arg, page=2, page_size=5)
def test_question_endorsed_pagination(self):
thread = self.make_minimal_cs_thread({
"thread_type": "question",
"endorsed_responses": [
make_minimal_cs_comment({"id": "comment_{}".format(i)}) for i in range(10)
]
})
def assert_page_correct(page, page_size, expected_start, expected_stop, expected_next, expected_prev):
"""
Check that requesting the given page/page_size returns the expected
output
"""
actual = self.get_comment_list(thread, endorsed=True, page=page, page_size=page_size)
result_ids = [result["id"] for result in actual["results"]]
self.assertEqual(
result_ids,
["comment_{}".format(i) for i in range(expected_start, expected_stop)]
)
self.assertEqual(
actual["next"],
"http://testserver/test_path?page={}".format(expected_next) if expected_next else None
)
self.assertEqual(
actual["previous"],
"http://testserver/test_path?page={}".format(expected_prev) if expected_prev else None
)
# Only page
assert_page_correct(
page=1,
page_size=10,
expected_start=0,
expected_stop=10,
expected_next=None,
expected_prev=None
)
# First page of many
assert_page_correct(
page=1,
page_size=4,
expected_start=0,
expected_stop=4,
expected_next=2,
expected_prev=None
)
# Middle page of many
assert_page_correct(
page=2,
page_size=4,
expected_start=4,
expected_stop=8,
expected_next=3,
expected_prev=1
)
# Last page of many
assert_page_correct(
page=3,
page_size=4,
expected_start=8,
expected_stop=10,
expected_next=None,
expected_prev=2
)
# Page past the end
with self.assertRaises(Http404):
self.get_comment_list(thread, endorsed=True, page=2, page_size=10)
...@@ -5,25 +5,17 @@ from unittest import TestCase ...@@ -5,25 +5,17 @@ from unittest import TestCase
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
from discussion_api.forms import ThreadListGetForm from discussion_api.forms import CommentListGetForm, ThreadListGetForm
class ThreadListGetFormTest(TestCase): class FormTestMixin(object):
"""Tests for ThreadListGetForm""" """A mixin for testing forms"""
def setUp(self):
super(ThreadListGetFormTest, self).setUp()
self.form_data = {
"course_id": "Foo/Bar/Baz",
"page": "2",
"page_size": "13",
}
def get_form(self, expected_valid): def get_form(self, expected_valid):
""" """
Return a form bound to self.form_data, asserting its validity (or lack Return a form bound to self.form_data, asserting its validity (or lack
thereof) according to expected_valid thereof) according to expected_valid
""" """
form = ThreadListGetForm(self.form_data) form = self.FORM_CLASS(self.form_data)
self.assertEqual(form.is_valid(), expected_valid) self.assertEqual(form.is_valid(), expected_valid)
return form return form
...@@ -44,6 +36,42 @@ class ThreadListGetFormTest(TestCase): ...@@ -44,6 +36,42 @@ class ThreadListGetFormTest(TestCase):
form = self.get_form(expected_valid=True) form = self.get_form(expected_valid=True)
self.assertEqual(form.cleaned_data[field], expected_value) self.assertEqual(form.cleaned_data[field], expected_value)
class PaginationTestMixin(object):
"""A mixin for testing forms with pagination fields"""
def test_missing_page(self):
self.form_data.pop("page")
self.assert_field_value("page", 1)
def test_invalid_page(self):
self.form_data["page"] = "0"
self.assert_error("page", "Ensure this value is greater than or equal to 1.")
def test_missing_page_size(self):
self.form_data.pop("page_size")
self.assert_field_value("page_size", 10)
def test_zero_page_size(self):
self.form_data["page_size"] = "0"
self.assert_error("page_size", "Ensure this value is greater than or equal to 1.")
def test_excessive_page_size(self):
self.form_data["page_size"] = "101"
self.assert_field_value("page_size", 100)
class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"""Tests for ThreadListGetForm"""
FORM_CLASS = ThreadListGetForm
def setUp(self):
super(ThreadListGetFormTest, self).setUp()
self.form_data = {
"course_id": "Foo/Bar/Baz",
"page": "2",
"page_size": "13",
}
def test_basic(self): def test_basic(self):
form = self.get_form(expected_valid=True) form = self.get_form(expected_valid=True)
self.assertEqual( self.assertEqual(
...@@ -63,22 +91,36 @@ class ThreadListGetFormTest(TestCase): ...@@ -63,22 +91,36 @@ class ThreadListGetFormTest(TestCase):
self.form_data["course_id"] = "invalid course id" self.form_data["course_id"] = "invalid course id"
self.assert_error("course_id", "'invalid course id' is not a valid course id") self.assert_error("course_id", "'invalid course id' is not a valid course id")
def test_missing_page(self):
self.form_data.pop("page")
self.assert_field_value("page", 1)
def test_invalid_page(self): class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
self.form_data["page"] = "0" """Tests for CommentListGetForm"""
self.assert_error("page", "Ensure this value is greater than or equal to 1.") FORM_CLASS = CommentListGetForm
def test_missing_page_size(self): def setUp(self):
self.form_data.pop("page_size") super(CommentListGetFormTest, self).setUp()
self.assert_field_value("page_size", 10) self.form_data = {
"thread_id": "deadbeef",
"endorsed": "False",
"page": "2",
"page_size": "13",
}
def test_zero_page_size(self): def test_basic(self):
self.form_data["page_size"] = "0" form = self.get_form(expected_valid=True)
self.assert_error("page_size", "Ensure this value is greater than or equal to 1.") self.assertEqual(
form.cleaned_data,
{
"thread_id": "deadbeef",
"endorsed": False,
"page": 2,
"page_size": 13,
}
)
def test_excessive_page_size(self): def test_missing_thread_id(self):
self.form_data["page_size"] = "101" self.form_data.pop("thread_id")
self.assert_field_value("page_size", 100) self.assert_error("thread_id", "This field is required.")
def test_missing_endorsed(self):
self.form_data.pop("endorsed")
self.assert_field_value("endorsed", None)
...@@ -4,8 +4,12 @@ Tests for Discussion API serializers ...@@ -4,8 +4,12 @@ Tests for Discussion API serializers
import ddt import ddt
import httpretty import httpretty
from discussion_api.serializers import ThreadSerializer, get_context from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context
from discussion_api.tests.utils import CommentsServiceMockMixin from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_thread,
make_minimal_cs_comment,
)
from django_comment_common.models import ( from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_COMMUNITY_TA, FORUM_ROLE_COMMUNITY_TA,
...@@ -20,10 +24,9 @@ from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory ...@@ -20,10 +24,9 @@ from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
@ddt.ddt @ddt.ddt
class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase): class SerializerTestMixin(CommentsServiceMockMixin):
"""Tests for ThreadSerializer."""
def setUp(self): def setUp(self):
super(ThreadSerializerTest, self).setUp() super(SerializerTestMixin, self).setUp()
httpretty.reset() httpretty.reset()
httpretty.enable() httpretty.enable()
self.addCleanup(httpretty.disable) self.addCleanup(httpretty.disable)
...@@ -39,37 +42,93 @@ class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -39,37 +42,93 @@ class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase):
role = Role.objects.create(name=role_name, course_id=course.id) role = Role.objects.create(name=role_name, course_id=course.id)
role.users = users role.users = users
def make_cs_thread(self, thread_data=None): @ddt.data(
(FORUM_ROLE_ADMINISTRATOR, True, False, True),
(FORUM_ROLE_ADMINISTRATOR, False, True, False),
(FORUM_ROLE_MODERATOR, True, False, True),
(FORUM_ROLE_MODERATOR, False, True, False),
(FORUM_ROLE_COMMUNITY_TA, True, False, True),
(FORUM_ROLE_COMMUNITY_TA, False, True, False),
(FORUM_ROLE_STUDENT, True, False, True),
(FORUM_ROLE_STUDENT, False, True, True),
)
@ddt.unpack
def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous):
""" """
Create a dictionary containing all needed thread fields as returned by Test that content is properly made anonymous.
the comments service with dummy data overridden by thread_data
Content should be anonymous iff the anonymous field is true or the
anonymous_to_peers field is true and the requester does not have a
privileged role.
role_name is the name of the requester's role.
anonymous is the value of the anonymous field in the content.
anonymous_to_peers is the value of the anonymous_to_peers field in the
content.
expected_serialized_anonymous is whether the content should actually be
anonymous in the API output when requested by a user with the given
role.
""" """
ret = { self.create_role(role_name, [self.user])
"id": "dummy", serialized = self.serialize(
self.make_cs_content({"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers})
)
actual_serialized_anonymous = serialized["author"] is None
self.assertEqual(actual_serialized_anonymous, expected_serialized_anonymous)
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, False, "staff"),
(FORUM_ROLE_ADMINISTRATOR, True, None),
(FORUM_ROLE_MODERATOR, False, "staff"),
(FORUM_ROLE_MODERATOR, True, None),
(FORUM_ROLE_COMMUNITY_TA, False, "community_ta"),
(FORUM_ROLE_COMMUNITY_TA, True, None),
(FORUM_ROLE_STUDENT, False, None),
(FORUM_ROLE_STUDENT, True, None),
)
@ddt.unpack
def test_author_labels(self, role_name, anonymous, expected_label):
"""
Test correctness of the author_label field.
The label should be "staff", "staff", or "community_ta" for the
Administrator, Moderator, and Community TA roles, respectively, but
the label should not be present if the content is anonymous.
role_name is the name of the author's role.
anonymous is the value of the anonymous field in the content.
expected_label is the expected value of the author_label field in the
API output.
"""
self.create_role(role_name, [self.author])
serialized = self.serialize(self.make_cs_content({"anonymous": anonymous}))
self.assertEqual(serialized["author_label"], expected_label)
def test_abuse_flagged(self):
serialized = self.serialize(self.make_cs_content({"abuse_flaggers": [str(self.user.id)]}))
self.assertEqual(serialized["abuse_flagged"], True)
def test_voted(self):
thread_id = "test_thread"
self.register_get_user_response(self.user, upvoted_ids=[thread_id])
serialized = self.serialize(self.make_cs_content({"id": thread_id}))
self.assertEqual(serialized["voted"], True)
@ddt.ddt
class ThreadSerializerTest(SerializerTestMixin, ModuleStoreTestCase):
"""Tests for ThreadSerializer."""
def make_cs_content(self, overrides):
"""
Create a thread with the given overrides, plus some useful test data.
"""
merged_overrides = {
"course_id": unicode(self.course.id), "course_id": unicode(self.course.id),
"commentable_id": "dummy",
"group_id": None,
"user_id": str(self.author.id), "user_id": str(self.author.id),
"username": self.author.username, "username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "1970-01-01T00:00:00Z",
"updated_at": "1970-01-01T00:00:00Z",
"thread_type": "discussion",
"title": "dummy",
"body": "dummy",
"pinned": False,
"closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 0},
"comments_count": 0,
"unread_comments_count": 0,
"children": [],
"resp_total": 0,
} }
if thread_data: merged_overrides.update(overrides)
ret.update(thread_data) return make_minimal_cs_thread(merged_overrides)
return ret
def serialize(self, thread): def serialize(self, thread):
""" """
...@@ -126,84 +185,86 @@ class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -126,84 +185,86 @@ class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase):
def test_group(self): def test_group(self):
cohort = CohortFactory.create(course_id=self.course.id) cohort = CohortFactory.create(course_id=self.course.id)
serialized = self.serialize(self.make_cs_thread({"group_id": cohort.id})) serialized = self.serialize(self.make_cs_content({"group_id": cohort.id}))
self.assertEqual(serialized["group_id"], cohort.id) self.assertEqual(serialized["group_id"], cohort.id)
self.assertEqual(serialized["group_name"], cohort.name) self.assertEqual(serialized["group_name"], cohort.name)
@ddt.data( def test_following(self):
(FORUM_ROLE_ADMINISTRATOR, True, False, True), thread_id = "test_thread"
(FORUM_ROLE_ADMINISTRATOR, False, True, False), self.register_get_user_response(self.user, subscribed_thread_ids=[thread_id])
(FORUM_ROLE_MODERATOR, True, False, True), serialized = self.serialize(self.make_cs_content({"id": thread_id}))
(FORUM_ROLE_MODERATOR, False, True, False), self.assertEqual(serialized["following"], True)
(FORUM_ROLE_COMMUNITY_TA, True, False, True),
(FORUM_ROLE_COMMUNITY_TA, False, True, False),
(FORUM_ROLE_STUDENT, True, False, True),
(FORUM_ROLE_STUDENT, False, True, True),
)
@ddt.unpack
def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous):
"""
Test that content is properly made anonymous.
Content should be anonymous iff the anonymous field is true or the
anonymous_to_peers field is true and the requester does not have a
privileged role.
role_name is the name of the requester's role. @ddt.ddt
anonymous is the value of the anonymous field in the content. class CommentSerializerTest(SerializerTestMixin, ModuleStoreTestCase):
anonymous_to_peers is the value of the anonymous_to_peers field in the """Tests for CommentSerializer."""
content. def make_cs_content(self, overrides):
expected_serialized_anonymous is whether the content should actually be
anonymous in the API output when requested by a user with the given
role.
""" """
self.create_role(role_name, [self.user]) Create a comment with the given overrides, plus some useful test data.
serialized = self.serialize(
self.make_cs_thread({"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers})
)
actual_serialized_anonymous = serialized["author"] is None
self.assertEqual(actual_serialized_anonymous, expected_serialized_anonymous)
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, False, "staff"),
(FORUM_ROLE_ADMINISTRATOR, True, None),
(FORUM_ROLE_MODERATOR, False, "staff"),
(FORUM_ROLE_MODERATOR, True, None),
(FORUM_ROLE_COMMUNITY_TA, False, "community_ta"),
(FORUM_ROLE_COMMUNITY_TA, True, None),
(FORUM_ROLE_STUDENT, False, None),
(FORUM_ROLE_STUDENT, True, None),
)
@ddt.unpack
def test_author_labels(self, role_name, anonymous, expected_label):
""" """
Test correctness of the author_label field. merged_overrides = {
"user_id": str(self.author.id),
The label should be "staff", "staff", or "community_ta" for the "username": self.author.username
Administrator, Moderator, and Community TA roles, respectively, but }
the label should not be present if the thread is anonymous. merged_overrides.update(overrides)
return make_minimal_cs_comment(merged_overrides)
role_name is the name of the author's role. def serialize(self, comment):
anonymous is the value of the anonymous field in the content.
expected_label is the expected value of the author_label field in the
API output.
""" """
self.create_role(role_name, [self.author]) Create a serializer with an appropriate context and use it to serialize
serialized = self.serialize(self.make_cs_thread({"anonymous": anonymous})) the given comment, returning the result.
self.assertEqual(serialized["author_label"], expected_label) """
return CommentSerializer(comment, context=get_context(self.course, self.user)).data
def test_following(self):
thread_id = "test_thread"
self.register_get_user_response(self.user, subscribed_thread_ids=[thread_id])
serialized = self.serialize(self.make_cs_thread({"id": thread_id}))
self.assertEqual(serialized["following"], True)
def test_abuse_flagged(self): def test_basic(self):
serialized = self.serialize(self.make_cs_thread({"abuse_flaggers": [str(self.user.id)]})) comment = {
self.assertEqual(serialized["abuse_flagged"], True) "id": "test_comment",
"thread_id": "test_thread",
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"body": "Test body",
"abuse_flaggers": [],
"votes": {"up_count": 4},
"children": [],
}
expected = {
"id": "test_comment",
"thread_id": "test_thread",
"parent_id": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"raw_body": "Test body",
"abuse_flagged": False,
"voted": False,
"vote_count": 4,
"children": [],
}
self.assertEqual(self.serialize(comment), expected)
def test_voted(self): def test_children(self):
thread_id = "test_thread" comment = self.make_cs_content({
self.register_get_user_response(self.user, upvoted_ids=[thread_id]) "id": "test_root",
serialized = self.serialize(self.make_cs_thread({"id": thread_id})) "children": [
self.assertEqual(serialized["voted"], True) self.make_cs_content({
"id": "test_child_1",
}),
self.make_cs_content({
"id": "test_child_2",
"children": [self.make_cs_content({"id": "test_grandchild"})],
}),
],
})
serialized = self.serialize(comment)
self.assertEqual(serialized["children"][0]["id"], "test_child_1")
self.assertEqual(serialized["children"][0]["parent_id"], "test_root")
self.assertEqual(serialized["children"][1]["id"], "test_child_2")
self.assertEqual(serialized["children"][1]["parent_id"], "test_root")
self.assertEqual(serialized["children"][1]["children"][0]["id"], "test_grandchild")
self.assertEqual(serialized["children"][1]["children"][0]["parent_id"], "test_child_2")
...@@ -10,13 +10,11 @@ from pytz import UTC ...@@ -10,13 +10,11 @@ from pytz import UTC
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from discussion_api.tests.utils import CommentsServiceMockMixin from discussion_api.tests.utils import CommentsServiceMockMixin, make_minimal_cs_thread
from student.tests.factories import CourseEnrollmentFactory, UserFactory from student.tests.factories import CourseEnrollmentFactory, UserFactory
from util.testing import UrlResetMixin from util.testing import UrlResetMixin
from xmodule.modulestore.django import modulestore
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.tabs import DiscussionTab
class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin): class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin):
...@@ -201,3 +199,125 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -201,3 +199,125 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"per_page": ["4"], "per_page": ["4"],
"recursive": ["False"], "recursive": ["False"],
}) })
@httpretty.activate
class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for CommentViewSet list"""
def setUp(self):
super(CommentViewSetListTest, self).setUp()
self.author = UserFactory.create()
self.url = reverse("comment-list")
self.thread_id = "test_thread"
def test_thread_id_missing(self):
response = self.client.get(self.url)
self.assert_response_correct(
response,
400,
{"field_errors": {"thread_id": "This field is required."}}
)
def test_404(self):
self.register_get_thread_error_response(self.thread_id, 404)
response = self.client.get(self.url, {"thread_id": self.thread_id})
self.assert_response_correct(
response,
404,
{"developer_message": "Not found."}
)
def test_basic(self):
self.register_get_user_response(self.user, upvoted_ids=["test_comment"])
source_comments = [{
"id": "test_comment",
"thread_id": self.thread_id,
"parent_id": None,
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-05-11T00:00:00Z",
"updated_at": "2015-05-11T11:11:11Z",
"body": "Test body",
"abuse_flaggers": [],
"votes": {"up_count": 4},
"children": [],
}]
expected_comments = [{
"id": "test_comment",
"thread_id": self.thread_id,
"parent_id": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-05-11T00:00:00Z",
"updated_at": "2015-05-11T11:11:11Z",
"raw_body": "Test body",
"abuse_flagged": False,
"voted": True,
"vote_count": 4,
"children": [],
}]
self.register_get_thread_response({
"id": self.thread_id,
"course_id": unicode(self.course.id),
"thread_type": "discussion",
"children": source_comments,
"resp_total": 100,
})
response = self.client.get(self.url, {"thread_id": self.thread_id})
self.assert_response_correct(
response,
200,
{
"results": expected_comments,
"next": "http://testserver/api/discussion/v1/comments/?thread_id={}&page=2".format(
self.thread_id
),
"previous": None,
}
)
self.assert_query_params_equal(
httpretty.httpretty.latest_requests[-2],
{
"recursive": ["True"],
"resp_skip": ["0"],
"resp_limit": ["10"],
"user_id": [str(self.user.id)],
"mark_as_read": ["True"],
}
)
def test_pagination(self):
"""
Test that pagination parameters are correctly plumbed through to the
comments service and that a 404 is correctly returned if a page past the
end is requested
"""
self.register_get_user_response(self.user)
self.register_get_thread_response(make_minimal_cs_thread({
"id": self.thread_id,
"course_id": unicode(self.course.id),
"thread_type": "discussion",
"children": [],
"resp_total": 10,
}))
response = self.client.get(
self.url,
{"thread_id": self.thread_id, "page": "18", "page_size": "4"}
)
self.assert_response_correct(
response,
404,
{"developer_message": "Not found."}
)
self.assert_query_params_equal(
httpretty.httpretty.latest_requests[-2],
{
"recursive": ["True"],
"resp_skip": ["68"],
"resp_limit": ["4"],
"user_id": [str(self.user.id)],
"mark_as_read": ["True"],
}
)
...@@ -21,6 +21,26 @@ class CommentsServiceMockMixin(object): ...@@ -21,6 +21,26 @@ class CommentsServiceMockMixin(object):
status=200 status=200
) )
def register_get_thread_error_response(self, thread_id, status_code):
"""Register a mock error response for GET on the CS thread endpoint."""
httpretty.register_uri(
httpretty.GET,
"http://localhost:4567/api/v1/threads/{id}".format(id=thread_id),
body="",
status=status_code
)
def register_get_thread_response(self, thread):
"""
Register a mock response for GET on the CS thread instance endpoint.
"""
httpretty.register_uri(
httpretty.GET,
"http://localhost:4567/api/v1/threads/{id}".format(id=thread["id"]),
body=json.dumps(thread),
status=200
)
def register_get_user_response(self, user, subscribed_thread_ids=None, upvoted_ids=None): def register_get_user_response(self, user, subscribed_thread_ids=None, upvoted_ids=None):
"""Register a mock response for GET on the CS user instance endpoint""" """Register a mock response for GET on the CS user instance endpoint"""
httpretty.register_uri( httpretty.register_uri(
...@@ -34,10 +54,73 @@ class CommentsServiceMockMixin(object): ...@@ -34,10 +54,73 @@ class CommentsServiceMockMixin(object):
status=200 status=200
) )
def assert_last_query_params(self, expected_params): def assert_query_params_equal(self, httpretty_request, expected_params):
""" """
Assert that the last mock request had the expected query parameters Assert that the given mock request had the expected query parameters
""" """
actual_params = dict(httpretty.last_request().querystring) actual_params = dict(httpretty_request.querystring)
actual_params.pop("request_id") # request_id is random actual_params.pop("request_id") # request_id is random
self.assertEqual(actual_params, expected_params) self.assertEqual(actual_params, expected_params)
def assert_last_query_params(self, expected_params):
"""
Assert that the last mock request had the expected query parameters
"""
self.assert_query_params_equal(httpretty.last_request(), expected_params)
def make_minimal_cs_thread(overrides=None):
"""
Create a dictionary containing all needed thread fields as returned by the
comments service with dummy data and optional overrides
"""
ret = {
"id": "dummy",
"course_id": "dummy/dummy/dummy",
"commentable_id": "dummy",
"group_id": None,
"user_id": "0",
"username": "dummy",
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "1970-01-01T00:00:00Z",
"updated_at": "1970-01-01T00:00:00Z",
"thread_type": "discussion",
"title": "dummy",
"body": "dummy",
"pinned": False,
"closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 0},
"comments_count": 0,
"unread_comments_count": 0,
"children": [],
"resp_total": 0,
}
ret.update(overrides or {})
return ret
def make_minimal_cs_comment(overrides=None):
"""
Create a dictionary containing all needed comment fields as returned by the
comments service with dummy data and optional overrides
"""
ret = {
"id": "dummy",
"thread_id": "dummy",
"user_id": "0",
"username": "dummy",
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "1970-01-01T00:00:00Z",
"updated_at": "1970-01-01T00:00:00Z",
"body": "dummy",
"abuse_flaggers": [],
"votes": {"up_count": 0},
"endorsed": False,
"endorsement": None,
"children": [],
}
ret.update(overrides or {})
return ret
...@@ -6,11 +6,12 @@ from django.conf.urls import include, patterns, url ...@@ -6,11 +6,12 @@ from django.conf.urls import include, patterns, url
from rest_framework.routers import SimpleRouter from rest_framework.routers import SimpleRouter
from discussion_api.views import CourseTopicsView, ThreadViewSet from discussion_api.views import CommentViewSet, CourseTopicsView, ThreadViewSet
ROUTER = SimpleRouter() ROUTER = SimpleRouter()
ROUTER.register("threads", ThreadViewSet, base_name="thread") ROUTER.register("threads", ThreadViewSet, base_name="thread")
ROUTER.register("comments", CommentViewSet, base_name="comment")
urlpatterns = patterns( urlpatterns = patterns(
"discussion_api", "discussion_api",
......
...@@ -11,8 +11,8 @@ from rest_framework.viewsets import ViewSet ...@@ -11,8 +11,8 @@ from rest_framework.viewsets import ViewSet
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
from discussion_api.api import get_course_topics, get_thread_list from discussion_api.api import get_comment_list, get_course_topics, get_thread_list
from discussion_api.forms import ThreadListGetForm from discussion_api.forms import CommentListGetForm, ThreadListGetForm
from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin
...@@ -126,3 +126,82 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -126,3 +126,82 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
form.cleaned_data["page_size"] form.cleaned_data["page_size"]
) )
) )
class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"""
**Use Cases**
Retrieve the list of comments in a thread.
**Example Requests**:
GET /api/discussion/v1/comments/?thread_id=0123456789abcdef01234567
**GET Parameters**:
* thread_id (required): The thread to retrieve comments for
* endorsed: If specified, only retrieve the endorsed or non-endorsed
comments accordingly. Required for a question thread, must be absent
for a discussion thread.
* page: The (1-indexed) page to retrieve (default is 1)
* page_size: The number of items per page (default is 10, max is 100)
**Response Values**:
* results: The list of comments. Each item in the list includes:
* id: The id of the comment
* thread_id: The id of the comment's thread
* parent_id: The id of the comment's parent
* author: The username of the comment's author, or None if the
comment is anonymous
* author_label: A label indicating whether the author has a special
role in the course, either "staff" for moderators and
administrators or "community_ta" for community TAs
* created_at: The ISO 8601 timestamp for the creation of the comment
* updated_at: The ISO 8601 timestamp for the last modification of
the comment, which may not have been an update of the body
* raw_body: The comment's raw body text without any rendering applied
* abuse_flagged: Boolean indicating whether the requesting user has
flagged the comment for abuse
* voted: Boolean indicating whether the requesting user has voted
for the comment
* vote_count: The number of votes for the comment
* children: The list of child comments (with the same format)
* next: The URL of the next page (or null if first page)
* previous: The URL of the previous page (or null if last page)
"""
def list(self, request):
"""
Implements the GET method for the list endpoint as described in the
class docstring.
"""
form = CommentListGetForm(request.GET)
if not form.is_valid():
raise ValidationError(form.errors)
return Response(
get_comment_list(
request,
form.cleaned_data["thread_id"],
form.cleaned_data["endorsed"],
form.cleaned_data["page"],
form.cleaned_data["page_size"]
)
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment