Commit df3e6535 by Greg Price

Add "following" parameter to thread creation

This allows authors to follow the thread immediately upon creation.
parent 93178bc3
...@@ -10,6 +10,7 @@ from opaque_keys import InvalidKeyError ...@@ -10,6 +10,7 @@ from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
from courseware.courses import get_course_with_access from courseware.courses import get_course_with_access
from discussion_api.forms import ThreadCreateExtrasForm
from discussion_api.pagination import get_paginated_data from discussion_api.pagination import get_paginated_data
from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context
from django_comment_client.base.views import ( from django_comment_client.base.views import (
...@@ -243,17 +244,24 @@ def create_thread(request, thread_data): ...@@ -243,17 +244,24 @@ def create_thread(request, thread_data):
context = get_context(course, request) context = get_context(course, request)
serializer = ThreadSerializer(data=thread_data, context=context) serializer = ThreadSerializer(data=thread_data, context=context)
if not serializer.is_valid(): extras_form = ThreadCreateExtrasForm(thread_data)
raise ValidationError(serializer.errors) if not (serializer.is_valid() and extras_form.is_valid()):
raise ValidationError(dict(serializer.errors.items() + extras_form.errors.items()))
serializer.save() serializer.save()
thread = serializer.object thread = serializer.object
ret = serializer.data
following = extras_form.cleaned_data["following"]
if following:
context["cc_requester"].follow(thread)
ret["following"] = True
track_forum_event( track_forum_event(
request, request,
THREAD_CREATED_EVENT_NAME, THREAD_CREATED_EVENT_NAME,
course, course,
thread, thread,
get_thread_created_event_data(thread, followed=False) get_thread_created_event_data(thread, followed=following)
) )
return serializer.data return serializer.data
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Discussion API forms Discussion API forms
""" """
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.forms import CharField, Form, IntegerField, NullBooleanField from django.forms import BooleanField, CharField, Form, IntegerField, NullBooleanField
from opaque_keys import InvalidKeyError from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
...@@ -37,6 +37,14 @@ class ThreadListGetForm(_PaginationForm): ...@@ -37,6 +37,14 @@ class ThreadListGetForm(_PaginationForm):
raise ValidationError("'{}' is not a valid course id".format(value)) raise ValidationError("'{}' is not a valid course id".format(value))
class ThreadCreateExtrasForm(Form):
"""
A form to handle fields in thread creation that require separate
interactions with the comments service.
"""
following = BooleanField(required=False)
class CommentListGetForm(_PaginationForm): class CommentListGetForm(_PaginationForm):
""" """
A form to validate query parameters in the comment list retrieval endpoint A form to validate query parameters in the comment list retrieval endpoint
......
...@@ -3,6 +3,7 @@ Tests for Discussion API internal interface ...@@ -3,6 +3,7 @@ Tests for Discussion API internal interface
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta
import itertools import itertools
from urlparse import urlparse
import ddt import ddt
import httpretty import httpretty
...@@ -1066,6 +1067,23 @@ class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC ...@@ -1066,6 +1067,23 @@ class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC
} }
) )
def test_following(self):
self.register_post_thread_response({"id": "test_id"})
self.register_subscription_response(self.user)
data = self.minimal_data.copy()
data["following"] = "True"
result = create_thread(self.request, data)
self.assertEqual(result["following"], True)
cs_request = httpretty.last_request()
self.assertEqual(
urlparse(cs_request.path).path,
"/api/v1/users/{}/subscriptions".format(self.user.id)
)
self.assertEqual(
cs_request.parsed_body,
{"source_type": ["thread"], "source_id": ["test_id"]}
)
def test_course_id_missing(self): def test_course_id_missing(self):
with self.assertRaises(ValidationError) as assertion: with self.assertRaises(ValidationError) as assertion:
create_thread(self.request, {}) create_thread(self.request, {})
......
...@@ -74,6 +74,17 @@ class CommentsServiceMockMixin(object): ...@@ -74,6 +74,17 @@ class CommentsServiceMockMixin(object):
status=200 status=200
) )
def register_subscription_response(self, user):
"""
Register a mock response for POST on the CS user subscription endpoint
"""
httpretty.register_uri(
httpretty.POST,
"http://localhost:4567/api/v1/users/{id}/subscriptions".format(id=user.id),
body=json.dumps({}), # body is unused
status=200
)
def assert_query_params_equal(self, httpretty_request, expected_params): def assert_query_params_equal(self, httpretty_request, expected_params):
""" """
Assert that the given mock request had the expected query parameters Assert that the given mock request had the expected query parameters
......
...@@ -96,6 +96,9 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -96,6 +96,9 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
* raw_body (required): The thread's raw body text * raw_body (required): The thread's raw body text
* following (optional): A boolean indicating whether the user should
follow the thread upon its creation; defaults to false
**GET Response Values**: **GET Response Values**:
* results: The list of threads; each item in the list has the same * results: The list of threads; each item in the list has the same
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment