Commit f891450f by Greg Price

Add thread creation to discussion API

Also, fix the field-specific error format for all API endpoints.

This requires cs_comments_service commit fdf017c918.
parent a3cea74a
......@@ -6,11 +6,17 @@ from django.http import Http404
from collections import defaultdict
from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseLocator
from courseware.courses import get_course_with_access
from discussion_api.pagination import get_paginated_data
from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context
from django_comment_client.base.views import (
THREAD_CREATED_EVENT_NAME,
get_thread_created_event_data,
track_forum_event,
)
from django_comment_client.utils import get_accessible_discussion_modules
from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.utils import CommentClientRequestError
......@@ -208,3 +214,46 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
results = [CommentSerializer(response, context=context).data for response in responses]
return get_paginated_data(request, results, page, num_pages)
def create_thread(request, thread_data):
"""
Create a thread.
Parameters:
request: The django request object used for build_absolute_uri and
determining the requesting user.
thread_data: The data for the created thread.
Returns:
The created thread; see discussion_api.views.ThreadViewSet for more
detail.
"""
course_id = thread_data.get("course_id")
if not course_id:
raise ValidationError({"course_id": ["This field is required."]})
try:
course_key = CourseLocator.from_string(course_id)
course = _get_course_or_404(course_key, request.user)
except (Http404, InvalidKeyError):
raise ValidationError({"course_id": ["Invalid value."]})
context = get_context(course, request)
serializer = ThreadSerializer(data=thread_data, context=context)
if not serializer.is_valid():
raise ValidationError(serializer.errors)
serializer.save()
thread = serializer.object
track_forum_event(
request,
THREAD_CREATED_EVENT_NAME,
course,
thread,
get_thread_created_event_data(thread, followed=False)
)
return serializer.data
......@@ -15,6 +15,7 @@ from django_comment_common.models import (
FORUM_ROLE_MODERATOR,
Role,
)
from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.user import User as CommentClientUser
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names
......@@ -134,15 +135,18 @@ class ThreadSerializer(_ContentSerializer):
"""
course_id = serializers.CharField()
topic_id = serializers.CharField(source="commentable_id")
group_id = serializers.IntegerField()
group_id = serializers.IntegerField(read_only=True)
group_name = serializers.SerializerMethodField("get_group_name")
type_ = serializers.ChoiceField(source="thread_type", choices=("discussion", "question"))
type_ = serializers.ChoiceField(
source="thread_type",
choices=[(val, val) for val in ["discussion", "question"]]
)
title = serializers.CharField()
pinned = serializers.BooleanField()
closed = serializers.BooleanField()
pinned = serializers.BooleanField(read_only=True)
closed = serializers.BooleanField(read_only=True)
following = serializers.SerializerMethodField("get_following")
comment_count = serializers.IntegerField(source="comments_count")
unread_comment_count = serializers.IntegerField(source="unread_comments_count")
comment_count = serializers.IntegerField(source="comments_count", read_only=True)
unread_comment_count = serializers.IntegerField(source="unread_comments_count", read_only=True)
comment_list_url = serializers.SerializerMethodField("get_comment_list_url")
endorsed_comment_list_url = serializers.SerializerMethodField("get_endorsed_comment_list_url")
non_endorsed_comment_list_url = serializers.SerializerMethodField("get_non_endorsed_comment_list_url")
......@@ -190,6 +194,11 @@ class ThreadSerializer(_ContentSerializer):
"""Returns the URL to retrieve the thread's non-endorsed comments."""
return self.get_comment_list_url(obj, endorsed=False)
def restore_object(self, attrs, instance=None):
if instance:
raise ValueError("ThreadSerializer cannot be used for updates.")
return Thread(user_id=self.context["cc_requester"]["id"], **attrs)
class CommentSerializer(_ContentSerializer):
"""
......
......@@ -16,7 +16,7 @@ from django.test.client import RequestFactory
from opaque_keys.edx.locator import CourseLocator
from courseware.tests.factories import BetaTesterFactory, StaffFactory
from discussion_api.api import get_comment_list, get_course_topics, get_thread_list
from discussion_api.api import create_thread, get_comment_list, get_course_topics, get_thread_list
from discussion_api.tests.utils import (
CommentsServiceMockMixin,
make_minimal_cs_comment,
......@@ -975,3 +975,126 @@ class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
# Page past the end
with self.assertRaises(Http404):
self.get_comment_list(thread, endorsed=True, page=2, page_size=10)
class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase):
"""Tests for create_thread"""
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
def setUp(self):
super(CreateThreadTest, self).setUp()
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
self.user = UserFactory.create()
self.register_get_user_response(self.user)
self.request = RequestFactory().get("/test_path")
self.request.user = self.user
self.course = CourseFactory.create()
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
self.minimal_data = {
"course_id": unicode(self.course.id),
"topic_id": "test_topic",
"type": "discussion",
"title": "Test Title",
"raw_body": "Test body",
}
@mock.patch("eventtracking.tracker.emit")
def test_basic(self, mock_emit):
self.register_post_thread_response({
"id": "test_id",
"username": self.user.username,
"created_at": "2015-05-19T00:00:00Z",
"updated_at": "2015-05-19T00:00:00Z",
})
actual = create_thread(self.request, self.minimal_data)
expected = {
"id": "test_id",
"course_id": unicode(self.course.id),
"topic_id": "test_topic",
"group_id": None,
"group_name": None,
"author": self.user.username,
"author_label": None,
"created_at": "2015-05-19T00:00:00Z",
"updated_at": "2015-05-19T00:00:00Z",
"type": "discussion",
"title": "Test Title",
"raw_body": "Test body",
"pinned": False,
"closed": False,
"following": False,
"abuse_flagged": False,
"voted": False,
"vote_count": 0,
"comment_count": 0,
"unread_comment_count": 0,
"comment_list_url": "http://testserver/api/discussion/v1/comments/?thread_id=test_id",
"endorsed_comment_list_url": None,
"non_endorsed_comment_list_url": None,
}
self.assertEqual(actual, expected)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"course_id": [unicode(self.course.id)],
"commentable_id": ["test_topic"],
"thread_type": ["discussion"],
"title": ["Test Title"],
"body": ["Test body"],
"user_id": [str(self.user.id)],
}
)
event_name, event_data = mock_emit.call_args[0]
self.assertEqual(event_name, "edx.forum.thread.created")
self.assertEqual(
event_data,
{
"commentable_id": "test_topic",
"group_id": None,
"thread_type": "discussion",
"title": "Test Title",
"anonymous": False,
"anonymous_to_peers": False,
"options": {"followed": False},
"id": "test_id",
"truncated": False,
"body": "Test body",
"url": "",
"user_forums_roles": [FORUM_ROLE_STUDENT],
"user_course_roles": [],
}
)
def test_course_id_missing(self):
with self.assertRaises(ValidationError) as assertion:
create_thread(self.request, {})
self.assertEqual(assertion.exception.message_dict, {"course_id": ["This field is required."]})
def test_course_id_invalid(self):
with self.assertRaises(ValidationError) as assertion:
create_thread(self.request, {"course_id": "invalid!"})
self.assertEqual(assertion.exception.message_dict, {"course_id": ["Invalid value."]})
def test_nonexistent_course(self):
with self.assertRaises(ValidationError) as assertion:
create_thread(self.request, {"course_id": "non/existent/course"})
self.assertEqual(assertion.exception.message_dict, {"course_id": ["Invalid value."]})
def test_not_enrolled(self):
self.request.user = UserFactory.create()
with self.assertRaises(ValidationError) as assertion:
create_thread(self.request, self.minimal_data)
self.assertEqual(assertion.exception.message_dict, {"course_id": ["Invalid value."]})
def test_discussions_disabled(self):
_remove_discussion_tab(self.course, self.user.id)
with self.assertRaises(ValidationError) as assertion:
create_thread(self.request, self.minimal_data)
self.assertEqual(assertion.exception.message_dict, {"course_id": ["Invalid value."]})
def test_invalid_field(self):
data = self.minimal_data.copy()
data["type"] = "invalid_type"
with self.assertRaises(ValidationError):
create_thread(self.request, data)
......@@ -2,6 +2,7 @@
Tests for Discussion API serializers
"""
import itertools
from urlparse import urlparse
import ddt
import httpretty
......@@ -125,8 +126,8 @@ class SerializerTestMixin(CommentsServiceMockMixin, UrlResetMixin):
@ddt.ddt
class ThreadSerializerTest(SerializerTestMixin, ModuleStoreTestCase):
"""Tests for ThreadSerializer."""
class ThreadSerializerSerializationTest(SerializerTestMixin, ModuleStoreTestCase):
"""Tests for ThreadSerializer serialization."""
def make_cs_content(self, overrides):
"""
Create a thread with the given overrides, plus some useful test data.
......@@ -366,3 +367,76 @@ class CommentSerializerTest(SerializerTestMixin, ModuleStoreTestCase):
self.assertEqual(serialized["children"][1]["parent_id"], "test_root")
self.assertEqual(serialized["children"][1]["children"][0]["id"], "test_grandchild")
self.assertEqual(serialized["children"][1]["children"][0]["parent_id"], "test_child_2")
@ddt.ddt
class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase):
"""Tests for ThreadSerializer deserialization."""
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
def setUp(self):
super(ThreadSerializerDeserializationTest, self).setUp()
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
self.register_post_thread_response({"id": "test_id"})
self.course = CourseFactory.create()
self.user = UserFactory.create()
self.register_get_user_response(self.user)
self.request = RequestFactory().get("/dummy")
self.request.user = self.user
self.minimal_data = {
"course_id": unicode(self.course.id),
"topic_id": "test_topic",
"type": "discussion",
"title": "Test Title",
"raw_body": "Test body",
}
def save_and_reserialize(self, data):
"""
Create a serializer with the given data, ensure that it is valid, save
the result, and return the full thread data from the serializer.
"""
serializer = ThreadSerializer(data=data, context=get_context(self.course, self.request))
self.assertTrue(serializer.is_valid())
serializer.save()
return serializer.data
def test_minimal(self):
saved = self.save_and_reserialize(self.minimal_data)
self.assertEqual(
urlparse(httpretty.last_request().path).path,
"/api/v1/test_topic/threads"
)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"course_id": [unicode(self.course.id)],
"commentable_id": ["test_topic"],
"thread_type": ["discussion"],
"title": ["Test Title"],
"body": ["Test body"],
"user_id": [str(self.user.id)],
}
)
self.assertEqual(saved["id"], "test_id")
def test_missing_field(self):
for field in self.minimal_data:
data = self.minimal_data.copy()
data.pop(field)
serializer = ThreadSerializer(data=data)
self.assertFalse(serializer.is_valid())
self.assertEqual(
serializer.errors,
{field: ["This field is required."]}
)
def test_type(self):
data = self.minimal_data.copy()
data["type"] = "question"
self.save_and_reserialize(data)
data["type"] = "invalid_type"
serializer = ThreadSerializer(data=data)
self.assertFalse(serializer.is_valid())
......@@ -103,7 +103,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
self.assert_response_correct(
response,
400,
{"field_errors": {"course_id": "This field is required."}}
{"field_errors": {"course_id": {"developer_message": "This field is required."}}}
)
def test_404(self):
......@@ -205,6 +205,93 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
@httpretty.activate
class ThreadViewSetCreateTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for ThreadViewSet create"""
def setUp(self):
super(ThreadViewSetCreateTest, self).setUp()
self.url = reverse("thread-list")
def test_basic(self):
self.register_get_user_response(self.user)
self.register_post_thread_response({
"id": "test_thread",
"username": self.user.username,
"created_at": "2015-05-19T00:00:00Z",
"updated_at": "2015-05-19T00:00:00Z",
})
request_data = {
"course_id": unicode(self.course.id),
"topic_id": "test_topic",
"type": "discussion",
"title": "Test Title",
"raw_body": "Test body",
}
expected_response_data = {
"id": "test_thread",
"course_id": unicode(self.course.id),
"topic_id": "test_topic",
"group_id": None,
"group_name": None,
"author": self.user.username,
"author_label": None,
"created_at": "2015-05-19T00:00:00Z",
"updated_at": "2015-05-19T00:00:00Z",
"type": "discussion",
"title": "Test Title",
"raw_body": "Test body",
"pinned": False,
"closed": False,
"following": False,
"abuse_flagged": False,
"voted": False,
"vote_count": 0,
"comment_count": 0,
"unread_comment_count": 0,
"comment_list_url": "http://testserver/api/discussion/v1/comments/?thread_id=test_thread",
"endorsed_comment_list_url": None,
"non_endorsed_comment_list_url": None,
}
response = self.client.post(
self.url,
json.dumps(request_data),
content_type="application/json"
)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.content)
self.assertEqual(response_data, expected_response_data)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"course_id": [unicode(self.course.id)],
"commentable_id": ["test_topic"],
"thread_type": ["discussion"],
"title": ["Test Title"],
"body": ["Test body"],
"user_id": [str(self.user.id)],
}
)
def test_error(self):
request_data = {
"topic_id": "dummy",
"type": "discussion",
"title": "dummy",
"raw_body": "dummy",
}
response = self.client.post(
self.url,
json.dumps(request_data),
content_type="application/json"
)
expected_response_data = {
"field_errors": {"course_id": {"developer_message": "This field is required."}}
}
self.assertEqual(response.status_code, 400)
response_data = json.loads(response.content)
self.assertEqual(response_data, expected_response_data)
@httpretty.activate
class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for CommentViewSet list"""
def setUp(self):
......@@ -218,7 +305,7 @@ class CommentViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
self.assert_response_correct(
response,
400,
{"field_errors": {"thread_id": "This field is required."}}
{"field_errors": {"thread_id": {"developer_message": "This field is required."}}}
)
def test_404(self):
......
......@@ -2,6 +2,7 @@
Discussion API test utilities
"""
import json
import re
import httpretty
......@@ -21,6 +22,25 @@ class CommentsServiceMockMixin(object):
status=200
)
def register_post_thread_response(self, response_overrides):
"""Register a mock response for POST on the CS commentable endpoint"""
def callback(request, _uri, headers):
"""
Simulate the thread creation endpoint by returning the provided data
along with the data from response_overrides.
"""
response_data = make_minimal_cs_thread(
{key: val[0] for key, val in request.parsed_body.items()}
)
response_data.update(response_overrides)
return (200, headers, json.dumps(response_data))
httpretty.register_uri(
httpretty.POST,
re.compile(r"http://localhost:4567/api/v1/(\w+)/threads"),
body=callback
)
def register_get_thread_error_response(self, thread_id, status_code):
"""Register a mock error response for GET on the CS thread endpoint."""
httpretty.register_uri(
......
......@@ -11,7 +11,7 @@ from rest_framework.viewsets import ViewSet
from opaque_keys.edx.locator import CourseLocator
from discussion_api.api import get_comment_list, get_course_topics, get_thread_list
from discussion_api.api import create_thread, get_comment_list, get_course_topics, get_thread_list
from discussion_api.forms import CommentListGetForm, ThreadListGetForm
from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin
......@@ -61,12 +61,21 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"""
**Use Cases**
Retrieve the list of threads for a course.
Retrieve the list of threads for a course or post a new thread.
**Example Requests**:
GET /api/discussion/v1/threads/?course_id=ExampleX/Demo/2015
POST /api/discussion/v1/threads
{
"course_id": "foo/bar/baz",
"topic_id": "quux",
"type": "discussion",
"title": "Title text",
"body": "Body text"
}
**GET Parameters**:
* course_id (required): The course to retrieve threads for
......@@ -75,9 +84,28 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
* page_size: The number of items per page (default is 10, max is 100)
**Response Values**:
**POST Parameters**:
* course_id (required): The course to create the thread in
* topic_id (required): The topic to create the thread in
* type (required): The thread's type (either "question" or "discussion")
* title (required): The thread's title
* raw_body (required): The thread's raw body text
**GET Response Values**:
* results: The list of threads; each item in the list has the same
fields as the POST response below
* next: The URL of the next page (or null if first page)
* previous: The URL of the previous page (or null if last page)
* results: The list of threads. Each item in the list includes:
**POST response values**:
* id: The id of the thread
......@@ -106,9 +134,6 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
that were created or updated since the last time the user read
the thread
* next: The URL of the next page (or null if first page)
* previous: The URL of the previous page (or null if last page)
"""
def list(self, request):
"""
......@@ -127,6 +152,13 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
)
)
def create(self, request):
"""
Implements the POST method for the list endpoint as described in the
class docstring.
"""
return Response(create_thread(request, request.DATA))
class CommentViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"""
......
......@@ -37,6 +37,8 @@ log = logging.getLogger(__name__)
TRACKING_MAX_FORUM_BODY = 2000
THREAD_CREATED_EVENT_NAME = "edx.forum.thread.created"
def permitted(fn):
@functools.wraps(fn)
......@@ -97,6 +99,26 @@ def track_forum_event(request, event_name, course, obj, data, id_map=None):
tracker.emit(event_name, data)
def get_thread_created_event_data(thread, followed):
"""
Get the event data payload for thread creation (excluding fields populated
by track_forum_event)
"""
return {
'commentable_id': thread.commentable_id,
'group_id': thread.get("group_id"),
'thread_type': thread.thread_type,
'title': thread.title,
'anonymous': thread.anonymous,
'anonymous_to_peers': thread.anonymous_to_peers,
'options': {'followed': followed},
# There is a stated desire for an 'origin' property that will state
# whether this thread was created via courseware or the forum.
# However, the view does not contain that data, and including it will
# likely require changes elsewhere.
}
@require_POST
@login_required
@permitted
......@@ -156,19 +178,7 @@ def create_thread(request, course_id, commentable_id):
user = cc.User.from_django_user(request.user)
user.follow(thread)
event_data = {
'title': thread.title,
'commentable_id': commentable_id,
'options': {'followed': follow},
'anonymous': anonymous,
'thread_type': thread.thread_type,
'group_id': group_id,
'anonymous_to_peers': anonymous_to_peers,
# There is a stated desire for an 'origin' property that will state
# whether this thread was created via courseware or the forum.
# However, the view does not contain that data, and including it will
# likely require changes elsewhere.
}
event_data = get_thread_created_event_data(thread, follow)
data = thread.to_dict()
# Calls to id map are expensive, but we need this more than once.
......@@ -177,7 +187,7 @@ def create_thread(request, course_id, commentable_id):
add_courseware_context([data], course, request.user, id_map=id_map)
track_forum_event(request, 'edx.forum.thread.created',
track_forum_event(request, THREAD_CREATED_EVENT_NAME,
course, thread, event_data, id_map=id_map)
if request.is_ajax():
......
......@@ -32,7 +32,7 @@ class DeveloperErrorViewMixin(object):
response_obj["developer_message"] = non_field_error_list[0]
if message_dict:
response_obj["field_errors"] = {
field: message_dict[field][0]
field: {"developer_message": message_dict[field][0]}
for field in message_dict
}
return Response(response_obj, status=400)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment