Commit 94c1bf49 by Greg Price

Merge pull request #8437 from edx/gprice/discussion-api-edit-comment

Add comment editing to discussion API
parents 26f1ecb6 895731f5
......@@ -26,6 +26,7 @@ from django_comment_client.base.views import (
track_forum_event,
)
from django_comment_client.utils import get_accessible_discussion_modules
from lms.lib.comment_client.comment import Comment
from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.utils import CommentClientRequestError
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, is_commentable_cohorted
......@@ -74,16 +75,36 @@ def _get_thread_and_context(request, thread_id, retrieve_kwargs=None):
raise Http404
def _is_user_author_or_privileged(cc_thread, context):
def _get_comment_and_context(request, comment_id):
"""
Check if the user is the author of a thread or a privileged user.
Retrieve the given comment and build a serializer context for it, returning
both. This function also enforces access control for the comment (checking
both the user's access to the course and to the comment's thread's cohort if
applicable). Raises Http404 if the comment does not exist or the user cannot
access it.
"""
try:
cc_comment = Comment(id=comment_id).retrieve()
_, context = _get_thread_and_context(
request,
cc_comment["thread_id"],
cc_comment["parent_id"]
)
return cc_comment, context
except CommentClientRequestError:
raise Http404
def _is_user_author_or_privileged(cc_content, context):
"""
Check if the user is the author of a content object or a privileged user.
Returns:
Boolean
"""
return (
context["is_requester_privileged"] or
context["cc_requester"]["id"] == cc_thread["user_id"]
context["cc_requester"]["id"] == cc_content["user_id"]
)
......@@ -442,6 +463,47 @@ def update_thread(request, thread_id, update_data):
return api_thread
def update_comment(request, comment_id, update_data):
"""
Update a comment.
Parameters:
request: The django request object used for build_absolute_uri and
determining the requesting user.
comment_id: The id for the comment to update.
update_data: The data to update in the comment.
Returns:
The updated comment; see discussion_api.views.CommentViewSet for more
detail.
Raises:
Http404: if the comment does not exist or is not accessible to the
requesting user
PermissionDenied: if the comment is accessible to but not editable by
the requesting user
ValidationError: if there is an error applying the update (e.g. raw_body
is empty or thread_id is included)
"""
cc_comment, context = _get_comment_and_context(request, comment_id)
if not _is_user_author_or_privileged(cc_comment, context):
raise PermissionDenied()
serializer = CommentSerializer(cc_comment, data=update_data, partial=True, context=context)
if not serializer.is_valid():
raise ValidationError(serializer.errors)
# Only save comment object if the comment is actually modified
if update_data:
serializer.save()
return serializer.data
def delete_thread(request, thread_id):
"""
Delete a thread.
......
......@@ -69,12 +69,23 @@ class _ContentSerializer(serializers.Serializer):
voted = serializers.SerializerMethodField("get_voted")
vote_count = serializers.SerializerMethodField("get_vote_count")
non_updatable_fields = ()
def __init__(self, *args, **kwargs):
super(_ContentSerializer, self).__init__(*args, **kwargs)
# id is an invalid class attribute name, so we must declare a different
# name above and modify it here
self.fields["id"] = self.fields.pop("id_")
for field in self.non_updatable_fields:
setattr(self, "validate_{}".format(field), self._validate_non_updatable)
def _validate_non_updatable(self, attrs, _source):
"""Ensure that a field is not edited in an update operation."""
if self.object:
raise ValidationError("This field is not allowed in an update.")
return attrs
def _is_user_privileged(self, user_id):
"""
Returns a boolean indicating whether the given user_id identifies a
......@@ -156,6 +167,8 @@ class ThreadSerializer(_ContentSerializer):
endorsed_comment_list_url = serializers.SerializerMethodField("get_endorsed_comment_list_url")
non_endorsed_comment_list_url = serializers.SerializerMethodField("get_non_endorsed_comment_list_url")
non_updatable_fields = ("course_id",)
def __init__(self, *args, **kwargs):
super(ThreadSerializer, self).__init__(*args, **kwargs)
# type is an invalid class attribute name, so we must declare a
......@@ -199,12 +212,6 @@ class ThreadSerializer(_ContentSerializer):
"""Returns the URL to retrieve the thread's non-endorsed comments."""
return self.get_comment_list_url(obj, endorsed=False)
def validate_course_id(self, attrs, _source):
"""Ensure that course_id is not edited in an update operation."""
if self.object:
raise ValidationError("This field is not allowed in an update.")
return attrs
def restore_object(self, attrs, instance=None):
if instance:
for key, val in attrs.items():
......@@ -230,6 +237,8 @@ class CommentSerializer(_ContentSerializer):
endorsed_at = serializers.SerializerMethodField("get_endorsed_at")
children = serializers.SerializerMethodField("get_children")
non_updatable_fields = ("thread_id", "parent_id")
def get_endorsed_by(self, obj):
"""
Returns the username of the endorsing user, if the information is
......@@ -288,8 +297,10 @@ class CommentSerializer(_ContentSerializer):
return attrs
def restore_object(self, attrs, instance=None):
if instance: # pragma: no cover
raise ValueError("CommentSerializer cannot be used for updates.")
if instance:
for key, val in attrs.items():
instance[key] = val
return instance
return Comment(
course_id=self.context["thread"]["course_id"],
user_id=self.context["cc_requester"]["id"],
......
......@@ -27,6 +27,7 @@ from discussion_api.api import (
get_comment_list,
get_course_topics,
get_thread_list,
update_comment,
update_thread,
)
from discussion_api.tests.utils import (
......@@ -1638,6 +1639,173 @@ class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC
@ddt.ddt
class UpdateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase):
"""Tests for update_comment"""
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
def setUp(self):
super(UpdateCommentTest, self).setUp()
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
self.user = UserFactory.create()
self.register_get_user_response(self.user)
self.request = RequestFactory().get("/test_path")
self.request.user = self.user
self.course = CourseFactory.create()
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
def register_comment(self, overrides=None, thread_overrides=None):
"""
Make a comment with appropriate data overridden by the overrides
parameter and register mock responses for both GET and PUT on its
endpoint. Also mock GET for the related thread with thread_overrides.
"""
cs_thread_data = make_minimal_cs_thread({
"id": "test_thread",
"course_id": unicode(self.course.id)
})
cs_thread_data.update(thread_overrides or {})
self.register_get_thread_response(cs_thread_data)
cs_comment_data = make_minimal_cs_comment({
"id": "test_comment",
"course_id": cs_thread_data["course_id"],
"thread_id": cs_thread_data["id"],
"username": self.user.username,
"user_id": str(self.user.id),
"created_at": "2015-06-03T00:00:00Z",
"updated_at": "2015-06-03T00:00:00Z",
"body": "Original body",
})
cs_comment_data.update(overrides or {})
self.register_get_comment_response(cs_comment_data)
self.register_put_comment_response(cs_comment_data)
def test_empty(self):
"""Check that an empty update does not make any modifying requests."""
self.register_comment()
update_comment(self.request, "test_comment", {})
for request in httpretty.httpretty.latest_requests:
self.assertEqual(request.method, "GET")
def test_basic(self):
self.register_comment()
actual = update_comment(self.request, "test_comment", {"raw_body": "Edited body"})
expected = {
"id": "test_comment",
"thread_id": "test_thread",
"parent_id": None, # TODO: we can't get this without retrieving from the thread :-(
"author": self.user.username,
"author_label": None,
"created_at": "2015-06-03T00:00:00Z",
"updated_at": "2015-06-03T00:00:00Z",
"raw_body": "Edited body",
"endorsed": False,
"endorsed_by": None,
"endorsed_by_label": None,
"endorsed_at": None,
"abuse_flagged": False,
"voted": False,
"vote_count": 0,
"children": [],
}
self.assertEqual(actual, expected)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"body": ["Edited body"],
"course_id": [unicode(self.course.id)],
"user_id": [str(self.user.id)],
"anonymous": ["False"],
"anonymous_to_peers": ["False"],
"endorsed": ["False"],
}
)
def test_nonexistent_comment(self):
self.register_get_comment_error_response("test_comment", 404)
with self.assertRaises(Http404):
update_comment(self.request, "test_comment", {})
def test_nonexistent_course(self):
self.register_comment(thread_overrides={"course_id": "non/existent/course"})
with self.assertRaises(Http404):
update_comment(self.request, "test_comment", {})
def test_unenrolled(self):
self.register_comment()
self.request.user = UserFactory.create()
with self.assertRaises(Http404):
update_comment(self.request, "test_comment", {})
def test_discussions_disabled(self):
_remove_discussion_tab(self.course, self.user.id)
self.register_comment()
with self.assertRaises(Http404):
update_comment(self.request, "test_comment", {})
@ddt.data(
*itertools.product(
[
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_MODERATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT,
],
[True, False],
["no_group", "match_group", "different_group"],
)
)
@ddt.unpack
def test_group_access(self, role_name, course_is_cohorted, thread_group_state):
cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted})
CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id)
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
role.users = [self.user]
self.register_get_thread_response(make_minimal_cs_thread())
self.register_comment(
{"thread_id": "test_thread"},
thread_overrides={
"id": "test_thread",
"course_id": unicode(cohort_course.id),
"group_id": (
None if thread_group_state == "no_group" else
cohort.id if thread_group_state == "match_group" else
cohort.id + 1
),
}
)
expected_error = (
role_name == FORUM_ROLE_STUDENT and
course_is_cohorted and
thread_group_state == "different_group"
)
try:
update_comment(self.request, "test_comment", {})
self.assertFalse(expected_error)
except Http404:
self.assertTrue(expected_error)
@ddt.data(
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_MODERATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT,
)
def test_role_access(self, role_name):
role = Role.objects.create(name=role_name, course_id=self.course.id)
role.users = [self.user]
self.register_comment({"user_id": str(self.user.id + 1)})
expected_error = role_name == FORUM_ROLE_STUDENT
try:
update_comment(self.request, "test_comment", {"raw_body": "edited"})
self.assertFalse(expected_error)
except PermissionDenied:
self.assertTrue(expected_error)
@ddt.ddt
class DeleteThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase):
"""Tests for delete_thread"""
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
......
......@@ -23,6 +23,7 @@ from django_comment_common.models import (
FORUM_ROLE_STUDENT,
Role,
)
from lms.lib.comment_client.comment import Comment
from lms.lib.comment_client.thread import Thread
from student.tests.factories import UserFactory
from util.testing import UrlResetMixin
......@@ -553,8 +554,15 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
"thread_id": "test_thread",
"raw_body": "Test body",
}
self.existing_comment = Comment(**make_minimal_cs_comment({
"id": "existing_comment",
"thread_id": "existing_thread",
"body": "Original body",
"user_id": str(self.user.id),
"course_id": unicode(self.course.id),
}))
def save_and_reserialize(self, data):
def save_and_reserialize(self, data, instance=None):
"""
Create a serializer with the given data, ensure that it is valid, save
the result, and return the full comment data from the serializer.
......@@ -564,13 +572,18 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
self.request,
make_minimal_cs_thread({"course_id": unicode(self.course.id)})
)
serializer = CommentSerializer(data=data, context=context)
serializer = CommentSerializer(
instance,
data=data,
partial=(instance is not None),
context=context
)
self.assertTrue(serializer.is_valid())
serializer.save()
return serializer.data
@ddt.data(None, "test_parent")
def test_success(self, parent_id):
def test_create_success(self, parent_id):
data = self.minimal_data.copy()
if parent_id:
data["parent_id"] = parent_id
......@@ -597,7 +610,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
self.assertEqual(saved["id"], "test_comment")
self.assertEqual(saved["parent_id"], parent_id)
def test_parent_id_nonexistent(self):
def test_create_parent_id_nonexistent(self):
self.register_get_comment_error_response("bad_parent", 404)
data = self.minimal_data.copy()
data["parent_id"] = "bad_parent"
......@@ -613,7 +626,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
}
)
def test_parent_id_wrong_thread(self):
def test_create_parent_id_wrong_thread(self):
self.register_get_comment_response({"thread_id": "different_thread", "id": "test_parent"})
data = self.minimal_data.copy()
data["parent_id"] = "test_parent"
......@@ -629,7 +642,7 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
}
)
def test_missing_field(self):
def test_create_missing_field(self):
for field in self.minimal_data:
data = self.minimal_data.copy()
data.pop(field)
......@@ -642,3 +655,60 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
serializer.errors,
{field: ["This field is required."]}
)
def test_update_empty(self):
self.register_put_comment_response(self.existing_comment.attributes)
self.save_and_reserialize({}, instance=self.existing_comment)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"body": ["Original body"],
"course_id": [unicode(self.course.id)],
"user_id": [str(self.user.id)],
"anonymous": ["False"],
"anonymous_to_peers": ["False"],
"endorsed": ["False"],
}
)
def test_update_all(self):
self.register_put_comment_response(self.existing_comment.attributes)
data = {"raw_body": "Edited body"}
saved = self.save_and_reserialize(data, instance=self.existing_comment)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"body": ["Edited body"],
"course_id": [unicode(self.course.id)],
"user_id": [str(self.user.id)],
"anonymous": ["False"],
"anonymous_to_peers": ["False"],
"endorsed": ["False"],
}
)
self.assertEqual(saved["raw_body"], data["raw_body"])
def test_update_empty_raw_body(self):
serializer = CommentSerializer(
self.existing_comment,
data={"raw_body": ""},
partial=True,
context=get_context(self.course, self.request)
)
self.assertEqual(
serializer.errors,
{"raw_body": ["This field is required."]}
)
@ddt.data("thread_id", "parent_id")
def test_update_non_updatable(self, field):
serializer = CommentSerializer(
self.existing_comment,
data={field: "different_value"},
partial=True,
context=get_context(self.course, self.request)
)
self.assertEqual(
serializer.errors,
{field: ["This field is not allowed in an update."]}
)
......@@ -13,7 +13,11 @@ from django.core.urlresolvers import reverse
from rest_framework.test import APIClient
from discussion_api.tests.utils import CommentsServiceMockMixin, make_minimal_cs_thread
from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_comment,
make_minimal_cs_thread,
)
from student.tests.factories import CourseEnrollmentFactory, UserFactory
from util.testing import UrlResetMixin
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
......@@ -633,3 +637,85 @@ class CommentViewSetCreateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
self.assertEqual(response.status_code, 400)
response_data = json.loads(response.content)
self.assertEqual(response_data, expected_response_data)
class CommentViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for CommentViewSet partial_update"""
def setUp(self):
super(CommentViewSetPartialUpdateTest, self).setUp()
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
self.register_get_user_response(self.user)
self.url = reverse("comment-detail", kwargs={"comment_id": "test_comment"})
cs_thread = make_minimal_cs_thread({
"id": "test_thread",
"course_id": unicode(self.course.id),
})
self.register_get_thread_response(cs_thread)
cs_comment = make_minimal_cs_comment({
"id": "test_comment",
"course_id": cs_thread["course_id"],
"thread_id": cs_thread["id"],
"username": self.user.username,
"user_id": str(self.user.id),
"created_at": "2015-06-03T00:00:00Z",
"updated_at": "2015-06-03T00:00:00Z",
"body": "Original body",
})
self.register_get_comment_response(cs_comment)
self.register_put_comment_response(cs_comment)
def test_basic(self):
request_data = {"raw_body": "Edited body"}
expected_response_data = {
"id": "test_comment",
"thread_id": "test_thread",
"parent_id": None,
"author": self.user.username,
"author_label": None,
"created_at": "2015-06-03T00:00:00Z",
"updated_at": "2015-06-03T00:00:00Z",
"raw_body": "Edited body",
"endorsed": False,
"endorsed_by": None,
"endorsed_by_label": None,
"endorsed_at": None,
"abuse_flagged": False,
"voted": False,
"vote_count": 0,
"children": [],
}
response = self.client.patch( # pylint: disable=no-member
self.url,
json.dumps(request_data),
content_type="application/json"
)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content)
self.assertEqual(response_data, expected_response_data)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"body": ["Edited body"],
"course_id": [unicode(self.course.id)],
"user_id": [str(self.user.id)],
"anonymous": ["False"],
"anonymous_to_peers": ["False"],
"endorsed": ["False"],
}
)
def test_error(self):
request_data = {"raw_body": ""}
response = self.client.patch( # pylint: disable=no-member
self.url,
json.dumps(request_data),
content_type="application/json"
)
expected_response_data = {
"field_errors": {"raw_body": {"developer_message": "This field is required."}}
}
self.assertEqual(response.status_code, 400)
response_data = json.loads(response.content)
self.assertEqual(response_data, expected_response_data)
......@@ -30,6 +30,32 @@ def _get_thread_callback(thread_data):
return callback
def _get_comment_callback(comment_data, thread_id, parent_id):
"""
Get a callback function that will return a comment containing the given data
plus necessary dummy data, overridden by the content of the POST/PUT
request.
"""
def callback(request, _uri, headers):
"""
Simulate the comment creation or update endpoint as described above.
"""
response_data = make_minimal_cs_comment(comment_data)
# thread_id and parent_id are not included in request payload but
# are returned by the comments service
response_data["thread_id"] = thread_id
response_data["parent_id"] = parent_id
for key, val_list in request.parsed_body.items():
val = val_list[0]
if key in ["anonymous", "anonymous_to_peers", "endorsed"]:
response_data[key] = val == "True"
else:
response_data[key] = val
return (200, headers, json.dumps(response_data))
return callback
class CommentsServiceMockMixin(object):
"""Mixin with utility methods for mocking the comments service"""
def register_get_threads_response(self, threads, page, num_pages):
......@@ -84,33 +110,35 @@ class CommentsServiceMockMixin(object):
status=200
)
def register_post_comment_response(self, response_overrides, thread_id, parent_id=None):
def register_post_comment_response(self, comment_data, thread_id, parent_id=None):
"""
Register a mock response for POST on the CS comments endpoint for the
given thread or parent; exactly one of thread_id and parent_id must be
specified.
"""
def callback(request, _uri, headers):
"""
Simulate the comment creation endpoint by returning the provided data
along with the data from response_overrides.
"""
response_data = make_minimal_cs_comment(
{key: val[0] for key, val in request.parsed_body.items()}
)
response_data.update(response_overrides or {})
# thread_id and parent_id are not included in request payload but
# are returned by the comments service
response_data["thread_id"] = thread_id
response_data["parent_id"] = parent_id
return (200, headers, json.dumps(response_data))
if parent_id:
url = "http://localhost:4567/api/v1/comments/{}".format(parent_id)
else:
url = "http://localhost:4567/api/v1/threads/{}/comments".format(thread_id)
httpretty.register_uri(httpretty.POST, url, body=callback)
httpretty.register_uri(
httpretty.POST,
url,
body=_get_comment_callback(comment_data, thread_id, parent_id)
)
def register_put_comment_response(self, comment_data):
"""
Register a mock response for PUT on the CS endpoint for the given
comment data (which must include the key "id").
"""
thread_id = comment_data["thread_id"]
parent_id = comment_data.get("parent_id")
httpretty.register_uri(
httpretty.PUT,
"http://localhost:4567/api/v1/comments/{}".format(comment_data["id"]),
body=_get_comment_callback(comment_data, thread_id, parent_id)
)
def register_get_comment_error_response(self, comment_id, status_code):
"""
......
......@@ -18,6 +18,7 @@ from discussion_api.api import (
get_comment_list,
get_course_topics,
get_thread_list,
update_comment,
update_thread,
)
from discussion_api.forms import CommentListGetForm, ThreadListGetForm
......@@ -212,7 +213,8 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"""
**Use Cases**
Retrieve the list of comments in a thread.
Retrieve the list of comments in a thread, create a comment, or modify
an existing comment.
**Example Requests**:
......@@ -224,6 +226,9 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"raw_body": "Body text"
}
PATCH /api/discussion/v1/comments/comment_id
{"raw_body": "Edited text"}
**GET Parameters**:
* thread_id (required): The thread to retrieve comments for
......@@ -245,6 +250,10 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
* raw_body: The comment's raw body text
**PATCH Parameters**:
raw_body is accepted with the same meaning as in a POST request
**GET Response Values**:
* results: The list of comments; each item in the list has the same
......@@ -254,7 +263,7 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
* previous: The URL of the previous page (or null if last page)
**POST Response Values**:
**POST/PATCH Response Values**:
* id: The id of the comment
......@@ -298,6 +307,8 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
* children: The list of child comments (with the same format)
"""
lookup_field = "comment_id"
def list(self, request):
"""
Implements the GET method for the list endpoint as described in the
......@@ -322,3 +333,10 @@ class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
class docstring.
"""
return Response(create_comment(request, request.DATA))
def partial_update(self, request, comment_id):
"""
Implements the PATCH method for the instance endpoint as described in
the class docstring.
"""
return Response(update_comment(request, comment_id, request.DATA))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment