Commit 9ff87497 by Greg Price

Add ability to follow a thread in discussion API

parent 298ba727
...@@ -9,13 +9,11 @@ from django.core.exceptions import ValidationError ...@@ -9,13 +9,11 @@ from django.core.exceptions import ValidationError
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.http import Http404 from django.http import Http404
from rest_framework.exceptions import PermissionDenied
from opaque_keys import InvalidKeyError from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseKey from opaque_keys.edx.locator import CourseKey
from courseware.courses import get_course_with_access from courseware.courses import get_course_with_access
from discussion_api.forms import ThreadCreateExtrasForm from discussion_api.forms import ThreadActionsForm
from discussion_api.pagination import get_paginated_data from discussion_api.pagination import get_paginated_data
from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context
from django_comment_client.base.views import ( from django_comment_client.base.views import (
...@@ -267,6 +265,20 @@ def get_comment_list(request, thread_id, endorsed, page, page_size): ...@@ -267,6 +265,20 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
return get_paginated_data(request, results, page, num_pages) return get_paginated_data(request, results, page, num_pages)
def _do_extra_thread_actions(api_thread, cc_thread, request_fields, actions_form, context):
"""
Perform any necessary additional actions related to thread creation or
update that require a separate comments service request.
"""
form_following = actions_form.cleaned_data["following"]
if "following" in request_fields and form_following != api_thread["following"]:
if form_following:
context["cc_requester"].follow(cc_thread)
else:
context["cc_requester"].unfollow(cc_thread)
api_thread["following"] = form_following
def create_thread(request, thread_data): def create_thread(request, thread_data):
""" """
Create a thread. Create a thread.
...@@ -294,27 +306,24 @@ def create_thread(request, thread_data): ...@@ -294,27 +306,24 @@ def create_thread(request, thread_data):
context = get_context(course, request) context = get_context(course, request)
serializer = ThreadSerializer(data=thread_data, context=context) serializer = ThreadSerializer(data=thread_data, context=context)
extras_form = ThreadCreateExtrasForm(thread_data) actions_form = ThreadActionsForm(thread_data)
if not (serializer.is_valid() and extras_form.is_valid()): if not (serializer.is_valid() and actions_form.is_valid()):
raise ValidationError(dict(serializer.errors.items() + extras_form.errors.items())) raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
serializer.save() serializer.save()
thread = serializer.object cc_thread = serializer.object
ret = serializer.data api_thread = serializer.data
following = extras_form.cleaned_data["following"] _do_extra_thread_actions(api_thread, cc_thread, thread_data.keys(), actions_form, context)
if following:
context["cc_requester"].follow(thread)
ret["following"] = True
track_forum_event( track_forum_event(
request, request,
THREAD_CREATED_EVENT_NAME, THREAD_CREATED_EVENT_NAME,
course, course,
thread, cc_thread,
get_thread_created_event_data(thread, followed=following) get_thread_created_event_data(cc_thread, followed=actions_form.cleaned_data["following"])
) )
return ret return api_thread
def create_comment(request, comment_data): def create_comment(request, comment_data):
...@@ -359,6 +368,21 @@ def create_comment(request, comment_data): ...@@ -359,6 +368,21 @@ def create_comment(request, comment_data):
return serializer.data return serializer.data
_THREAD_EDITABLE_BY_ANY = {"following"}
_THREAD_EDITABLE_BY_AUTHOR = {"topic_id", "type", "title", "raw_body"} | _THREAD_EDITABLE_BY_ANY
def _get_thread_editable_fields(cc_thread, context):
"""
Get the list of editable fields for the given thread in the given context
"""
is_author = context["cc_requester"]["id"] == cc_thread["user_id"]
if context["is_requester_privileged"] or is_author:
return _THREAD_EDITABLE_BY_AUTHOR
else:
return _THREAD_EDITABLE_BY_ANY
def update_thread(request, thread_id, update_data): def update_thread(request, thread_id, update_data):
""" """
Update a thread. Update a thread.
...@@ -378,11 +402,21 @@ def update_thread(request, thread_id, update_data): ...@@ -378,11 +402,21 @@ def update_thread(request, thread_id, update_data):
detail. detail.
""" """
cc_thread, context = _get_thread_and_context(request, thread_id) cc_thread, context = _get_thread_and_context(request, thread_id)
is_author = str(request.user.id) == cc_thread["user_id"] editable_fields = _get_thread_editable_fields(cc_thread, context)
if not (context["is_requester_privileged"] or is_author): non_editable_errors = {
raise PermissionDenied() field: ["This field is not editable."]
for field in update_data.keys()
if field not in editable_fields
}
if non_editable_errors:
raise ValidationError(non_editable_errors)
serializer = ThreadSerializer(cc_thread, data=update_data, partial=True, context=context) serializer = ThreadSerializer(cc_thread, data=update_data, partial=True, context=context)
if not serializer.is_valid(): actions_form = ThreadActionsForm(update_data)
raise ValidationError(serializer.errors) if not (serializer.is_valid() and actions_form.is_valid()):
serializer.save() raise ValidationError(dict(serializer.errors.items() + actions_form.errors.items()))
return serializer.data # Only save thread object if some of the edited fields are in the thread data, not extra actions
if set(update_data) - set(actions_form.fields):
serializer.save()
api_thread = serializer.data
_do_extra_thread_actions(api_thread, cc_thread, update_data.keys(), actions_form, context)
return api_thread
...@@ -57,7 +57,7 @@ class ThreadListGetForm(_PaginationForm): ...@@ -57,7 +57,7 @@ class ThreadListGetForm(_PaginationForm):
raise ValidationError("'{}' is not a valid course id".format(value)) raise ValidationError("'{}' is not a valid course id".format(value))
class ThreadCreateExtrasForm(Form): class ThreadActionsForm(Form):
""" """
A form to handle fields in thread creation that require separate A form to handle fields in thread creation that require separate
interactions with the comments service. interactions with the comments service.
......
...@@ -3,7 +3,7 @@ Tests for Discussion API internal interface ...@@ -3,7 +3,7 @@ Tests for Discussion API internal interface
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta
import itertools import itertools
from urlparse import urlparse, urlunparse from urlparse import parse_qs, urlparse, urlunparse
from urllib import urlencode from urllib import urlencode
import ddt import ddt
...@@ -15,8 +15,6 @@ from django.core.exceptions import ValidationError ...@@ -15,8 +15,6 @@ from django.core.exceptions import ValidationError
from django.http import Http404 from django.http import Http404
from django.test.client import RequestFactory from django.test.client import RequestFactory
from rest_framework.exceptions import PermissionDenied
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
from courseware.tests.factories import BetaTesterFactory, StaffFactory from courseware.tests.factories import BetaTesterFactory, StaffFactory
...@@ -1132,6 +1130,7 @@ class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC ...@@ -1132,6 +1130,7 @@ class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC
urlparse(cs_request.path).path, urlparse(cs_request.path).path,
"/api/v1/users/{}/subscriptions".format(self.user.id) "/api/v1/users/{}/subscriptions".format(self.user.id)
) )
self.assertEqual(cs_request.method, "POST")
self.assertEqual( self.assertEqual(
cs_request.parsed_body, cs_request.parsed_body,
{"source_type": ["thread"], "source_id": ["test_id"]} {"source_type": ["thread"], "source_id": ["test_id"]}
...@@ -1396,6 +1395,15 @@ class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC ...@@ -1396,6 +1395,15 @@ class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC
self.register_get_thread_response(cs_data) self.register_get_thread_response(cs_data)
self.register_put_thread_response(cs_data) self.register_put_thread_response(cs_data)
def test_empty(self):
"""Check that an empty update does not make any modifying requests."""
# Ensure that the default following value of False is not applied implicitly
self.register_get_user_response(self.user, subscribed_thread_ids=["test_thread"])
self.register_thread()
update_thread(self.request, "test_thread", {})
for request in httpretty.httpretty.latest_requests:
self.assertEqual(request.method, "GET")
def test_basic(self): def test_basic(self):
self.register_thread() self.register_thread()
actual = update_thread(self.request, "test_thread", {"raw_body": "Edited body"}) actual = update_thread(self.request, "test_thread", {"raw_body": "Edited body"})
...@@ -1507,16 +1515,61 @@ class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC ...@@ -1507,16 +1515,61 @@ class UpdateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC
FORUM_ROLE_COMMUNITY_TA, FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT, FORUM_ROLE_STUDENT,
) )
def test_non_author_access(self, role_name): def test_author_only_fields(self, role_name):
role = Role.objects.create(name=role_name, course_id=self.course.id) role = Role.objects.create(name=role_name, course_id=self.course.id)
role.users = [self.user] role.users = [self.user]
self.register_thread({"user_id": str(self.user.id + 1)}) self.register_thread({"user_id": str(self.user.id + 1)})
data = {field: "edited" for field in ["topic_id", "title", "raw_body"]}
data["type"] = "question"
expected_error = role_name == FORUM_ROLE_STUDENT expected_error = role_name == FORUM_ROLE_STUDENT
try: try:
update_thread(self.request, "test_thread", {}) update_thread(self.request, "test_thread", data)
self.assertFalse(expected_error) self.assertFalse(expected_error)
except PermissionDenied: except ValidationError as err:
self.assertTrue(expected_error) self.assertTrue(expected_error)
self.assertEqual(
err.message_dict,
{field: ["This field is not editable."] for field in data.keys()}
)
@ddt.data(*itertools.product([True, False], [True, False]))
@ddt.unpack
def test_following(self, old_following, new_following):
"""
Test attempts to edit the "following" field.
old_following indicates whether the thread should be followed at the
start of the test. new_following indicates the value for the "following"
field in the update. If old_following and new_following are the same, no
update should be made. Otherwise, a subscription should be POSTed or
DELETEd according to the new_following value.
"""
if old_following:
self.register_get_user_response(self.user, subscribed_thread_ids=["test_thread"])
self.register_subscription_response(self.user)
self.register_thread()
data = {"following": new_following}
result = update_thread(self.request, "test_thread", data)
self.assertEqual(result["following"], new_following)
last_request_path = urlparse(httpretty.last_request().path).path
subscription_url = "/api/v1/users/{}/subscriptions".format(self.user.id)
if old_following == new_following:
self.assertNotEqual(last_request_path, subscription_url)
else:
self.assertEqual(last_request_path, subscription_url)
self.assertEqual(
httpretty.last_request().method,
"POST" if new_following else "DELETE"
)
request_data = (
httpretty.last_request().parsed_body if new_following else
parse_qs(urlparse(httpretty.last_request().path).query)
)
request_data.pop("request_id", None)
self.assertEqual(
request_data,
{"source_type": ["thread"], "source_id": ["test_thread"]}
)
def test_invalid_field(self): def test_invalid_field(self):
self.register_thread() self.register_thread()
......
...@@ -149,14 +149,16 @@ class CommentsServiceMockMixin(object): ...@@ -149,14 +149,16 @@ class CommentsServiceMockMixin(object):
def register_subscription_response(self, user): def register_subscription_response(self, user):
""" """
Register a mock response for POST on the CS user subscription endpoint Register a mock response for POST and DELETE on the CS user subscription
""" endpoint
httpretty.register_uri( """
httpretty.POST, for method in [httpretty.POST, httpretty.DELETE]:
"http://localhost:4567/api/v1/users/{id}/subscriptions".format(id=user.id), httpretty.register_uri(
body=json.dumps({}), # body is unused method,
status=200 "http://localhost:4567/api/v1/users/{id}/subscriptions".format(id=user.id),
) body=json.dumps({}), # body is unused
status=200
)
def assert_query_params_equal(self, httpretty_request, expected_params): def assert_query_params_equal(self, httpretty_request, expected_params):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment