Commit 0998df8c by Greg Price

Add comment voting to discussion API

parent 905512b6
...@@ -15,7 +15,7 @@ from opaque_keys import InvalidKeyError ...@@ -15,7 +15,7 @@ from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseKey from opaque_keys.edx.locator import CourseKey
from courseware.courses import get_course_with_access from courseware.courses import get_course_with_access
from discussion_api.forms import ThreadActionsForm from discussion_api.forms import CommentActionsForm, ThreadActionsForm
from discussion_api.pagination import get_paginated_data from discussion_api.pagination import get_paginated_data
from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context
from django_comment_client.base.views import ( from django_comment_client.base.views import (
...@@ -335,25 +335,25 @@ def get_comment_list(request, thread_id, endorsed, page, page_size): ...@@ -335,25 +335,25 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
return get_paginated_data(request, results, page, num_pages) return get_paginated_data(request, results, page, num_pages)
def _do_extra_thread_actions(api_thread, cc_thread, request_fields, actions_form, context): def _do_extra_actions(api_content, cc_content, request_fields, actions_form, context):
""" """
Perform any necessary additional actions related to thread creation or Perform any necessary additional actions related to content creation or
update that require a separate comments service request. update that require a separate comments service request.
""" """
for field, form_value in actions_form.cleaned_data.items(): for field, form_value in actions_form.cleaned_data.items():
if field in request_fields and form_value != api_thread[field]: if field in request_fields and form_value != api_content[field]:
api_thread[field] = form_value api_content[field] = form_value
if field == "following": if field == "following":
if form_value: if form_value:
context["cc_requester"].follow(cc_thread) context["cc_requester"].follow(cc_content)
else: else:
context["cc_requester"].unfollow(cc_thread) context["cc_requester"].unfollow(cc_content)
else: else:
assert field == "voted" assert field == "voted"
if form_value: if form_value:
context["cc_requester"].vote(cc_thread, "up") context["cc_requester"].vote(cc_content, "up")
else: else:
context["cc_requester"].unvote(cc_thread) context["cc_requester"].unvote(cc_content)
def create_thread(request, thread_data): def create_thread(request, thread_data):
...@@ -390,7 +390,7 @@ def create_thread(request, thread_data): ...@@ -390,7 +390,7 @@ def create_thread(request, thread_data):
cc_thread = serializer.object cc_thread = serializer.object
api_thread = serializer.data api_thread = serializer.data
_do_extra_thread_actions(api_thread, cc_thread, thread_data.keys(), actions_form, context) _do_extra_actions(api_thread, cc_thread, thread_data.keys(), actions_form, context)
track_forum_event( track_forum_event(
request, request,
...@@ -428,11 +428,15 @@ def create_comment(request, comment_data): ...@@ -428,11 +428,15 @@ def create_comment(request, comment_data):
raise ValidationError({"thread_id": ["Invalid value."]}) raise ValidationError({"thread_id": ["Invalid value."]})
serializer = CommentSerializer(data=comment_data, context=context) serializer = CommentSerializer(data=comment_data, context=context)
if not serializer.is_valid(): actions_form = CommentActionsForm(comment_data)
raise ValidationError(serializer.errors) if not (serializer.is_valid() and actions_form.is_valid()):
raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
serializer.save() serializer.save()
cc_comment = serializer.object cc_comment = serializer.object
api_comment = serializer.data
_do_extra_actions(api_comment, cc_comment, comment_data.keys(), actions_form, context)
track_forum_event( track_forum_event(
request, request,
get_comment_created_event_name(cc_comment), get_comment_created_event_name(cc_comment),
...@@ -501,7 +505,7 @@ def update_thread(request, thread_id, update_data): ...@@ -501,7 +505,7 @@ def update_thread(request, thread_id, update_data):
if set(update_data) - set(actions_form.fields): if set(update_data) - set(actions_form.fields):
serializer.save() serializer.save()
api_thread = serializer.data api_thread = serializer.data
_do_extra_thread_actions(api_thread, cc_thread, update_data.keys(), actions_form, context) _do_extra_actions(api_thread, cc_thread, update_data.keys(), actions_form, context)
return api_thread return api_thread
...@@ -513,7 +517,7 @@ def _get_comment_editable_fields(cc_comment, context): ...@@ -513,7 +517,7 @@ def _get_comment_editable_fields(cc_comment, context):
""" """
Get the list of editable fields for the given comment in the given context Get the list of editable fields for the given comment in the given context
""" """
ret = set() ret = {"voted"}
if _is_user_author_or_privileged(cc_comment, context): if _is_user_author_or_privileged(cc_comment, context):
ret |= _COMMENT_EDITABLE_BY_AUTHOR ret |= _COMMENT_EDITABLE_BY_AUTHOR
if _is_user_author_or_privileged(context["thread"], context): if _is_user_author_or_privileged(context["thread"], context):
...@@ -554,12 +558,15 @@ def update_comment(request, comment_id, update_data): ...@@ -554,12 +558,15 @@ def update_comment(request, comment_id, update_data):
editable_fields = _get_comment_editable_fields(cc_comment, context) editable_fields = _get_comment_editable_fields(cc_comment, context)
_check_editable_fields(editable_fields, update_data) _check_editable_fields(editable_fields, update_data)
serializer = CommentSerializer(cc_comment, data=update_data, partial=True, context=context) serializer = CommentSerializer(cc_comment, data=update_data, partial=True, context=context)
if not serializer.is_valid(): actions_form = CommentActionsForm(update_data)
raise ValidationError(serializer.errors) if not (serializer.is_valid() and actions_form.is_valid()):
# Only save comment object if the comment is actually modified raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
if update_data: # Only save thread object if some of the edited fields are in the thread data, not extra actions
if set(update_data) - set(actions_form.fields):
serializer.save() serializer.save()
return serializer.data api_comment = serializer.data
_do_extra_actions(api_comment, cc_comment, update_data.keys(), actions_form, context)
return api_comment
def delete_thread(request, thread_id): def delete_thread(request, thread_id):
......
...@@ -59,7 +59,7 @@ class ThreadListGetForm(_PaginationForm): ...@@ -59,7 +59,7 @@ class ThreadListGetForm(_PaginationForm):
class ThreadActionsForm(Form): class ThreadActionsForm(Form):
""" """
A form to handle fields in thread creation that require separate A form to handle fields in thread creation/update that require separate
interactions with the comments service. interactions with the comments service.
""" """
following = BooleanField(required=False) following = BooleanField(required=False)
...@@ -74,3 +74,11 @@ class CommentListGetForm(_PaginationForm): ...@@ -74,3 +74,11 @@ class CommentListGetForm(_PaginationForm):
# TODO: should we use something better here? This only accepts "True", # TODO: should we use something better here? This only accepts "True",
# "False", "1", and "0" # "False", "1", and "0"
endorsed = NullBooleanField(required=False) endorsed = NullBooleanField(required=False)
class CommentActionsForm(Form):
"""
A form to handle fields in comment creation/update that require separate
interactions with the comments service.
"""
voted = BooleanField(required=False)
...@@ -1359,6 +1359,21 @@ class CreateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -1359,6 +1359,21 @@ class CreateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
self.assertEqual(actual_event_name, expected_event_name) self.assertEqual(actual_event_name, expected_event_name)
self.assertEqual(actual_event_data, expected_event_data) self.assertEqual(actual_event_data, expected_event_data)
def test_voted(self):
self.register_post_comment_response({"id": "test_comment"}, "test_thread")
self.register_comment_votes_response("test_comment")
data = self.minimal_data.copy()
data["voted"] = "True"
result = create_comment(self.request, data)
self.assertEqual(result["voted"], True)
cs_request = httpretty.last_request()
self.assertEqual(urlparse(cs_request.path).path, "/api/v1/comments/test_comment/votes")
self.assertEqual(cs_request.method, "PUT")
self.assertEqual(
cs_request.parsed_body,
{"user_id": [str(self.user.id)], "value": ["up"]}
)
def test_thread_id_missing(self): def test_thread_id_missing(self):
with self.assertRaises(ValidationError) as assertion: with self.assertRaises(ValidationError) as assertion:
create_comment(self.request, {}) create_comment(self.request, {})
...@@ -1918,6 +1933,45 @@ class UpdateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -1918,6 +1933,45 @@ class UpdateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
{"endorsed": ["This field is not editable."]} {"endorsed": ["This field is not editable."]}
) )
@ddt.data(*itertools.product([True, False], [True, False]))
@ddt.unpack
def test_voted(self, old_voted, new_voted):
"""
Test attempts to edit the "voted" field.
old_voted indicates whether the comment should be upvoted at the start of
the test. new_voted indicates the value for the "voted" field in the
update. If old_voted and new_voted are the same, no update should be
made. Otherwise, a vote should be PUT or DELETEd according to the
new_voted value.
"""
if old_voted:
self.register_get_user_response(self.user, upvoted_ids=["test_comment"])
self.register_comment_votes_response("test_comment")
self.register_comment()
data = {"voted": new_voted}
result = update_comment(self.request, "test_comment", data)
self.assertEqual(result["voted"], new_voted)
last_request_path = urlparse(httpretty.last_request().path).path
votes_url = "/api/v1/comments/test_comment/votes"
if old_voted == new_voted:
self.assertNotEqual(last_request_path, votes_url)
else:
self.assertEqual(last_request_path, votes_url)
self.assertEqual(
httpretty.last_request().method,
"PUT" if new_voted else "DELETE"
)
actual_request_data = (
httpretty.last_request().parsed_body if new_voted else
parse_qs(urlparse(httpretty.last_request().path).query)
)
actual_request_data.pop("request_id", None)
expected_request_data = {"user_id": [str(self.user.id)]}
if new_voted:
expected_request_data["value"] = ["up"]
self.assertEqual(actual_request_data, expected_request_data)
@ddt.ddt @ddt.ddt
class DeleteThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase): class DeleteThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase):
......
...@@ -203,6 +203,19 @@ class CommentsServiceMockMixin(object): ...@@ -203,6 +203,19 @@ class CommentsServiceMockMixin(object):
status=200 status=200
) )
def register_comment_votes_response(self, comment_id):
"""
Register a mock response for PUT and DELETE on the CS comment votes
endpoint
"""
for method in [httpretty.PUT, httpretty.DELETE]:
httpretty.register_uri(
method,
"http://localhost:4567/api/v1/comments/{}/votes".format(comment_id),
body=json.dumps({}), # body is unused
status=200
)
def register_delete_thread_response(self, thread_id): def register_delete_thread_response(self, thread_id):
""" """
Register a mock response for DELETE on the CS thread instance endpoint Register a mock response for DELETE on the CS thread instance endpoint
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment