Commit 6c2df96c by Greg Price

Merge pull request #8372 from edx/gprice/discussion-api-follow-thread

Add ability to follow a thread in discussion API
parents ef5cd63c 9ff87497
......@@ -9,13 +9,11 @@ from django.core.exceptions import ValidationError
from django.core.urlresolvers import reverse
from django.http import Http404
from rest_framework.exceptions import PermissionDenied
from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseKey
from courseware.courses import get_course_with_access
from discussion_api.forms import ThreadCreateExtrasForm
from discussion_api.forms import ThreadActionsForm
from discussion_api.pagination import get_paginated_data
from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context
from django_comment_client.base.views import (
......@@ -267,6 +265,20 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
return get_paginated_data(request, results, page, num_pages)
def _do_extra_thread_actions(api_thread, cc_thread, request_fields, actions_form, context):
"""
Perform any necessary additional actions related to thread creation or
update that require a separate comments service request.
"""
form_following = actions_form.cleaned_data["following"]
if "following" in request_fields and form_following != api_thread["following"]:
if form_following:
context["cc_requester"].follow(cc_thread)
else:
context["cc_requester"].unfollow(cc_thread)
api_thread["following"] = form_following
def create_thread(request, thread_data):
"""
Create a thread.
......@@ -294,27 +306,24 @@ def create_thread(request, thread_data):
context = get_context(course, request)
serializer = ThreadSerializer(data=thread_data, context=context)
extras_form = ThreadCreateExtrasForm(thread_data)
if not (serializer.is_valid() and extras_form.is_valid()):
raise ValidationError(dict(serializer.errors.items() + extras_form.errors.items()))
actions_form = ThreadActionsForm(thread_data)
if not (serializer.is_valid() and actions_form.is_valid()):
raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
serializer.save()
thread = serializer.object
ret = serializer.data
following = extras_form.cleaned_data["following"]
if following:
context["cc_requester"].follow(thread)
ret["following"] = True
cc_thread = serializer.object
api_thread = serializer.data
_do_extra_thread_actions(api_thread, cc_thread, thread_data.keys(), actions_form, context)
track_forum_event(
request,
THREAD_CREATED_EVENT_NAME,
course,
thread,
get_thread_created_event_data(thread, followed=following)
cc_thread,
get_thread_created_event_data(cc_thread, followed=actions_form.cleaned_data["following"])
)
return ret
return api_thread
def create_comment(request, comment_data):
......@@ -359,6 +368,21 @@ def create_comment(request, comment_data):
return serializer.data
_THREAD_EDITABLE_BY_ANY = {"following"}
_THREAD_EDITABLE_BY_AUTHOR = {"topic_id", "type", "title", "raw_body"} | _THREAD_EDITABLE_BY_ANY
def _get_thread_editable_fields(cc_thread, context):
"""
Get the list of editable fields for the given thread in the given context
"""
is_author = context["cc_requester"]["id"] == cc_thread["user_id"]
if context["is_requester_privileged"] or is_author:
return _THREAD_EDITABLE_BY_AUTHOR
else:
return _THREAD_EDITABLE_BY_ANY
def update_thread(request, thread_id, update_data):
"""
Update a thread.
......@@ -378,11 +402,21 @@ def update_thread(request, thread_id, update_data):
detail.
"""
cc_thread, context = _get_thread_and_context(request, thread_id)
is_author = str(request.user.id) == cc_thread["user_id"]
if not (context["is_requester_privileged"] or is_author):
raise PermissionDenied()
editable_fields = _get_thread_editable_fields(cc_thread, context)
non_editable_errors = {
field: ["This field is not editable."]
for field in update_data.keys()
if field not in editable_fields
}
if non_editable_errors:
raise ValidationError(non_editable_errors)
serializer = ThreadSerializer(cc_thread, data=update_data, partial=True, context=context)
if not serializer.is_valid():
raise ValidationError(serializer.errors)
serializer.save()
return serializer.data
actions_form = ThreadActionsForm(update_data)
if not (serializer.is_valid() and actions_form.is_valid()):
raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
# Only save thread object if some of the edited fields are in the thread data, not extra actions
if set(update_data) - set(actions_form.fields):
serializer.save()
api_thread = serializer.data
_do_extra_thread_actions(api_thread, cc_thread, update_data.keys(), actions_form, context)
return api_thread
......@@ -57,7 +57,7 @@ class ThreadListGetForm(_PaginationForm):
raise ValidationError("'{}' is not a valid course id".format(value))
class ThreadCreateExtrasForm(Form):
class ThreadActionsForm(Form):
"""
A form to handle fields in thread creation that require separate
interactions with the comments service.
......
......@@ -3,7 +3,7 @@ Tests for Discussion API internal interface
"""
from datetime import datetime, timedelta
import itertools
from urlparse import urlparse, urlunparse
from urlparse import parse_qs, urlparse, urlunparse
from urllib import urlencode
import ddt
......@@ -15,8 +15,6 @@ from django.core.exceptions import ValidationError
from django.http import Http404
from django.test.client import RequestFactory
from rest_framework.exceptions import PermissionDenied
from opaque_keys.edx.locator import CourseLocator
from courseware.tests.factories import BetaTesterFactory, StaffFactory
......@@ -1132,6 +1130,7 @@ class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC
urlparse(cs_request.path).path,
"/api/v1/users/{}/subscriptions".format(self.user.id)
)
self.assertEqual(cs_request.method, "POST")
self.assertEqual(
cs_request.parsed_body,
{"source_type": ["thread"], "source_id": ["test_id"]}
......@@ -1396,6 +1395,15 @@ class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC
self.register_get_thread_response(cs_data)
self.register_put_thread_response(cs_data)
def test_empty(self):
"""Check that an empty update does not make any modifying requests."""
# Ensure that the default following value of False is not applied implicitly
self.register_get_user_response(self.user, subscribed_thread_ids=["test_thread"])
self.register_thread()
update_thread(self.request, "test_thread", {})
for request in httpretty.httpretty.latest_requests:
self.assertEqual(request.method, "GET")
def test_basic(self):
self.register_thread()
actual = update_thread(self.request, "test_thread", {"raw_body": "Edited body"})
......@@ -1507,16 +1515,61 @@ class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT,
)
def test_non_author_access(self, role_name):
def test_author_only_fields(self, role_name):
role = Role.objects.create(name=role_name, course_id=self.course.id)
role.users = [self.user]
self.register_thread({"user_id": str(self.user.id + 1)})
data = {field: "edited" for field in ["topic_id", "title", "raw_body"]}
data["type"] = "question"
expected_error = role_name == FORUM_ROLE_STUDENT
try:
update_thread(self.request, "test_thread", {})
update_thread(self.request, "test_thread", data)
self.assertFalse(expected_error)
except PermissionDenied:
except ValidationError as err:
self.assertTrue(expected_error)
self.assertEqual(
err.message_dict,
{field: ["This field is not editable."] for field in data.keys()}
)
@ddt.data(*itertools.product([True, False], [True, False]))
@ddt.unpack
def test_following(self, old_following, new_following):
"""
Test attempts to edit the "following" field.
old_following indicates whether the thread should be followed at the
start of the test. new_following indicates the value for the "following"
field in the update. If old_following and new_following are the same, no
update should be made. Otherwise, a subscription should be POSTed or
DELETEd according to the new_following value.
"""
if old_following:
self.register_get_user_response(self.user, subscribed_thread_ids=["test_thread"])
self.register_subscription_response(self.user)
self.register_thread()
data = {"following": new_following}
result = update_thread(self.request, "test_thread", data)
self.assertEqual(result["following"], new_following)
last_request_path = urlparse(httpretty.last_request().path).path
subscription_url = "/api/v1/users/{}/subscriptions".format(self.user.id)
if old_following == new_following:
self.assertNotEqual(last_request_path, subscription_url)
else:
self.assertEqual(last_request_path, subscription_url)
self.assertEqual(
httpretty.last_request().method,
"POST" if new_following else "DELETE"
)
request_data = (
httpretty.last_request().parsed_body if new_following else
parse_qs(urlparse(httpretty.last_request().path).query)
)
request_data.pop("request_id", None)
self.assertEqual(
request_data,
{"source_type": ["thread"], "source_id": ["test_thread"]}
)
def test_invalid_field(self):
self.register_thread()
......
......@@ -149,14 +149,16 @@ class CommentsServiceMockMixin(object):
def register_subscription_response(self, user):
"""
Register a mock response for POST on the CS user subscription endpoint
"""
httpretty.register_uri(
httpretty.POST,
"http://localhost:4567/api/v1/users/{id}/subscriptions".format(id=user.id),
body=json.dumps({}), # body is unused
status=200
)
Register a mock response for POST and DELETE on the CS user subscription
endpoint
"""
for method in [httpretty.POST, httpretty.DELETE]:
httpretty.register_uri(
method,
"http://localhost:4567/api/v1/users/{id}/subscriptions".format(id=user.id),
body=json.dumps({}), # body is unused
status=200
)
def assert_query_params_equal(self, httpretty_request, expected_params):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment