Commit 895731f5 by Greg Price

Add comment editing to discussion API

This is done via PATCH on a comment instance endpoint.
parent 9bdf64d3
...@@ -26,6 +26,7 @@ from django_comment_client.base.views import ( ...@@ -26,6 +26,7 @@ from django_comment_client.base.views import (
track_forum_event, track_forum_event,
) )
from django_comment_client.utils import get_accessible_discussion_modules from django_comment_client.utils import get_accessible_discussion_modules
from lms.lib.comment_client.comment import Comment
from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.utils import CommentClientRequestError from lms.lib.comment_client.utils import CommentClientRequestError
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, is_commentable_cohorted from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, is_commentable_cohorted
...@@ -74,16 +75,36 @@ def _get_thread_and_context(request, thread_id, retrieve_kwargs=None): ...@@ -74,16 +75,36 @@ def _get_thread_and_context(request, thread_id, retrieve_kwargs=None):
raise Http404 raise Http404
def _is_user_author_or_privileged(cc_thread, context): def _get_comment_and_context(request, comment_id):
""" """
Check if the user is the author of a thread or a privileged user. Retrieve the given comment and build a serializer context for it, returning
both. This function also enforces access control for the comment (checking
both the user's access to the course and to the comment's thread's cohort if
applicable). Raises Http404 if the comment does not exist or the user cannot
access it.
"""
try:
cc_comment = Comment(id=comment_id).retrieve()
_, context = _get_thread_and_context(
request,
cc_comment["thread_id"],
cc_comment["parent_id"]
)
return cc_comment, context
except CommentClientRequestError:
raise Http404
def _is_user_author_or_privileged(cc_content, context):
"""
Check if the user is the author of a content object or a privileged user.
Returns: Returns:
Boolean Boolean
""" """
return ( return (
context["is_requester_privileged"] or context["is_requester_privileged"] or
context["cc_requester"]["id"] == cc_thread["user_id"] context["cc_requester"]["id"] == cc_content["user_id"]
) )
...@@ -442,6 +463,47 @@ def update_thread(request, thread_id, update_data): ...@@ -442,6 +463,47 @@ def update_thread(request, thread_id, update_data):
return api_thread return api_thread
def update_comment(request, comment_id, update_data):
"""
Update a comment.
Parameters:
request: The django request object used for build_absolute_uri and
determining the requesting user.
comment_id: The id for the comment to update.
update_data: The data to update in the comment.
Returns:
The updated comment; see discussion_api.views.CommentViewSet for more
detail.
Raises:
Http404: if the comment does not exist or is not accessible to the
requesting user
PermissionDenied: if the comment is accessible to but not editable by
the requesting user
ValidationError: if there is an error applying the update (e.g. raw_body
is empty or thread_id is included)
"""
cc_comment, context = _get_comment_and_context(request, comment_id)
if not _is_user_author_or_privileged(cc_comment, context):
raise PermissionDenied()
serializer = CommentSerializer(cc_comment, data=update_data, partial=True, context=context)
if not serializer.is_valid():
raise ValidationError(serializer.errors)
# Only save comment object if the comment is actually modified
if update_data:
serializer.save()
return serializer.data
def delete_thread(request, thread_id): def delete_thread(request, thread_id):
""" """
Delete a thread. Delete a thread.
......
...@@ -69,12 +69,23 @@ class _ContentSerializer(serializers.Serializer): ...@@ -69,12 +69,23 @@ class _ContentSerializer(serializers.Serializer):
voted = serializers.SerializerMethodField("get_voted") voted = serializers.SerializerMethodField("get_voted")
vote_count = serializers.SerializerMethodField("get_vote_count") vote_count = serializers.SerializerMethodField("get_vote_count")
non_updatable_fields = ()
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(_ContentSerializer, self).__init__(*args, **kwargs) super(_ContentSerializer, self).__init__(*args, **kwargs)
# id is an invalid class attribute name, so we must declare a different # id is an invalid class attribute name, so we must declare a different
# name above and modify it here # name above and modify it here
self.fields["id"] = self.fields.pop("id_") self.fields["id"] = self.fields.pop("id_")
for field in self.non_updatable_fields:
setattr(self, "validate_{}".format(field), self._validate_non_updatable)
def _validate_non_updatable(self, attrs, _source):
"""Ensure that a field is not edited in an update operation."""
if self.object:
raise ValidationError("This field is not allowed in an update.")
return attrs
def _is_user_privileged(self, user_id): def _is_user_privileged(self, user_id):
""" """
Returns a boolean indicating whether the given user_id identifies a Returns a boolean indicating whether the given user_id identifies a
...@@ -156,6 +167,8 @@ class ThreadSerializer(_ContentSerializer): ...@@ -156,6 +167,8 @@ class ThreadSerializer(_ContentSerializer):
endorsed_comment_list_url = serializers.SerializerMethodField("get_endorsed_comment_list_url") endorsed_comment_list_url = serializers.SerializerMethodField("get_endorsed_comment_list_url")
non_endorsed_comment_list_url = serializers.SerializerMethodField("get_non_endorsed_comment_list_url") non_endorsed_comment_list_url = serializers.SerializerMethodField("get_non_endorsed_comment_list_url")
non_updatable_fields = ("course_id",)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ThreadSerializer, self).__init__(*args, **kwargs) super(ThreadSerializer, self).__init__(*args, **kwargs)
# type is an invalid class attribute name, so we must declare a # type is an invalid class attribute name, so we must declare a
...@@ -199,12 +212,6 @@ class ThreadSerializer(_ContentSerializer): ...@@ -199,12 +212,6 @@ class ThreadSerializer(_ContentSerializer):
"""Returns the URL to retrieve the thread's non-endorsed comments.""" """Returns the URL to retrieve the thread's non-endorsed comments."""
return self.get_comment_list_url(obj, endorsed=False) return self.get_comment_list_url(obj, endorsed=False)
def validate_course_id(self, attrs, _source):
"""Ensure that course_id is not edited in an update operation."""
if self.object:
raise ValidationError("This field is not allowed in an update.")
return attrs
def restore_object(self, attrs, instance=None): def restore_object(self, attrs, instance=None):
if instance: if instance:
for key, val in attrs.items(): for key, val in attrs.items():
...@@ -230,6 +237,8 @@ class CommentSerializer(_ContentSerializer): ...@@ -230,6 +237,8 @@ class CommentSerializer(_ContentSerializer):
endorsed_at = serializers.SerializerMethodField("get_endorsed_at") endorsed_at = serializers.SerializerMethodField("get_endorsed_at")
children = serializers.SerializerMethodField("get_children") children = serializers.SerializerMethodField("get_children")
non_updatable_fields = ("thread_id", "parent_id")
def get_endorsed_by(self, obj): def get_endorsed_by(self, obj):
""" """
Returns the username of the endorsing user, if the information is Returns the username of the endorsing user, if the information is
...@@ -288,8 +297,10 @@ class CommentSerializer(_ContentSerializer): ...@@ -288,8 +297,10 @@ class CommentSerializer(_ContentSerializer):
return attrs return attrs
def restore_object(self, attrs, instance=None): def restore_object(self, attrs, instance=None):
if instance: # pragma: no cover if instance:
raise ValueError("CommentSerializer cannot be used for updates.") for key, val in attrs.items():
instance[key] = val
return instance
return Comment( return Comment(
course_id=self.context["thread"]["course_id"], course_id=self.context["thread"]["course_id"],
user_id=self.context["cc_requester"]["id"], user_id=self.context["cc_requester"]["id"],
......
...@@ -27,6 +27,7 @@ from discussion_api.api import ( ...@@ -27,6 +27,7 @@ from discussion_api.api import (
get_comment_list, get_comment_list,
get_course_topics, get_course_topics,
get_thread_list, get_thread_list,
update_comment,
update_thread, update_thread,
) )
from discussion_api.tests.utils import ( from discussion_api.tests.utils import (
...@@ -1638,6 +1639,173 @@ class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC ...@@ -1638,6 +1639,173 @@ class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC
@ddt.ddt @ddt.ddt
class UpdateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase):
"""Tests for update_comment"""
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
def setUp(self):
super(UpdateCommentTest, self).setUp()
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
self.user = UserFactory.create()
self.register_get_user_response(self.user)
self.request = RequestFactory().get("/test_path")
self.request.user = self.user
self.course = CourseFactory.create()
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
def register_comment(self, overrides=None, thread_overrides=None):
"""
Make a comment with appropriate data overridden by the overrides
parameter and register mock responses for both GET and PUT on its
endpoint. Also mock GET for the related thread with thread_overrides.
"""
cs_thread_data = make_minimal_cs_thread({
"id": "test_thread",
"course_id": unicode(self.course.id)
})
cs_thread_data.update(thread_overrides or {})
self.register_get_thread_response(cs_thread_data)
cs_comment_data = make_minimal_cs_comment({
"id": "test_comment",
"course_id": cs_thread_data["course_id"],
"thread_id": cs_thread_data["id"],
"username": self.user.username,
"user_id": str(self.user.id),
"created_at": "2015-06-03T00:00:00Z",
"updated_at": "2015-06-03T00:00:00Z",
"body": "Original body",
})
cs_comment_data.update(overrides or {})
self.register_get_comment_response(cs_comment_data)
self.register_put_comment_response(cs_comment_data)
def test_empty(self):
"""Check that an empty update does not make any modifying requests."""
self.register_comment()
update_comment(self.request, "test_comment", {})
for request in httpretty.httpretty.latest_requests:
self.assertEqual(request.method, "GET")
def test_basic(self):
self.register_comment()
actual = update_comment(self.request, "test_comment", {"raw_body": "Edited body"})
expected = {
"id": "test_comment",
"thread_id": "test_thread",
"parent_id": None, # TODO: we can't get this without retrieving from the thread :-(
"author": self.user.username,
"author_label": None,
"created_at": "2015-06-03T00:00:00Z",
"updated_at": "2015-06-03T00:00:00Z",
"raw_body": "Edited body",
"endorsed": False,
"endorsed_by": None,
"endorsed_by_label": None,
"endorsed_at": None,
"abuse_flagged": False,
"voted": False,
"vote_count": 0,
"children": [],
}
self.assertEqual(actual, expected)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"body": ["Edited body"],
"course_id": [unicode(self.course.id)],
"user_id": [str(self.user.id)],
"anonymous": ["False"],
"anonymous_to_peers": ["False"],
"endorsed": ["False"],
}
)
def test_nonexistent_comment(self):
self.register_get_comment_error_response("test_comment", 404)
with self.assertRaises(Http404):
update_comment(self.request, "test_comment", {})
def test_nonexistent_course(self):
self.register_comment(thread_overrides={"course_id": "non/existent/course"})
with self.assertRaises(Http404):
update_comment(self.request, "test_comment", {})
def test_unenrolled(self):
self.register_comment()
self.request.user = UserFactory.create()
with self.assertRaises(Http404):
update_comment(self.request, "test_comment", {})
def test_discussions_disabled(self):
_remove_discussion_tab(self.course, self.user.id)
self.register_comment()
with self.assertRaises(Http404):
update_comment(self.request, "test_comment", {})
@ddt.data(
*itertools.product(
[
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_MODERATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT,
],
[True, False],
["no_group", "match_group", "different_group"],
)
)
@ddt.unpack
def test_group_access(self, role_name, course_is_cohorted, thread_group_state):
cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted})
CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id)
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
role.users = [self.user]
self.register_get_thread_response(make_minimal_cs_thread())
self.register_comment(
{"thread_id": "test_thread"},
thread_overrides={
"id": "test_thread",
"course_id": unicode(cohort_course.id),
"group_id": (
None if thread_group_state == "no_group" else
cohort.id if thread_group_state == "match_group" else
cohort.id + 1
),
}
)
expected_error = (
role_name == FORUM_ROLE_STUDENT and
course_is_cohorted and
thread_group_state == "different_group"
)
try:
update_comment(self.request, "test_comment", {})
self.assertFalse(expected_error)
except Http404:
self.assertTrue(expected_error)
@ddt.data(
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_MODERATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT,
)
def test_role_access(self, role_name):
role = Role.objects.create(name=role_name, course_id=self.course.id)
role.users = [self.user]
self.register_comment({"user_id": str(self.user.id + 1)})
expected_error = role_name == FORUM_ROLE_STUDENT
try:
update_comment(self.request, "test_comment", {"raw_body": "edited"})
self.assertFalse(expected_error)
except PermissionDenied:
self.assertTrue(expected_error)
@ddt.ddt
class DeleteThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase): class DeleteThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase):
"""Tests for delete_thread""" """Tests for delete_thread"""
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
......
...@@ -23,6 +23,7 @@ from django_comment_common.models import ( ...@@ -23,6 +23,7 @@ from django_comment_common.models import (
FORUM_ROLE_STUDENT, FORUM_ROLE_STUDENT,
Role, Role,
) )
from lms.lib.comment_client.comment import Comment
from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.thread import Thread
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
from util.testing import UrlResetMixin from util.testing import UrlResetMixin
...@@ -553,8 +554,15 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore ...@@ -553,8 +554,15 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
"thread_id": "test_thread", "thread_id": "test_thread",
"raw_body": "Test body", "raw_body": "Test body",
} }
self.existing_comment = Comment(**make_minimal_cs_comment({
"id": "existing_comment",
"thread_id": "existing_thread",
"body": "Original body",
"user_id": str(self.user.id),
"course_id": unicode(self.course.id),
}))
def save_and_reserialize(self, data): def save_and_reserialize(self, data, instance=None):
""" """
Create a serializer with the given data, ensure that it is valid, save Create a serializer with the given data, ensure that it is valid, save
the result, and return the full comment data from the serializer. the result, and return the full comment data from the serializer.
...@@ -564,13 +572,18 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore ...@@ -564,13 +572,18 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
self.request, self.request,
make_minimal_cs_thread({"course_id": unicode(self.course.id)}) make_minimal_cs_thread({"course_id": unicode(self.course.id)})
) )
serializer = CommentSerializer(data=data, context=context) serializer = CommentSerializer(
instance,
data=data,
partial=(instance is not None),
context=context
)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
serializer.save() serializer.save()
return serializer.data return serializer.data
@ddt.data(None, "test_parent") @ddt.data(None, "test_parent")
def test_success(self, parent_id): def test_create_success(self, parent_id):
data = self.minimal_data.copy() data = self.minimal_data.copy()
if parent_id: if parent_id:
data["parent_id"] = parent_id data["parent_id"] = parent_id
...@@ -597,7 +610,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore ...@@ -597,7 +610,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
self.assertEqual(saved["id"], "test_comment") self.assertEqual(saved["id"], "test_comment")
self.assertEqual(saved["parent_id"], parent_id) self.assertEqual(saved["parent_id"], parent_id)
def test_parent_id_nonexistent(self): def test_create_parent_id_nonexistent(self):
self.register_get_comment_error_response("bad_parent", 404) self.register_get_comment_error_response("bad_parent", 404)
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["parent_id"] = "bad_parent" data["parent_id"] = "bad_parent"
...@@ -613,7 +626,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore ...@@ -613,7 +626,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
} }
) )
def test_parent_id_wrong_thread(self): def test_create_parent_id_wrong_thread(self):
self.register_get_comment_response({"thread_id": "different_thread", "id": "test_parent"}) self.register_get_comment_response({"thread_id": "different_thread", "id": "test_parent"})
data = self.minimal_data.copy() data = self.minimal_data.copy()
data["parent_id"] = "test_parent" data["parent_id"] = "test_parent"
...@@ -629,7 +642,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore ...@@ -629,7 +642,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
} }
) )
def test_missing_field(self): def test_create_missing_field(self):
for field in self.minimal_data: for field in self.minimal_data:
data = self.minimal_data.copy() data = self.minimal_data.copy()
data.pop(field) data.pop(field)
...@@ -642,3 +655,60 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore ...@@ -642,3 +655,60 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
serializer.errors, serializer.errors,
{field: ["This field is required."]} {field: ["This field is required."]}
) )
def test_update_empty(self):
self.register_put_comment_response(self.existing_comment.attributes)
self.save_and_reserialize({}, instance=self.existing_comment)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"body": ["Original body"],
"course_id": [unicode(self.course.id)],
"user_id": [str(self.user.id)],
"anonymous": ["False"],
"anonymous_to_peers": ["False"],
"endorsed": ["False"],
}
)
def test_update_all(self):
self.register_put_comment_response(self.existing_comment.attributes)
data = {"raw_body": "Edited body"}
saved = self.save_and_reserialize(data, instance=self.existing_comment)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"body": ["Edited body"],
"course_id": [unicode(self.course.id)],
"user_id": [str(self.user.id)],
"anonymous": ["False"],
"anonymous_to_peers": ["False"],
"endorsed": ["False"],
}
)
self.assertEqual(saved["raw_body"], data["raw_body"])
def test_update_empty_raw_body(self):
serializer = CommentSerializer(
self.existing_comment,
data={"raw_body": ""},
partial=True,
context=get_context(self.course, self.request)
)
self.assertEqual(
serializer.errors,
{"raw_body": ["This field is required."]}
)
@ddt.data("thread_id", "parent_id")
def test_update_non_updatable(self, field):
serializer = CommentSerializer(
self.existing_comment,
data={field: "different_value"},
partial=True,
context=get_context(self.course, self.request)
)
self.assertEqual(
serializer.errors,
{field: ["This field is not allowed in an update."]}
)
...@@ -13,7 +13,11 @@ from django.core.urlresolvers import reverse ...@@ -13,7 +13,11 @@ from django.core.urlresolvers import reverse
from rest_framework.test import APIClient from rest_framework.test import APIClient
from discussion_api.tests.utils import CommentsServiceMockMixin, make_minimal_cs_thread from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_comment,
make_minimal_cs_thread,
)
from student.tests.factories import CourseEnrollmentFactory, UserFactory from student.tests.factories import CourseEnrollmentFactory, UserFactory
from util.testing import UrlResetMixin from util.testing import UrlResetMixin
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
...@@ -633,3 +637,85 @@ class CommentViewSetCreateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -633,3 +637,85 @@ class CommentViewSetCreateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
response_data = json.loads(response.content) response_data = json.loads(response.content)
self.assertEqual(response_data, expected_response_data) self.assertEqual(response_data, expected_response_data)
class CommentViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for CommentViewSet partial_update"""
def setUp(self):
super(CommentViewSetPartialUpdateTest, self).setUp()
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
self.register_get_user_response(self.user)
self.url = reverse("comment-detail", kwargs={"comment_id": "test_comment"})
cs_thread = make_minimal_cs_thread({
"id": "test_thread",
"course_id": unicode(self.course.id),
})
self.register_get_thread_response(cs_thread)
cs_comment = make_minimal_cs_comment({
"id": "test_comment",
"course_id": cs_thread["course_id"],
"thread_id": cs_thread["id"],
"username": self.user.username,
"user_id": str(self.user.id),
"created_at": "2015-06-03T00:00:00Z",
"updated_at": "2015-06-03T00:00:00Z",
"body": "Original body",
})
self.register_get_comment_response(cs_comment)
self.register_put_comment_response(cs_comment)
def test_basic(self):
request_data = {"raw_body": "Edited body"}
expected_response_data = {
"id": "test_comment",
"thread_id": "test_thread",
"parent_id": None,
"author": self.user.username,
"author_label": None,
"created_at": "2015-06-03T00:00:00Z",
"updated_at": "2015-06-03T00:00:00Z",
"raw_body": "Edited body",
"endorsed": False,
"endorsed_by": None,
"endorsed_by_label": None,
"endorsed_at": None,
"abuse_flagged": False,
"voted": False,
"vote_count": 0,
"children": [],
}
response = self.client.patch( # pylint: disable=no-member
self.url,
json.dumps(request_data),
content_type="application/json"
)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content)
self.assertEqual(response_data, expected_response_data)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"body": ["Edited body"],
"course_id": [unicode(self.course.id)],
"user_id": [str(self.user.id)],
"anonymous": ["False"],
"anonymous_to_peers": ["False"],
"endorsed": ["False"],
}
)
def test_error(self):
request_data = {"raw_body": ""}
response = self.client.patch( # pylint: disable=no-member
self.url,
json.dumps(request_data),
content_type="application/json"
)
expected_response_data = {
"field_errors": {"raw_body": {"developer_message": "This field is required."}}
}
self.assertEqual(response.status_code, 400)
response_data = json.loads(response.content)
self.assertEqual(response_data, expected_response_data)
...@@ -30,6 +30,32 @@ def _get_thread_callback(thread_data): ...@@ -30,6 +30,32 @@ def _get_thread_callback(thread_data):
return callback return callback
def _get_comment_callback(comment_data, thread_id, parent_id):
"""
Get a callback function that will return a comment containing the given data
plus necessary dummy data, overridden by the content of the POST/PUT
request.
"""
def callback(request, _uri, headers):
"""
Simulate the comment creation or update endpoint as described above.
"""
response_data = make_minimal_cs_comment(comment_data)
# thread_id and parent_id are not included in request payload but
# are returned by the comments service
response_data["thread_id"] = thread_id
response_data["parent_id"] = parent_id
for key, val_list in request.parsed_body.items():
val = val_list[0]
if key in ["anonymous", "anonymous_to_peers", "endorsed"]:
response_data[key] = val == "True"
else:
response_data[key] = val
return (200, headers, json.dumps(response_data))
return callback
class CommentsServiceMockMixin(object): class CommentsServiceMockMixin(object):
"""Mixin with utility methods for mocking the comments service""" """Mixin with utility methods for mocking the comments service"""
def register_get_threads_response(self, threads, page, num_pages): def register_get_threads_response(self, threads, page, num_pages):
...@@ -84,33 +110,35 @@ class CommentsServiceMockMixin(object): ...@@ -84,33 +110,35 @@ class CommentsServiceMockMixin(object):
status=200 status=200
) )
def register_post_comment_response(self, response_overrides, thread_id, parent_id=None): def register_post_comment_response(self, comment_data, thread_id, parent_id=None):
""" """
Register a mock response for POST on the CS comments endpoint for the Register a mock response for POST on the CS comments endpoint for the
given thread or parent; exactly one of thread_id and parent_id must be given thread or parent; exactly one of thread_id and parent_id must be
specified. specified.
""" """
def callback(request, _uri, headers):
"""
Simulate the comment creation endpoint by returning the provided data
along with the data from response_overrides.
"""
response_data = make_minimal_cs_comment(
{key: val[0] for key, val in request.parsed_body.items()}
)
response_data.update(response_overrides or {})
# thread_id and parent_id are not included in request payload but
# are returned by the comments service
response_data["thread_id"] = thread_id
response_data["parent_id"] = parent_id
return (200, headers, json.dumps(response_data))
if parent_id: if parent_id:
url = "http://localhost:4567/api/v1/comments/{}".format(parent_id) url = "http://localhost:4567/api/v1/comments/{}".format(parent_id)
else: else:
url = "http://localhost:4567/api/v1/threads/{}/comments".format(thread_id) url = "http://localhost:4567/api/v1/threads/{}/comments".format(thread_id)
httpretty.register_uri(httpretty.POST, url, body=callback) httpretty.register_uri(
httpretty.POST,
url,
body=_get_comment_callback(comment_data, thread_id, parent_id)
)
def register_put_comment_response(self, comment_data):
"""
Register a mock response for PUT on the CS endpoint for the given
comment data (which must include the key "id").
"""
thread_id = comment_data["thread_id"]
parent_id = comment_data.get("parent_id")
httpretty.register_uri(
httpretty.PUT,
"http://localhost:4567/api/v1/comments/{}".format(comment_data["id"]),
body=_get_comment_callback(comment_data, thread_id, parent_id)
)
def register_get_comment_error_response(self, comment_id, status_code): def register_get_comment_error_response(self, comment_id, status_code):
""" """
......
...@@ -18,6 +18,7 @@ from discussion_api.api import ( ...@@ -18,6 +18,7 @@ from discussion_api.api import (
get_comment_list, get_comment_list,
get_course_topics, get_course_topics,
get_thread_list, get_thread_list,
update_comment,
update_thread, update_thread,
) )
from discussion_api.forms import CommentListGetForm, ThreadListGetForm from discussion_api.forms import CommentListGetForm, ThreadListGetForm
...@@ -212,7 +213,8 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -212,7 +213,8 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
""" """
**Use Cases** **Use Cases**
Retrieve the list of comments in a thread. Retrieve the list of comments in a thread, create a comment, or modify
an existing comment.
**Example Requests**: **Example Requests**:
...@@ -224,6 +226,9 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -224,6 +226,9 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"raw_body": "Body text" "raw_body": "Body text"
} }
PATCH /api/discussion/v1/comments/comment_id
{"raw_body": "Edited text"}
**GET Parameters**: **GET Parameters**:
* thread_id (required): The thread to retrieve comments for * thread_id (required): The thread to retrieve comments for
...@@ -245,6 +250,10 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -245,6 +250,10 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
* raw_body: The comment's raw body text * raw_body: The comment's raw body text
**PATCH Parameters**:
raw_body is accepted with the same meaning as in a POST request
**GET Response Values**: **GET Response Values**:
* results: The list of comments; each item in the list has the same * results: The list of comments; each item in the list has the same
...@@ -254,7 +263,7 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -254,7 +263,7 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
* previous: The URL of the previous page (or null if last page) * previous: The URL of the previous page (or null if last page)
**POST Response Values**: **POST/PATCH Response Values**:
* id: The id of the comment * id: The id of the comment
...@@ -298,6 +307,8 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -298,6 +307,8 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
* children: The list of child comments (with the same format) * children: The list of child comments (with the same format)
""" """
lookup_field = "comment_id"
def list(self, request): def list(self, request):
""" """
Implements the GET method for the list endpoint as described in the Implements the GET method for the list endpoint as described in the
...@@ -322,3 +333,10 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -322,3 +333,10 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
class docstring. class docstring.
""" """
return Response(create_comment(request, request.DATA)) return Response(create_comment(request, request.DATA))
def partial_update(self, request, comment_id):
"""
Implements the PATCH method for the instance endpoint as described in
the class docstring.
"""
return Response(update_comment(request, comment_id, request.DATA))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment