Commit b8576e74 by Greg Price Committed by christopher lee

Allow group_id to be set in discussion API

This also fixes two bugs:
* A thread created by a non-privileged user in a cohorted course would
  not have group_id set
* Any user could set the endorsed field on creation of a comment
parent 1e8ebc80
......@@ -17,7 +17,12 @@ from opaque_keys.edx.locator import CourseKey
from courseware.courses import get_course_with_access
from discussion_api.forms import CommentActionsForm, ThreadActionsForm
from discussion_api.pagination import get_paginated_data
from discussion_api.permissions import can_delete, get_editable_fields
from discussion_api.permissions import (
can_delete,
get_editable_fields,
get_initializable_comment_fields,
get_initializable_thread_fields,
)
from discussion_api.serializers import CommentSerializer, ThreadSerializer, get_context
from django_comment_client.base.views import (
THREAD_CREATED_EVENT_NAME,
......@@ -295,7 +300,7 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
"""
Return the list of comments in the given thread.
Parameters:
Arguments:
request: The django request object used for build_absolute_uri and
determining the requesting user.
......@@ -361,19 +366,76 @@ def get_comment_list(request, thread_id, endorsed, page, page_size):
return get_paginated_data(request, results, page, num_pages)
def _check_fields(allowed_fields, data, message):
"""
Checks that the keys given in data is in allowed_fields
Arguments:
allowed_fields (set): A set of allowed fields
data (dict): The data to compare the allowed_fields against
message (str): The message to return if there are any invalid fields
Raises:
ValidationError if the given data contains a key that is not in
allowed_fields
"""
non_allowed_fields = {field: [message] for field in data.keys() if field not in allowed_fields}
if non_allowed_fields:
raise ValidationError(non_allowed_fields)
def _check_initializable_thread_fields(data, context): # pylint: disable=invalid-name
"""
Checks if the given data contains a thread field that is not initializable
by the requesting user
Arguments:
data (dict): The data to compare the allowed_fields against
context (dict): The context appropriate for use with the thread which
includes the requesting user
Raises:
ValidationError if the given data contains a thread field that is not
initializable by the requesting user
"""
_check_fields(
get_initializable_thread_fields(context),
data,
"This field is not initializable."
)
def _check_initializable_comment_fields(data, context): # pylint: disable=invalid-name
"""
Checks if the given data contains a comment field that is not initializable
by the requesting user
Arguments:
data (dict): The data to compare the allowed_fields against
context (dict): The context appropriate for use with the comment which
includes the requesting user
Raises:
ValidationError if the given data contains a comment field that is not
initializable by the requesting user
"""
_check_fields(
get_initializable_comment_fields(context),
data,
"This field is not initializable."
)
def _check_editable_fields(cc_content, data, context):
"""
Raise ValidationError if the given update data contains a field that is not
in editable_fields.
editable by the requesting user
"""
editable_fields = get_editable_fields(cc_content, context)
non_editable_errors = {
field: ["This field is not editable."]
for field in data.keys()
if field not in editable_fields
}
if non_editable_errors:
raise ValidationError(non_editable_errors)
_check_fields(
get_editable_fields(cc_content, context),
data,
"This field is not editable."
)
def _do_extra_actions(api_content, cc_content, request_fields, actions_form, context):
......@@ -406,7 +468,7 @@ def create_thread(request, thread_data):
"""
Create a thread.
Parameters:
Arguments:
request: The django request object used for build_absolute_uri and
determining the requesting user.
......@@ -428,6 +490,13 @@ def create_thread(request, thread_data):
raise ValidationError({"course_id": ["Invalid value."]})
context = get_context(course, request)
_check_initializable_thread_fields(thread_data, context)
if (
"group_id" not in thread_data and
is_commentable_cohorted(course_key, thread_data.get("topic_id"))
):
thread_data = thread_data.copy()
thread_data["group_id"] = get_cohort_id(request.user, course_key)
serializer = ThreadSerializer(data=thread_data, context=context)
actions_form = ThreadActionsForm(thread_data)
if not (serializer.is_valid() and actions_form.is_valid()):
......@@ -453,7 +522,7 @@ def create_comment(request, comment_data):
"""
Create a comment.
Parameters:
Arguments:
request: The django request object used for build_absolute_uri and
determining the requesting user.
......@@ -473,6 +542,7 @@ def create_comment(request, comment_data):
except Http404:
raise ValidationError({"thread_id": ["Invalid value."]})
_check_initializable_comment_fields(comment_data, context)
serializer = CommentSerializer(data=comment_data, context=context)
actions_form = CommentActionsForm(comment_data)
if not (serializer.is_valid() and actions_form.is_valid()):
......@@ -498,7 +568,7 @@ def update_thread(request, thread_id, update_data):
"""
Update a thread.
Parameters:
Arguments:
request: The django request object used for build_absolute_uri and
determining the requesting user.
......@@ -530,7 +600,7 @@ def update_comment(request, comment_id, update_data):
"""
Update a comment.
Parameters:
Arguments:
request: The django request object used for build_absolute_uri and
determining the requesting user.
......@@ -573,7 +643,7 @@ def delete_thread(request, thread_id):
"""
Delete a thread.
Parameters:
Arguments:
request: The django request object used for build_absolute_uri and
determining the requesting user.
......@@ -596,7 +666,7 @@ def delete_comment(request, comment_id):
"""
Delete a comment.
Parameters:
Arguments:
request: The django request object used for build_absolute_uri and
determining the requesting user.
......
"""
Discussion API permission logic
"""
from lms.lib.comment_client.comment import Comment
from lms.lib.comment_client.thread import Thread
def _is_author(cc_content, context):
......@@ -18,6 +20,38 @@ def _is_author_or_privileged(cc_content, context):
return context["is_requester_privileged"] or _is_author(cc_content, context)
NON_UPDATABLE_THREAD_FIELDS = {"course_id"}
NON_UPDATABLE_COMMENT_FIELDS = {"thread_id", "parent_id"}
def get_initializable_thread_fields(context):
"""
Return the set of fields that the requester can initialize for a thread
Any field that is editable by the author should also be initializable.
"""
ret = get_editable_fields(
Thread(user_id=context["cc_requester"]["id"], type="thread"),
context
)
ret |= NON_UPDATABLE_THREAD_FIELDS
return ret
def get_initializable_comment_fields(context): # pylint: disable=invalid-name
"""
Return the set of fields that the requester can initialize for a comment
Any field that is editable by the author should also be initializable.
"""
ret = get_editable_fields(
Comment(user_id=context["cc_requester"]["id"], type="comment"),
context
)
ret |= NON_UPDATABLE_COMMENT_FIELDS
return ret
def get_editable_fields(cc_content, context):
"""
Return the set of fields that the requester can edit on the given content
......@@ -32,6 +66,8 @@ def get_editable_fields(cc_content, context):
ret |= {"following"}
if _is_author_or_privileged(cc_content, context):
ret |= {"topic_id", "type", "title"}
if context["is_requester_privileged"] and context["course"].is_cohorted:
ret |= {"group_id"}
# Comment fields
if (
......
......@@ -10,7 +10,11 @@ from django.core.urlresolvers import reverse
from rest_framework import serializers
from discussion_api.permissions import get_editable_fields
from discussion_api.permissions import (
NON_UPDATABLE_COMMENT_FIELDS,
NON_UPDATABLE_THREAD_FIELDS,
get_editable_fields,
)
from discussion_api.render import render_body
from django_comment_client.utils import is_comment_too_deep
from django_comment_common.models import (
......@@ -76,7 +80,7 @@ class _ContentSerializer(serializers.Serializer):
vote_count = serializers.SerializerMethodField("get_vote_count")
editable_fields = serializers.SerializerMethodField("get_editable_fields")
non_updatable_fields = ()
non_updatable_fields = set()
def __init__(self, *args, **kwargs):
super(_ContentSerializer, self).__init__(*args, **kwargs)
......@@ -166,7 +170,7 @@ class ThreadSerializer(_ContentSerializer):
"""
course_id = serializers.CharField()
topic_id = NonEmptyCharField(source="commentable_id")
group_id = serializers.IntegerField(read_only=True)
group_id = serializers.IntegerField(required=False)
group_name = serializers.SerializerMethodField("get_group_name")
type_ = serializers.ChoiceField(
source="thread_type",
......@@ -182,7 +186,7 @@ class ThreadSerializer(_ContentSerializer):
endorsed_comment_list_url = serializers.SerializerMethodField("get_endorsed_comment_list_url")
non_endorsed_comment_list_url = serializers.SerializerMethodField("get_non_endorsed_comment_list_url")
non_updatable_fields = ("course_id",)
non_updatable_fields = NON_UPDATABLE_THREAD_FIELDS
def __init__(self, *args, **kwargs):
super(ThreadSerializer, self).__init__(*args, **kwargs)
......@@ -256,7 +260,7 @@ class CommentSerializer(_ContentSerializer):
endorsed_at = serializers.SerializerMethodField("get_endorsed_at")
children = serializers.SerializerMethodField("get_children")
non_updatable_fields = ("thread_id", "parent_id")
non_updatable_fields = NON_UPDATABLE_COMMENT_FIELDS
def get_endorsed_by(self, obj):
"""
......
......@@ -1182,6 +1182,7 @@ class GetCommentListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
self.get_comment_list(thread, endorsed=True, page=2, page_size=10)
@ddt.ddt
class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestCase):
"""Tests for create_thread"""
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
......@@ -1273,6 +1274,69 @@ class CreateThreadTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTestC
}
)
@ddt.data(
*itertools.product(
[
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_MODERATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT,
],
[True, False],
[True, False],
["no_group_set", "group_is_none", "group_is_set"],
)
)
@ddt.unpack
def test_group_id(self, role_name, course_is_cohorted, topic_is_cohorted, data_group_state):
"""
Tests whether the user has permission to create a thread with certain
group_id values.
If there is no group, user cannot create a thread.
Else if group is None or set, and the course is not cohorted and/or the
role is a student, user can create a thread.
"""
cohort_course = CourseFactory.create(
discussion_topics={"Test Topic": {"id": "test_topic"}},
cohort_config={
"cohorted": course_is_cohorted,
"cohorted_discussions": ["test_topic"] if topic_is_cohorted else [],
}
)
CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id)
if course_is_cohorted:
cohort = CohortFactory.create(course_id=cohort_course.id, users=[self.user])
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
role.users = [self.user]
self.register_post_thread_response({})
data = self.minimal_data.copy()
data["course_id"] = unicode(cohort_course.id)
if data_group_state == "group_is_none":
data["group_id"] = None
elif data_group_state == "group_is_set":
if course_is_cohorted:
data["group_id"] = cohort.id + 1
else:
data["group_id"] = 1 # Set to any value since there is no cohort
expected_error = (
data_group_state in ["group_is_none", "group_is_set"] and
(not course_is_cohorted or role_name == FORUM_ROLE_STUDENT)
)
try:
create_thread(self.request, data)
self.assertFalse(expected_error)
actual_post_data = httpretty.last_request().parsed_body
if data_group_state == "group_is_set":
self.assertEqual(actual_post_data["group_id"], [str(data["group_id"])])
elif data_group_state == "no_group_set" and course_is_cohorted and topic_is_cohorted:
self.assertEqual(actual_post_data["group_id"], [str(cohort.id)])
else:
self.assertNotIn("group_id", actual_post_data)
except ValidationError:
self.assertTrue(expected_error)
def test_following(self):
self.register_post_thread_response({"id": "test_id"})
self.register_subscription_response(self.user)
......@@ -1456,6 +1520,44 @@ class CreateCommentTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
self.assertEqual(actual_event_name, expected_event_name)
self.assertEqual(actual_event_data, expected_event_data)
@ddt.data(
*itertools.product(
[
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_MODERATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_STUDENT,
],
[True, False],
["question", "discussion"],
)
)
@ddt.unpack
def test_endorsed(self, role_name, is_thread_author, thread_type):
role = Role.objects.create(name=role_name, course_id=self.course.id)
role.users = [self.user]
self.register_get_thread_response(
make_minimal_cs_thread({
"id": "test_thread",
"course_id": unicode(self.course.id),
"thread_type": thread_type,
"user_id": str(self.user.id) if is_thread_author else str(self.user.id + 1),
})
)
self.register_post_comment_response({}, "test_thread")
data = self.minimal_data.copy()
data["endorsed"] = True
expected_error = (
role_name == FORUM_ROLE_STUDENT and
(not is_thread_author or thread_type == "discussion")
)
try:
create_comment(self.request, data)
self.assertEqual(httpretty.last_request().parsed_body["endorsed"], ["True"])
self.assertFalse(expected_error)
except ValidationError:
self.assertTrue(expected_error)
def test_voted(self):
self.register_post_comment_response({"id": "test_comment"}, "test_thread")
self.register_comment_votes_response("test_comment")
......
......@@ -2,37 +2,86 @@
Tests for discussion API permission logic
"""
import itertools
from unittest import TestCase
import ddt
from discussion_api.permissions import can_delete, get_editable_fields
from discussion_api.permissions import (
can_delete,
get_editable_fields,
get_initializable_comment_fields,
get_initializable_thread_fields,
)
from lms.lib.comment_client.comment import Comment
from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.user import User
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory
def _get_context(requester_id, is_requester_privileged, thread=None):
def _get_context(requester_id, is_requester_privileged, is_cohorted=False, thread=None):
"""Return a context suitable for testing the permissions module"""
return {
"cc_requester": User(id=requester_id),
"is_requester_privileged": is_requester_privileged,
"course": CourseFactory(cohort_config={"cohorted": is_cohorted}),
"thread": thread,
}
@ddt.ddt
class GetEditableFieldsTest(TestCase):
"""Tests for get_editable_fields"""
class GetInitializableFieldsTest(ModuleStoreTestCase):
"""Tests for get_*_initializable_fields"""
@ddt.data(*itertools.product([True, False], [True, False]))
@ddt.unpack
def test_thread(self, is_author, is_privileged):
def test_thread(self, is_privileged, is_cohorted):
context = _get_context(
requester_id="5",
is_requester_privileged=is_privileged,
is_cohorted=is_cohorted
)
actual = get_initializable_thread_fields(context)
expected = {
"abuse_flagged", "course_id", "following", "raw_body", "title", "topic_id", "type", "voted"
}
if is_privileged and is_cohorted:
expected |= {"group_id"}
self.assertEqual(actual, expected)
@ddt.data(*itertools.product([True, False], ["question", "discussion"], [True, False]))
@ddt.unpack
def test_comment(self, is_thread_author, thread_type, is_privileged):
context = _get_context(
requester_id="5",
is_requester_privileged=is_privileged,
thread=Thread(user_id="5" if is_thread_author else "6", thread_type=thread_type)
)
actual = get_initializable_comment_fields(context)
expected = {
"abuse_flagged", "parent_id", "raw_body", "thread_id", "voted"
}
if (is_thread_author and thread_type == "question") or is_privileged:
expected |= {"endorsed"}
self.assertEqual(actual, expected)
@ddt.ddt
class GetEditableFieldsTest(ModuleStoreTestCase):
"""Tests for get_editable_fields"""
@ddt.data(*itertools.product([True, False], [True, False], [True, False]))
@ddt.unpack
def test_thread(self, is_author, is_privileged, is_cohorted):
thread = Thread(user_id="5" if is_author else "6", type="thread")
context = _get_context(requester_id="5", is_requester_privileged=is_privileged)
context = _get_context(
requester_id="5",
is_requester_privileged=is_privileged,
is_cohorted=is_cohorted
)
actual = get_editable_fields(thread, context)
expected = {"abuse_flagged", "following", "voted"}
if is_author or is_privileged:
expected |= {"topic_id", "type", "title", "raw_body"}
if is_privileged and is_cohorted:
expected |= {"group_id"}
self.assertEqual(actual, expected)
@ddt.data(*itertools.product([True, False], [True, False], ["question", "discussion"], [True, False]))
......@@ -54,7 +103,7 @@ class GetEditableFieldsTest(TestCase):
@ddt.ddt
class CanDeleteTest(TestCase):
class CanDeleteTest(ModuleStoreTestCase):
"""Tests for can_delete"""
@ddt.data(*itertools.product([True, False], [True, False]))
@ddt.unpack
......
......@@ -462,6 +462,24 @@ class ThreadSerializerDeserializationTest(CommentsServiceMockMixin, UrlResetMixi
)
self.assertEqual(saved["id"], "test_id")
def test_create_all_fields(self):
self.register_post_thread_response({"id": "test_id"})
data = self.minimal_data.copy()
data["group_id"] = 42
self.save_and_reserialize(data)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"course_id": [unicode(self.course.id)],
"commentable_id": ["test_topic"],
"thread_type": ["discussion"],
"title": ["Test Title"],
"body": ["Test body"],
"user_id": [str(self.user.id)],
"group_id": ["42"],
}
)
def test_create_missing_field(self):
for field in self.minimal_data:
data = self.minimal_data.copy()
......@@ -638,6 +656,27 @@ class CommentSerializerDeserializationTest(CommentsServiceMockMixin, ModuleStore
self.assertEqual(saved["id"], "test_comment")
self.assertEqual(saved["parent_id"], parent_id)
def test_create_all_fields(self):
data = self.minimal_data.copy()
data["parent_id"] = "test_parent"
data["endorsed"] = True
self.register_get_comment_response({"thread_id": "test_thread", "id": "test_parent"})
self.register_post_comment_response(
{"id": "test_comment"},
thread_id="test_thread",
parent_id="test_parent"
)
self.save_and_reserialize(data)
self.assertEqual(
httpretty.last_request().parsed_body,
{
"course_id": [unicode(self.course.id)],
"body": ["Test body"],
"user_id": [str(self.user.id)],
"endorsed": ["True"],
}
)
def test_create_parent_id_nonexistent(self):
self.register_get_comment_error_response("bad_parent", 404)
data = self.minimal_data.copy()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment