Commit a26b00ba by Christopher Lee

Merge pull request #8656 from edx/gprice/discussion-api-set-group-id

Allow group_id to be set in discussion API
parents 1aa074da b8576e74
...@@ -17,7 +17,12 @@ from opaque_keys.edx.locator import CourseKey ...@@ -17,7 +17,12 @@ from opaque_keys.edx.locator import CourseKey
from courseware.courses import get_course_with_access from courseware.courses import get_course_with_access
from discussion_api.forms import CommentActionsForm, ThreadActionsForm from discussion_api.forms import CommentActionsForm, ThreadActionsForm
from discussion_api.pagination import get_paginated_data from discussion_api.pagination import get_paginated_data
from discussion_api.permissions import can_delete, get_editable_fields from discussion_api.permissions import (
can_delete,
get_editable_fields,
get_initializable_comment_fields,
get_initializable_thread_fields,
)
from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context
from django_comment_client.base.views import ( from django_comment_client.base.views import (
THREAD_CREATED_EVENT_NAME, THREAD_CREATED_EVENT_NAME,
...@@ -295,7 +300,7 @@ def get_comment_list(request, thread_id, endorsed, page, page_size): ...@@ -295,7 +300,7 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
""" """
Return the list of comments in the given thread. Return the list of comments in the given thread.
Parameters: Arguments:
request: The django request object used for build_absolute_uri and request: The django request object used for build_absolute_uri and
determining the requesting user. determining the requesting user.
...@@ -361,19 +366,76 @@ def get_comment_list(request, thread_id, endorsed, page, page_size): ...@@ -361,19 +366,76 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
return get_paginated_data(request, results, page, num_pages) return get_paginated_data(request, results, page, num_pages)
def _check_fields(allowed_fields, data, message):
"""
Checks that the keys given in data is in allowed_fields
Arguments:
allowed_fields (set): A set of allowed fields
data (dict): The data to compare the allowed_fields against
message (str): The message to return if there are any invalid fields
Raises:
ValidationError if the given data contains a key that is not in
allowed_fields
"""
non_allowed_fields = {field: [message] for field in data.keys() if field not in allowed_fields}
if non_allowed_fields:
raise ValidationError(non_allowed_fields)
def _check_initializable_thread_fields(data, context): # pylint: disable=invalid-name
"""
Checks if the given data contains a thread field that is not initializable
by the requesting user
Arguments:
data (dict): The data to compare the allowed_fields against
context (dict): The context appropriate for use with the thread which
includes the requesting user
Raises:
ValidationError if the given data contains a thread field that is not
initializable by the requesting user
"""
_check_fields(
get_initializable_thread_fields(context),
data,
"This field is not initializable."
)
def _check_initializable_comment_fields(data, context): # pylint: disable=invalid-name
"""
Checks if the given data contains a comment field that is not initializable
by the requesting user
Arguments:
data (dict): The data to compare the allowed_fields against
context (dict): The context appropriate for use with the comment which
includes the requesting user
Raises:
ValidationError if the given data contains a comment field that is not
initializable by the requesting user
"""
_check_fields(
get_initializable_comment_fields(context),
data,
"This field is not initializable."
)
def _check_editable_fields(cc_content, data, context): def _check_editable_fields(cc_content, data, context):
""" """
Raise ValidationError if the given update data contains a field that is not Raise ValidationError if the given update data contains a field that is not
in editable_fields. editable by the requesting user
""" """
editable_fields = get_editable_fields(cc_content, context) _check_fields(
non_editable_errors = { get_editable_fields(cc_content, context),
field: ["This field is not editable."] data,
for field in data.keys() "This field is not editable."
if field not in editable_fields )
}
if non_editable_errors:
raise ValidationError(non_editable_errors)
def _do_extra_actions(api_content, cc_content, request_fields, actions_form, context): def _do_extra_actions(api_content, cc_content, request_fields, actions_form, context):
...@@ -406,7 +468,7 @@ def create_thread(request, thread_data): ...@@ -406,7 +468,7 @@ def create_thread(request, thread_data):
""" """
Create a thread. Create a thread.
Parameters: Arguments:
request: The django request object used for build_absolute_uri and request: The django request object used for build_absolute_uri and
determining the requesting user. determining the requesting user.
...@@ -428,6 +490,13 @@ def create_thread(request, thread_data): ...@@ -428,6 +490,13 @@ def create_thread(request, thread_data):
raise ValidationError({"course_id": ["Invalid value."]}) raise ValidationError({"course_id": ["Invalid value."]})
context = get_context(course, request) context = get_context(course, request)
_check_initializable_thread_fields(thread_data, context)
if (
"group_id" not in thread_data and
is_commentable_cohorted(course_key, thread_data.get("topic_id"))
):
thread_data = thread_data.copy()
thread_data["group_id"] = get_cohort_id(request.user, course_key)
serializer = ThreadSerializer(data=thread_data, context=context) serializer = ThreadSerializer(data=thread_data, context=context)
actions_form = ThreadActionsForm(thread_data) actions_form = ThreadActionsForm(thread_data)
if not (serializer.is_valid() and actions_form.is_valid()): if not (serializer.is_valid() and actions_form.is_valid()):
...@@ -453,7 +522,7 @@ def create_comment(request, comment_data): ...@@ -453,7 +522,7 @@ def create_comment(request, comment_data):
""" """
Create a comment. Create a comment.
Parameters: Arguments:
request: The django request object used for build_absolute_uri and request: The django request object used for build_absolute_uri and
determining the requesting user. determining the requesting user.
...@@ -473,6 +542,7 @@ def create_comment(request, comment_data): ...@@ -473,6 +542,7 @@ def create_comment(request, comment_data):
except Http404: except Http404:
raise ValidationError({"thread_id": ["Invalid value."]}) raise ValidationError({"thread_id": ["Invalid value."]})
_check_initializable_comment_fields(comment_data, context)
serializer = CommentSerializer(data=comment_data, context=context) serializer = CommentSerializer(data=comment_data, context=context)
actions_form = CommentActionsForm(comment_data) actions_form = CommentActionsForm(comment_data)
if not (serializer.is_valid() and actions_form.is_valid()): if not (serializer.is_valid() and actions_form.is_valid()):
...@@ -498,7 +568,7 @@ def update_thread(request, thread_id, update_data): ...@@ -498,7 +568,7 @@ def update_thread(request, thread_id, update_data):
""" """
Update a thread. Update a thread.
Parameters: Arguments:
request: The django request object used for build_absolute_uri and request: The django request object used for build_absolute_uri and
determining the requesting user. determining the requesting user.
...@@ -530,7 +600,7 @@ def update_comment(request, comment_id, update_data): ...@@ -530,7 +600,7 @@ def update_comment(request, comment_id, update_data):
""" """
Update a comment. Update a comment.
Parameters: Arguments:
request: The django request object used for build_absolute_uri and request: The django request object used for build_absolute_uri and
determining the requesting user. determining the requesting user.
...@@ -573,7 +643,7 @@ def delete_thread(request, thread_id): ...@@ -573,7 +643,7 @@ def delete_thread(request, thread_id):
""" """
Delete a thread. Delete a thread.
Parameters: Arguments:
request: The django request object used for build_absolute_uri and request: The django request object used for build_absolute_uri and
determining the requesting user. determining the requesting user.
...@@ -596,7 +666,7 @@ def delete_comment(request, comment_id): ...@@ -596,7 +666,7 @@ def delete_comment(request, comment_id):
""" """
Delete a comment. Delete a comment.
Parameters: Arguments:
request: The django request object used for build_absolute_uri and request: The django request object used for build_absolute_uri and
determining the requesting user. determining the requesting user.
......
""" """
Discussion API permission logic Discussion API permission logic
""" """
from lms.lib.comment_client.comment import Comment
from lms.lib.comment_client.thread import Thread
def _is_author(cc_content, context): def _is_author(cc_content, context):
...@@ -18,6 +20,38 @@ def _is_author_or_privileged(cc_content, context): ...@@ -18,6 +20,38 @@ def _is_author_or_privileged(cc_content, context):
return context["is_requester_privileged"] or _is_author(cc_content, context) return context["is_requester_privileged"] or _is_author(cc_content, context)
NON_UPDATABLE_THREAD_FIELDS = {"course_id"}
NON_UPDATABLE_COMMENT_FIELDS = {"thread_id", "parent_id"}
def get_initializable_thread_fields(context):
"""
Return the set of fields that the requester can initialize for a thread
Any field that is editable by the author should also be initializable.
"""
ret = get_editable_fields(
Thread(user_id=context["cc_requester"]["id"], type="thread"),
context
)
ret |= NON_UPDATABLE_THREAD_FIELDS
return ret
def get_initializable_comment_fields(context): # pylint: disable=invalid-name
"""
Return the set of fields that the requester can initialize for a comment
Any field that is editable by the author should also be initializable.
"""
ret = get_editable_fields(
Comment(user_id=context["cc_requester"]["id"], type="comment"),
context
)
ret |= NON_UPDATABLE_COMMENT_FIELDS
return ret
def get_editable_fields(cc_content, context): def get_editable_fields(cc_content, context):
""" """
Return the set of fields that the requester can edit on the given content Return the set of fields that the requester can edit on the given content
...@@ -32,6 +66,8 @@ def get_editable_fields(cc_content, context): ...@@ -32,6 +66,8 @@ def get_editable_fields(cc_content, context):
ret |= {"following"} ret |= {"following"}
if _is_author_or_privileged(cc_content, context): if _is_author_or_privileged(cc_content, context):
ret |= {"topic_id", "type", "title"} ret |= {"topic_id", "type", "title"}
if context["is_requester_privileged"] and context["course"].is_cohorted:
ret |= {"group_id"}
# Comment fields # Comment fields
if ( if (
......
...@@ -10,7 +10,11 @@ from django.core.urlresolvers import reverse ...@@ -10,7 +10,11 @@ from django.core.urlresolvers import reverse
from rest_framework import serializers from rest_framework import serializers
from discussion_api.permissions import get_editable_fields from discussion_api.permissions import (
NON_UPDATABLE_COMMENT_FIELDS,
NON_UPDATABLE_THREAD_FIELDS,
get_editable_fields,
)
from discussion_api.render import render_body from discussion_api.render import render_body
from django_comment_client.utils import is_comment_too_deep from django_comment_client.utils import is_comment_too_deep
from django_comment_common.models import ( from django_comment_common.models import (
...@@ -76,7 +80,7 @@ class _ContentSerializer(serializers.Serializer): ...@@ -76,7 +80,7 @@ class _ContentSerializer(serializers.Serializer):
vote_count = serializers.SerializerMethodField("get_vote_count") vote_count = serializers.SerializerMethodField("get_vote_count")
editable_fields = serializers.SerializerMethodField("get_editable_fields") editable_fields = serializers.SerializerMethodField("get_editable_fields")
non_updatable_fields = () non_updatable_fields = set()
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(_ContentSerializer, self).__init__(*args, **kwargs) super(_ContentSerializer, self).__init__(*args, **kwargs)
...@@ -166,7 +170,7 @@ class ThreadSerializer(_ContentSerializer): ...@@ -166,7 +170,7 @@ class ThreadSerializer(_ContentSerializer):
""" """
course_id = serializers.CharField() course_id = serializers.CharField()
topic_id = NonEmptyCharField(source="commentable_id") topic_id = NonEmptyCharField(source="commentable_id")
group_id = serializers.IntegerField(read_only=True) group_id = serializers.IntegerField(required=False)
group_name = serializers.SerializerMethodField("get_group_name") group_name = serializers.SerializerMethodField("get_group_name")
type_ = serializers.ChoiceField( type_ = serializers.ChoiceField(
source="thread_type", source="thread_type",
...@@ -182,7 +186,7 @@ class ThreadSerializer(_ContentSerializer): ...@@ -182,7 +186,7 @@ class ThreadSerializer(_ContentSerializer):
endorsed_comment_list_url = serializers.SerializerMethodField("get_endorsed_comment_list_url") endorsed_comment_list_url = serializers.SerializerMethodField("get_endorsed_comment_list_url")
non_endorsed_comment_list_url = serializers.SerializerMethodField("get_non_endorsed_comment_list_url") non_endorsed_comment_list_url = serializers.SerializerMethodField("get_non_endorsed_comment_list_url")
non_updatable_fields = ("course_id",) non_updatable_fields = NON_UPDATABLE_THREAD_FIELDS
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ThreadSerializer, self).__init__(*args, **kwargs) super(ThreadSerializer, self).__init__(*args, **kwargs)
...@@ -256,7 +260,7 @@ class CommentSerializer(_ContentSerializer): ...@@ -256,7 +260,7 @@ class CommentSerializer(_ContentSerializer):
endorsed_at = serializers.SerializerMethodField("get_endorsed_at") endorsed_at = serializers.SerializerMethodField("get_endorsed_at")
children = serializers.SerializerMethodField("get_children") children = serializers.SerializerMethodField("get_children")
non_updatable_fields = ("thread_id", "parent_id") non_updatable_fields = NON_UPDATABLE_COMMENT_FIELDS
def get_endorsed_by(self, obj): def get_endorsed_by(self, obj):
""" """
......
...@@ -1182,6 +1182,7 @@ class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -1182,6 +1182,7 @@ class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
self.get_comment_list(thread, endorsed=True, page=2, page_size=10) self.get_comment_list(thread, endorsed=True, page=2, page_size=10)
@ddt.ddt
class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase): class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase):
"""Tests for create_thread""" """Tests for create_thread"""
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
...@@ -1273,6 +1274,69 @@ class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC ...@@ -1273,6 +1274,69 @@ class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC
} }
) )
@ddt.data(
*itertools.product(
[
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_MODERATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT,
],
[True, False],
[True, False],
["no_group_set", "group_is_none", "group_is_set"],
)
)
@ddt.unpack
def test_group_id(self, role_name, course_is_cohorted, topic_is_cohorted, data_group_state):
"""
Tests whether the user has permission to create a thread with certain
group_id values.
If there is no group, user cannot create a thread.
Else if group is None or set, and the course is not cohorted and/or the
role is a student, user can create a thread.
"""
cohort_course = CourseFactory.create(
discussion_topics={"Test Topic": {"id": "test_topic"}},
cohort_config={
"cohorted": course_is_cohorted,
"cohorted_discussions": ["test_topic"] if topic_is_cohorted else [],
}
)
CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id)
if course_is_cohorted:
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
role.users = [self.user]
self.register_post_thread_response({})
data = self.minimal_data.copy()
data["course_id"] = unicode(cohort_course.id)
if data_group_state == "group_is_none":
data["group_id"] = None
elif data_group_state == "group_is_set":
if course_is_cohorted:
data["group_id"] = cohort.id + 1
else:
data["group_id"] = 1 # Set to any value since there is no cohort
expected_error = (
data_group_state in ["group_is_none", "group_is_set"] and
(not course_is_cohorted or role_name == FORUM_ROLE_STUDENT)
)
try:
create_thread(self.request, data)
self.assertFalse(expected_error)
actual_post_data = httpretty.last_request().parsed_body
if data_group_state == "group_is_set":
self.assertEqual(actual_post_data["group_id"], [str(data["group_id"])])
elif data_group_state == "no_group_set" and course_is_cohorted and topic_is_cohorted:
self.assertEqual(actual_post_data["group_id"], [str(cohort.id)])
else:
self.assertNotIn("group_id", actual_post_data)
except ValidationError:
self.assertTrue(expected_error)
def test_following(self): def test_following(self):
self.register_post_thread_response({"id": "test_id"}) self.register_post_thread_response({"id": "test_id"})
self.register_subscription_response(self.user) self.register_subscription_response(self.user)
...@@ -1456,6 +1520,44 @@ class CreateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -1456,6 +1520,44 @@ class CreateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
self.assertEqual(actual_event_name, expected_event_name) self.assertEqual(actual_event_name, expected_event_name)
self.assertEqual(actual_event_data, expected_event_data) self.assertEqual(actual_event_data, expected_event_data)
@ddt.data(
*itertools.product(
[
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_MODERATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT,
],
[True, False],
["question", "discussion"],
)
)
@ddt.unpack
def test_endorsed(self, role_name, is_thread_author, thread_type):
role = Role.objects.create(name=role_name, course_id=self.course.id)
role.users = [self.user]
self.register_get_thread_response(
make_minimal_cs_thread({
"id": "test_thread",
"course_id": unicode(self.course.id),
"thread_type": thread_type,
"user_id": str(self.user.id) if is_thread_author else str(self.user.id + 1),
})
)
self.register_post_comment_response({}, "test_thread")
data = self.minimal_data.copy()
data["endorsed"] = True
expected_error = (
role_name == FORUM_ROLE_STUDENT and
(not is_thread_author or thread_type == "discussion")
)
try:
create_comment(self.request, data)
self.assertEqual(httpretty.last_request().parsed_body["endorsed"], ["True"])
self.assertFalse(expected_error)
except ValidationError:
self.assertTrue(expected_error)
def test_voted(self): def test_voted(self):
self.register_post_comment_response({"id": "test_comment"}, "test_thread") self.register_post_comment_response({"id": "test_comment"}, "test_thread")
self.register_comment_votes_response("test_comment") self.register_comment_votes_response("test_comment")
......
...@@ -2,37 +2,86 @@ ...@@ -2,37 +2,86 @@
Tests for discussion API permission logic Tests for discussion API permission logic
""" """
import itertools import itertools
from unittest import TestCase
import ddt import ddt
from discussion_api.permissions import can_delete, get_editable_fields from discussion_api.permissions import (
can_delete,
get_editable_fields,
get_initializable_comment_fields,
get_initializable_thread_fields,
)
from lms.lib.comment_client.comment import Comment from lms.lib.comment_client.comment import Comment
from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.user import User from lms.lib.comment_client.user import User
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory
def _get_context(requester_id, is_requester_privileged, thread=None): def _get_context(requester_id, is_requester_privileged, is_cohorted=False, thread=None):
"""Return a context suitable for testing the permissions module""" """Return a context suitable for testing the permissions module"""
return { return {
"cc_requester": User(id=requester_id), "cc_requester": User(id=requester_id),
"is_requester_privileged": is_requester_privileged, "is_requester_privileged": is_requester_privileged,
"course": CourseFactory(cohort_config={"cohorted": is_cohorted}),
"thread": thread, "thread": thread,
} }
@ddt.ddt @ddt.ddt
class GetEditableFieldsTest(TestCase): class GetInitializableFieldsTest(ModuleStoreTestCase):
"""Tests for get_editable_fields""" """Tests for get_*_initializable_fields"""
@ddt.data(*itertools.product([True, False], [True, False])) @ddt.data(*itertools.product([True, False], [True, False]))
@ddt.unpack @ddt.unpack
def test_thread(self, is_author, is_privileged): def test_thread(self, is_privileged, is_cohorted):
context = _get_context(
requester_id="5",
is_requester_privileged=is_privileged,
is_cohorted=is_cohorted
)
actual = get_initializable_thread_fields(context)
expected = {
"abuse_flagged", "course_id", "following", "raw_body", "title", "topic_id", "type", "voted"
}
if is_privileged and is_cohorted:
expected |= {"group_id"}
self.assertEqual(actual, expected)
@ddt.data(*itertools.product([True, False], ["question", "discussion"], [True, False]))
@ddt.unpack
def test_comment(self, is_thread_author, thread_type, is_privileged):
context = _get_context(
requester_id="5",
is_requester_privileged=is_privileged,
thread=Thread(user_id="5" if is_thread_author else "6", thread_type=thread_type)
)
actual = get_initializable_comment_fields(context)
expected = {
"abuse_flagged", "parent_id", "raw_body", "thread_id", "voted"
}
if (is_thread_author and thread_type == "question") or is_privileged:
expected |= {"endorsed"}
self.assertEqual(actual, expected)
@ddt.ddt
class GetEditableFieldsTest(ModuleStoreTestCase):
"""Tests for get_editable_fields"""
@ddt.data(*itertools.product([True, False], [True, False], [True, False]))
@ddt.unpack
def test_thread(self, is_author, is_privileged, is_cohorted):
thread = Thread(user_id="5" if is_author else "6", type="thread") thread = Thread(user_id="5" if is_author else "6", type="thread")
context = _get_context(requester_id="5", is_requester_privileged=is_privileged) context = _get_context(
requester_id="5",
is_requester_privileged=is_privileged,
is_cohorted=is_cohorted
)
actual = get_editable_fields(thread, context) actual = get_editable_fields(thread, context)
expected = {"abuse_flagged", "following", "voted"} expected = {"abuse_flagged", "following", "voted"}
if is_author or is_privileged: if is_author or is_privileged:
expected |= {"topic_id", "type", "title", "raw_body"} expected |= {"topic_id", "type", "title", "raw_body"}
if is_privileged and is_cohorted:
expected |= {"group_id"}
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
@ddt.data(*itertools.product([True, False], [True, False], ["question", "discussion"], [True, False])) @ddt.data(*itertools.product([True, False], [True, False], ["question", "discussion"], [True, False]))
...@@ -54,7 +103,7 @@ class GetEditableFieldsTest(TestCase): ...@@ -54,7 +103,7 @@ class GetEditableFieldsTest(TestCase):
@ddt.ddt @ddt.ddt
class CanDeleteTest(TestCase): class CanDeleteTest(ModuleStoreTestCase):
"""Tests for can_delete""" """Tests for can_delete"""
@ddt.data(*itertools.product([True, False], [True, False])) @ddt.data(*itertools.product([True, False], [True, False]))
@ddt.unpack @ddt.unpack
......
...@@ -462,6 +462,24 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi ...@@ -462,6 +462,24 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
) )
self.assertEqual(saved["id"], "test_id") self.assertEqual(saved["id"], "test_id")
def test_create_all_fields(self):
self.register_post_thread_response({"id": "test_id"})
data = self.minimal_data.copy()
data["group_id"] = 42
self.save_and_reserialize(data)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"course_id": [unicode(self.course.id)],
"commentable_id": ["test_topic"],
"thread_type": ["discussion"],
"title": ["Test Title"],
"body": ["Test body"],
"user_id": [str(self.user.id)],
"group_id": ["42"],
}
)
def test_create_missing_field(self): def test_create_missing_field(self):
for field in self.minimal_data: for field in self.minimal_data:
data = self.minimal_data.copy() data = self.minimal_data.copy()
...@@ -638,6 +656,27 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore ...@@ -638,6 +656,27 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
self.assertEqual(saved["id"], "test_comment") self.assertEqual(saved["id"], "test_comment")
self.assertEqual(saved["parent_id"], parent_id) self.assertEqual(saved["parent_id"], parent_id)
def test_create_all_fields(self):
data = self.minimal_data.copy()
data["parent_id"] = "test_parent"
data["endorsed"] = True
self.register_get_comment_response({"thread_id": "test_thread", "id": "test_parent"})
self.register_post_comment_response(
{"id": "test_comment"},
thread_id="test_thread",
parent_id="test_parent"
)
self.save_and_reserialize(data)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"course_id": [unicode(self.course.id)],
"body": ["Test body"],
"user_id": [str(self.user.id)],
"endorsed": ["True"],
}
)
def test_create_parent_id_nonexistent(self): def test_create_parent_id_nonexistent(self):
self.register_get_comment_error_response("bad_parent", 404) self.register_get_comment_error_response("bad_parent", 404)
data = self.minimal_data.copy() data = self.minimal_data.copy()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment