Commit 73d89131 by Greg Price

Add thread editing to discussion API

This is done via PATCH on a thread instance endpoint.
parent 7c0709ce
"""
Discussion API internal interface
"""
from collections import defaultdict
from urllib import urlencode
from urlparse import urlunparse
......@@ -8,11 +9,10 @@ from django.core.exceptions import ValidationError
from django.core.urlresolvers import reverse
from django.http import Http404
from collections import defaultdict
from rest_framework.exceptions import PermissionDenied
from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseLocator
from opaque_keys.edx.locator import CourseKey
from courseware.courses import get_course_with_access
from discussion_api.forms import ThreadCreateExtrasForm
......@@ -28,7 +28,7 @@ from django_comment_client.base.views import (
from django_comment_client.utils import get_accessible_discussion_modules
from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.utils import CommentClientRequestError
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, is_commentable_cohorted
def _get_course_or_404(course_key, user):
......@@ -43,6 +43,37 @@ def _get_course_or_404(course_key, user):
return course
def _get_thread_and_context(request, thread_id, parent_id=None, retrieve_kwargs=None):
"""
Retrieve the given thread and build a serializer context for it, returning
both. This function also enforces access control for the thread (checking
both the user's access to the course and to the thread's cohort if
applicable). Raises Http404 if the thread does not exist or the user cannot
access it.
"""
retrieve_kwargs = retrieve_kwargs or {}
try:
if "mark_as_read" not in retrieve_kwargs:
retrieve_kwargs["mark_as_read"] = False
cc_thread = Thread(id=thread_id).retrieve(**retrieve_kwargs)
course_key = CourseKey.from_string(cc_thread["course_id"])
course = _get_course_or_404(course_key, request.user)
context = get_context(course, request, cc_thread, parent_id)
if (
not context["is_requester_privileged"] and
cc_thread["group_id"] and
is_commentable_cohorted(course.id, cc_thread["commentable_id"])
):
requester_cohort = get_cohort_id(request.user, course.id)
if requester_cohort is not None and cc_thread["group_id"] != requester_cohort:
raise Http404
return cc_thread, context
except CommentClientRequestError:
# params are validated at a higher level, so the only possible request
# error is if the thread doesn't exist
raise Http404
def get_thread_list_url(request, course_key, topic_id_list):
"""
Returns the URL for the thread_list_url field, given a list of topic_ids
......@@ -191,28 +222,17 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
discussion_api.views.CommentViewSet for more detail.
"""
response_skip = page_size * (page - 1)
try:
cc_thread = Thread(id=thread_id).retrieve(
recursive=True,
user_id=request.user.id,
mark_as_read=True,
response_skip=response_skip,
response_limit=page_size
)
except CommentClientRequestError:
# page and page_size are validated at a higher level, so the only
# possible request error is if the thread doesn't exist
raise Http404
course_key = CourseLocator.from_string(cc_thread["course_id"])
course = _get_course_or_404(course_key, request.user)
context = get_context(course, request, cc_thread)
# Ensure user has access to the thread
if not context["is_requester_privileged"] and cc_thread["group_id"]:
requester_cohort = get_cohort_id(request.user, course_key)
if requester_cohort is not None and cc_thread["group_id"] != requester_cohort:
raise Http404
cc_thread, context = _get_thread_and_context(
request,
thread_id,
retrieve_kwargs={
"recursive": True,
"user_id": request.user.id,
"mark_as_read": True,
"response_skip": response_skip,
"response_limit": page_size,
}
)
# Responses to discussion threads cannot be separated by endorsed, but
# responses to question threads must be separated by endorsed due to the
......@@ -267,7 +287,7 @@ def create_thread(request, thread_data):
if not course_id:
raise ValidationError({"course_id": ["This field is required."]})
try:
course_key = CourseLocator.from_string(course_id)
course_key = CourseKey.from_string(course_id)
course = _get_course_or_404(course_key, request.user)
except (Http404, InvalidKeyError):
raise ValidationError({"course_id": ["Invalid value."]})
......@@ -314,29 +334,55 @@ def create_comment(request, comment_data):
detail.
"""
thread_id = comment_data.get("thread_id")
parent_id = comment_data.get("parent_id")
if not thread_id:
raise ValidationError({"thread_id": ["This field is required."]})
try:
thread = Thread(id=thread_id).retrieve(mark_as_read=False)
course_key = CourseLocator.from_string(thread["course_id"])
course = _get_course_or_404(course_key, request.user)
except (Http404, CommentClientRequestError):
cc_thread, context = _get_thread_and_context(request, thread_id, parent_id)
except Http404:
raise ValidationError({"thread_id": ["Invalid value."]})
parent_id = comment_data.get("parent_id")
context = get_context(course, request, thread, parent_id)
serializer = CommentSerializer(data=comment_data, context=context)
if not serializer.is_valid():
raise ValidationError(serializer.errors)
serializer.save()
comment = serializer.object
cc_comment = serializer.object
track_forum_event(
request,
get_comment_created_event_name(comment),
course,
comment,
get_comment_created_event_data(comment, thread["commentable_id"], followed=False)
get_comment_created_event_name(cc_comment),
context["course"],
cc_comment,
get_comment_created_event_data(cc_comment, cc_thread["commentable_id"], followed=False)
)
return serializer.data
def update_thread(request, thread_id, update_data):
"""
Update a thread.
Parameters:
request: The django request object used for build_absolute_uri and
determining the requesting user.
thread_id: The id for the thread to update.
update_data: The data to update in the thread.
Returns:
The updated thread; see discussion_api.views.ThreadViewSet for more
detail.
"""
cc_thread, context = _get_thread_and_context(request, thread_id)
is_author = str(request.user.id) == cc_thread["user_id"]
if not (context["is_requester_privileged"] or is_author):
raise PermissionDenied()
serializer = ThreadSerializer(cc_thread, data=update_data, partial=True, context=context)
if not serializer.is_valid():
raise ValidationError(serializer.errors)
serializer.save()
return serializer.data
......@@ -21,6 +21,7 @@ from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.user import User as CommentClientUser
from lms.lib.comment_client.utils import CommentClientRequestError
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names
from openedx.core.lib.api.fields import NonEmptyCharField
def get_context(course, request, thread=None, parent_id=None):
......@@ -44,15 +45,16 @@ def get_context(course, request, thread=None, parent_id=None):
}
requester = request.user
return {
# For now, the only groups are cohorts
"course": course,
"request": request,
"thread": thread,
"parent_id": parent_id,
# For now, the only groups are cohorts
"group_ids_to_names": get_cohort_names(course),
"is_requester_privileged": requester.id in staff_user_ids or requester.id in ta_user_ids,
"staff_user_ids": staff_user_ids,
"ta_user_ids": ta_user_ids,
"cc_requester": CommentClientUser.from_django_user(requester).retrieve(),
"thread": thread,
"parent_id": parent_id,
}
......@@ -63,7 +65,7 @@ class _ContentSerializer(serializers.Serializer):
author_label = serializers.SerializerMethodField("get_author_label")
created_at = serializers.CharField(read_only=True)
updated_at = serializers.CharField(read_only=True)
raw_body = serializers.CharField(source="body")
raw_body = NonEmptyCharField(source="body")
abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged")
voted = serializers.SerializerMethodField("get_voted")
vote_count = serializers.SerializerMethodField("get_vote_count")
......@@ -138,14 +140,14 @@ class ThreadSerializer(_ContentSerializer):
at introspection and Thread's __getattr__.
"""
course_id = serializers.CharField()
topic_id = serializers.CharField(source="commentable_id")
topic_id = NonEmptyCharField(source="commentable_id")
group_id = serializers.IntegerField(read_only=True)
group_name = serializers.SerializerMethodField("get_group_name")
type_ = serializers.ChoiceField(
source="thread_type",
choices=[(val, val) for val in ["discussion", "question"]]
)
title = serializers.CharField()
title = NonEmptyCharField()
pinned = serializers.BooleanField(read_only=True)
closed = serializers.BooleanField(read_only=True)
following = serializers.SerializerMethodField("get_following")
......@@ -198,10 +200,19 @@ class ThreadSerializer(_ContentSerializer):
"""Returns the URL to retrieve the thread's non-endorsed comments."""
return self.get_comment_list_url(obj, endorsed=False)
def validate_course_id(self, attrs, _source):
"""Ensure that course_id is not edited in an update operation."""
if self.object:
raise ValidationError("This field is not allowed in an update.")
return attrs
def restore_object(self, attrs, instance=None):
if instance:
raise ValueError("ThreadSerializer cannot be used for updates.")
return Thread(user_id=self.context["cc_requester"]["id"], **attrs)
for key, val in attrs.items():
instance[key] = val
return instance
else:
return Thread(user_id=self.context["cc_requester"]["id"], **attrs)
class CommentSerializer(_ContentSerializer):
......
......@@ -23,6 +23,7 @@ from django_comment_common.models import (
FORUM_ROLE_STUDENT,
Role,
)
from lms.lib.comment_client.thread import Thread
from student.tests.factories import UserFactory
from util.testing import UrlResetMixin
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
......@@ -378,7 +379,6 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
self.register_post_thread_response({"id": "test_id"})
self.course = CourseFactory.create()
self.user = UserFactory.create()
self.register_get_user_response(self.user)
......@@ -391,18 +391,34 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
"title": "Test Title",
"raw_body": "Test body",
}
self.existing_thread = Thread(**make_minimal_cs_thread({
"id": "existing_thread",
"course_id": unicode(self.course.id),
"commentable_id": "original_topic",
"thread_type": "discussion",
"title": "Original Title",
"body": "Original body",
"user_id": str(self.user.id),
}))
def save_and_reserialize(self, data):
def save_and_reserialize(self, data, instance=None):
"""
Create a serializer with the given data, ensure that it is valid, save
the result, and return the full thread data from the serializer.
Create a serializer with the given data and (if updating) instance,
ensure that it is valid, save the result, and return the full thread
data from the serializer.
"""
serializer = ThreadSerializer(data=data, context=get_context(self.course, self.request))
serializer = ThreadSerializer(
instance,
data=data,
partial=(instance is not None),
context=get_context(self.course, self.request)
)
self.assertTrue(serializer.is_valid())
serializer.save()
return serializer.data
def test_minimal(self):
def test_create_minimal(self):
self.register_post_thread_response({"id": "test_id"})
saved = self.save_and_reserialize(self.minimal_data)
self.assertEqual(
urlparse(httpretty.last_request().path).path,
......@@ -421,7 +437,7 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
)
self.assertEqual(saved["id"], "test_id")
def test_missing_field(self):
def test_create_missing_field(self):
for field in self.minimal_data:
data = self.minimal_data.copy()
data.pop(field)
......@@ -432,7 +448,8 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
{field: ["This field is required."]}
)
def test_type(self):
def test_create_type(self):
self.register_post_thread_response({"id": "test_id"})
data = self.minimal_data.copy()
data["type"] = "question"
self.save_and_reserialize(data)
......@@ -441,6 +458,76 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
serializer = ThreadSerializer(data=data)
self.assertFalse(serializer.is_valid())
def test_update_empty(self):
self.register_put_thread_response(self.existing_thread.attributes)
self.save_and_reserialize({}, self.existing_thread)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"course_id": [unicode(self.course.id)],
"commentable_id": ["original_topic"],
"thread_type": ["discussion"],
"title": ["Original Title"],
"body": ["Original body"],
"anonymous": ["False"],
"anonymous_to_peers": ["False"],
"closed": ["False"],
"pinned": ["False"],
"user_id": [str(self.user.id)],
}
)
def test_update_all(self):
self.register_put_thread_response(self.existing_thread.attributes)
data = {
"topic_id": "edited_topic",
"type": "question",
"title": "Edited Title",
"raw_body": "Edited body",
}
saved = self.save_and_reserialize(data, self.existing_thread)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"course_id": [unicode(self.course.id)],
"commentable_id": ["edited_topic"],
"thread_type": ["question"],
"title": ["Edited Title"],
"body": ["Edited body"],
"anonymous": ["False"],
"anonymous_to_peers": ["False"],
"closed": ["False"],
"pinned": ["False"],
"user_id": [str(self.user.id)],
}
)
for key in data:
self.assertEqual(saved[key], data[key])
def test_update_empty_string(self):
serializer = ThreadSerializer(
self.existing_thread,
data={field: "" for field in ["topic_id", "title", "raw_body"]},
partial=True,
context=get_context(self.course, self.request)
)
self.assertEqual(
serializer.errors,
{field: ["This field is required."] for field in ["topic_id", "title", "raw_body"]}
)
def test_update_course_id(self):
serializer = ThreadSerializer(
self.existing_thread,
data={"course_id": "some/other/course"},
partial=True,
context=get_context(self.course, self.request)
)
self.assertEqual(
serializer.errors,
{"course_id": ["This field is not allowed in an update."]}
)
@ddt.ddt
class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStoreTestCase):
......
......@@ -11,6 +11,8 @@ from pytz import UTC
from django.core.urlresolvers import reverse
from rest_framework.test import APIClient
from discussion_api.tests.utils import CommentsServiceMockMixin, make_minimal_cs_thread
from student.tests.factories import CourseEnrollmentFactory, UserFactory
from util.testing import UrlResetMixin
......@@ -25,6 +27,8 @@ class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin):
in the test client, utility functions, and a test case for unauthenticated
requests. Subclasses must set self.url in their setUp methods.
"""
client_class = APIClient
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
def setUp(self):
super(DiscussionAPIViewTestMixin, self).setUp()
......@@ -295,6 +299,101 @@ class ThreadViewSetCreateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
@httpretty.activate
class ThreadViewSetPartialUpdateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for ThreadViewSet partial_update"""
def setUp(self):
super(ThreadViewSetPartialUpdateTest, self).setUp()
self.url = reverse("thread-detail", kwargs={"thread_id": "test_thread"})
def test_basic(self):
self.register_get_user_response(self.user)
cs_thread = make_minimal_cs_thread({
"id": "test_thread",
"course_id": unicode(self.course.id),
"commentable_id": "original_topic",
"username": self.user.username,
"user_id": str(self.user.id),
"created_at": "2015-05-29T00:00:00Z",
"updated_at": "2015-05-29T00:00:00Z",
"thread_type": "discussion",
"title": "Original Title",
"body": "Original body",
})
self.register_get_thread_response(cs_thread)
self.register_put_thread_response(cs_thread)
request_data = {"raw_body": "Edited body"}
expected_response_data = {
"id": "test_thread",
"course_id": unicode(self.course.id),
"topic_id": "original_topic",
"group_id": None,
"group_name": None,
"author": self.user.username,
"author_label": None,
"created_at": "2015-05-29T00:00:00Z",
"updated_at": "2015-05-29T00:00:00Z",
"type": "discussion",
"title": "Original Title",
"raw_body": "Edited body",
"pinned": False,
"closed": False,
"following": False,
"abuse_flagged": False,
"voted": False,
"vote_count": 0,
"comment_count": 0,
"unread_comment_count": 0,
"comment_list_url": "http://testserver/api/discussion/v1/comments/?thread_id=test_thread",
"endorsed_comment_list_url": None,
"non_endorsed_comment_list_url": None,
}
response = self.client.patch( # pylint: disable=no-member
self.url,
json.dumps(request_data),
content_type="application/json"
)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content)
self.assertEqual(response_data, expected_response_data)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"course_id": [unicode(self.course.id)],
"commentable_id": ["original_topic"],
"thread_type": ["discussion"],
"title": ["Original Title"],
"body": ["Edited body"],
"user_id": [str(self.user.id)],
"anonymous": ["False"],
"anonymous_to_peers": ["False"],
"closed": ["False"],
"pinned": ["False"],
}
)
def test_error(self):
self.register_get_user_response(self.user)
cs_thread = make_minimal_cs_thread({
"id": "test_thread",
"course_id": unicode(self.course.id),
"user_id": str(self.user.id),
})
self.register_get_thread_response(cs_thread)
request_data = {"title": ""}
response = self.client.patch( # pylint: disable=no-member
self.url,
json.dumps(request_data),
content_type="application/json"
)
expected_response_data = {
"field_errors": {"title": {"developer_message": "This field is required."}}
}
self.assertEqual(response.status_code, 400)
response_data = json.loads(response.content)
self.assertEqual(response_data, expected_response_data)
@httpretty.activate
class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for CommentViewSet list"""
def setUp(self):
......
......@@ -7,6 +7,29 @@ import re
import httpretty
def _get_thread_callback(thread_data):
"""
Get a callback function that will return POST/PUT data overridden by
response_overrides.
"""
def callback(request, _uri, headers):
"""
Simulate the thread creation or update endpoint by returning the provided
data along with the data from response_overrides and dummy values for any
additional required fields.
"""
response_data = make_minimal_cs_thread(thread_data)
for key, val_list in request.parsed_body.items():
val = val_list[0]
if key in ["anonymous", "anonymous_to_peers", "closed", "pinned"]:
response_data[key] = val == "True"
else:
response_data[key] = val
return (200, headers, json.dumps(response_data))
return callback
class CommentsServiceMockMixin(object):
"""Mixin with utility methods for mocking the comments service"""
def register_get_threads_response(self, threads, page, num_pages):
......@@ -22,23 +45,23 @@ class CommentsServiceMockMixin(object):
status=200
)
def register_post_thread_response(self, response_overrides):
def register_post_thread_response(self, thread_data):
"""Register a mock response for POST on the CS commentable endpoint"""
def callback(request, _uri, headers):
"""
Simulate the thread creation endpoint by returning the provided data
along with the data from response_overrides.
"""
response_data = make_minimal_cs_thread(
{key: val[0] for key, val in request.parsed_body.items()}
)
response_data.update(response_overrides)
return (200, headers, json.dumps(response_data))
httpretty.register_uri(
httpretty.POST,
re.compile(r"http://localhost:4567/api/v1/(\w+)/threads"),
body=callback
body=_get_thread_callback(thread_data)
)
def register_put_thread_response(self, thread_data):
"""
Register a mock response for PUT on the CS endpoint for the given
thread_id.
"""
httpretty.register_uri(
httpretty.PUT,
"http://localhost:4567/api/v1/threads/{}".format(thread_data["id"]),
body=_get_thread_callback(thread_data)
)
def register_get_thread_error_response(self, thread_id, status_code):
......
......@@ -17,6 +17,7 @@ from discussion_api.api import (
get_comment_list,
get_course_topics,
get_thread_list,
update_thread,
)
from discussion_api.forms import CommentListGetForm, ThreadListGetForm
from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin
......@@ -67,7 +68,8 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"""
**Use Cases**
Retrieve the list of threads for a course or post a new thread.
Retrieve the list of threads for a course, post a new thread, or modify
an existing thread.
**Example Requests**:
......@@ -82,6 +84,9 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"body": "Body text"
}
PATCH /api/discussion/v1/threads/thread_id
{"raw_body": "Edited text"}
**GET Parameters**:
* course_id (required): The course to retrieve threads for
......@@ -109,16 +114,21 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
* following (optional): A boolean indicating whether the user should
follow the thread upon its creation; defaults to false
**PATCH Parameters**:
topic_id, type, title, and raw_body are accepted with the same meaning
as in a POST request
**GET Response Values**:
* results: The list of threads; each item in the list has the same
fields as the POST response below
fields as the POST/PATCH response below
* next: The URL of the next page (or null if first page)
* previous: The URL of the previous page (or null if last page)
**POST response values**:
**POST/PATCH response values**:
* id: The id of the thread
......@@ -148,6 +158,8 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
the thread
"""
lookup_field = "thread_id"
def list(self, request):
"""
Implements the GET method for the list endpoint as described in the
......@@ -173,6 +185,13 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"""
return Response(create_thread(request, request.DATA))
def partial_update(self, request, thread_id):
"""
Implements the PATCH method for the instance endpoint as described in
the class docstring.
"""
return Response(update_thread(request, thread_id, request.DATA))
class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"""
......
"""Fields useful for edX API implementations."""
from django.core.exceptions import ValidationError
from rest_framework.serializers import Field
from rest_framework.serializers import CharField, Field
class ExpandableField(Field):
......@@ -20,3 +21,17 @@ class ExpandableField(Field):
else:
self.collapsed.initialize(self, field_name)
return self.collapsed.field_to_native(obj, field_name)
class NonEmptyCharField(CharField):
"""
A field that enforces non-emptiness even for partial updates.
This is necessary because prior to version 3, DRF skips validation for empty
values. Thus, CharField's min_length and RegexField cannot be used to
enforce this constraint.
"""
def validate(self, value):
super(NonEmptyCharField, self).validate(value)
if not value:
raise ValidationError(self.error_messages["required"])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment