Commit e124fb06 by Greg Price

Merge pull request #8059 from edx/gprice/discussion-api-thread-list-refactor

Refactor Discussion API thread list
parents 1775dd53 7309352e
......@@ -5,26 +5,34 @@ from django.http import Http404
from collections import defaultdict
from courseware.courses import get_course_with_access
from discussion_api.pagination import get_paginated_data
from discussion_api.serializers import ThreadSerializer, get_context
from django_comment_client.utils import get_accessible_discussion_modules
from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_MODERATOR,
Role,
)
from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.user import User
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, get_cohort_names
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id
from xmodule.tabs import DiscussionTab
def get_course_topics(course, user):
def _get_course_or_404(course_key, user):
"""
Get the course descriptor, raising Http404 if the course is not found,
the user cannot access forums for the course, or the discussion tab is
disabled for the course.
"""
course = get_course_with_access(user, 'load_forum', course_key)
if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]):
raise Http404
return course
def get_course_topics(course_key, user):
"""
Return the course topic listing for the given course and user.
Parameters:
course: The course to get topics for
course_key: The key of the course to get topics for
user: The requesting user, for access control
Returns:
......@@ -39,6 +47,7 @@ def get_course_topics(course, user):
"""
return module.sort_key or module.discussion_target
course = _get_course_or_404(course_key, user)
discussion_modules = get_accessible_discussion_modules(course, user)
modules_by_category = defaultdict(list)
for module in discussion_modules:
......@@ -77,75 +86,14 @@ def get_course_topics(course, user):
}
def _cc_thread_to_api_thread(thread, cc_user, staff_user_ids, ta_user_ids, group_ids_to_names):
"""
Convert a thread data dict from the comment_client format (which is a direct
representation of the format returned by the comments service) to the format
used in this API
Arguments:
thread (comment_client.thread.Thread): The thread to convert
cc_user (comment_client.user.User): The comment_client representation of
the requesting user
staff_user_ids (set): The set of user ids for users with the Moderator or
Administrator role in the course
ta_user_ids (set): The set of user ids for users with the Community TA
role in the course
group_ids_to_names (dict): A mapping of group ids to names
Returns:
dict: The discussion_api format representation of the thread.
"""
is_anonymous = (
thread["anonymous"] or
(
thread["anonymous_to_peers"] and
int(cc_user["id"]) not in (staff_user_ids | ta_user_ids)
)
)
ret = {
key: thread[key]
for key in [
"id",
"course_id",
"group_id",
"created_at",
"updated_at",
"title",
"pinned",
"closed",
]
}
ret.update({
"topic_id": thread["commentable_id"],
"group_name": group_ids_to_names.get(thread["group_id"]),
"author": None if is_anonymous else thread["username"],
"author_label": (
None if is_anonymous else
"staff" if int(thread["user_id"]) in staff_user_ids else
"community_ta" if int(thread["user_id"]) in ta_user_ids else
None
),
"type": thread["thread_type"],
"raw_body": thread["body"],
"following": thread["id"] in cc_user["subscribed_thread_ids"],
"abuse_flagged": cc_user["id"] in thread["abuse_flaggers"],
"voted": thread["id"] in cc_user["upvoted_ids"],
"vote_count": thread["votes"]["up_count"],
"comment_count": thread["comments_count"],
"unread_comment_count": thread["unread_comments_count"],
})
return ret
def get_thread_list(request, course, page, page_size):
def get_thread_list(request, course_key, page, page_size):
"""
Return the list of all discussion threads pertaining to the given course
Parameters:
request: The django request objects used for build_absolute_uri
course: The course to get discussion threads for
course_key: The key of the course to get discussion threads for
page: The page number (1-indexed) to retrieve
page_size: The number of threads to retrieve per page
......@@ -154,15 +102,14 @@ def get_thread_list(request, course, page, page_size):
A paginated result containing a list of threads; see
discussion_api.views.ThreadViewSet for more detail.
"""
user_is_privileged = Role.objects.filter(
course_id=course.id,
name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR, FORUM_ROLE_COMMUNITY_TA],
users=request.user
).exists()
cc_user = User.from_django_user(request.user).retrieve()
course = _get_course_or_404(course_key, request.user)
context = get_context(course, request.user)
threads, result_page, num_pages, _ = Thread.search({
"course_id": unicode(course.id),
"group_id": None if user_is_privileged else get_cohort_id(request.user, course.id),
"group_id": (
None if context["is_requester_privileged"] else
get_cohort_id(request.user, course.id)
),
"sort_key": "date",
"sort_order": "desc",
"page": page,
......@@ -173,25 +120,6 @@ def get_thread_list(request, course, page, page_size):
# behavior and return a 404 in that case
if result_page != page:
raise Http404
# TODO: cache staff_user_ids and ta_user_ids if we need to improve perf
staff_user_ids = {
user.id
for role in Role.objects.filter(
name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR],
course_id=course.id
)
for user in role.users.all()
}
ta_user_ids = {
user.id
for role in Role.objects.filter(name=FORUM_ROLE_COMMUNITY_TA, course_id=course.id)
for user in role.users.all()
}
# For now, the only groups are cohorts
group_ids_to_names = get_cohort_names(course)
results = [
_cc_thread_to_api_thread(thread, cc_user, staff_user_ids, ta_user_ids, group_ids_to_names)
for thread in threads
]
results = [ThreadSerializer(thread, context=context).data for thread in threads]
return get_paginated_data(request, results, page, num_pages)
"""
Discussion API serializers
"""
from rest_framework import serializers
from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_MODERATOR,
Role,
)
from lms.lib.comment_client.user import User
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_names
def get_context(course, requester):
"""Returns a context appropriate for use with ThreadSerializer."""
# TODO: cache staff_user_ids and ta_user_ids if we need to improve perf
staff_user_ids = {
user.id
for role in Role.objects.filter(
name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR],
course_id=course.id
)
for user in role.users.all()
}
ta_user_ids = {
user.id
for role in Role.objects.filter(name=FORUM_ROLE_COMMUNITY_TA, course_id=course.id)
for user in role.users.all()
}
return {
# For now, the only groups are cohorts
"group_ids_to_names": get_cohort_names(course),
"is_requester_privileged": requester.id in staff_user_ids or requester.id in ta_user_ids,
"staff_user_ids": staff_user_ids,
"ta_user_ids": ta_user_ids,
"cc_requester": User.from_django_user(requester).retrieve(),
}
class ThreadSerializer(serializers.Serializer):
"""
A serializer for thread data.
N.B. This should not be used with a comment_client Thread object that has
not had retrieve() called, because of the interaction between DRF's attempts
at introspection and Thread's __getattr__.
"""
id_ = serializers.CharField(read_only=True)
course_id = serializers.CharField()
topic_id = serializers.CharField(source="commentable_id")
group_id = serializers.IntegerField()
group_name = serializers.SerializerMethodField("get_group_name")
author = serializers.SerializerMethodField("get_author")
author_label = serializers.SerializerMethodField("get_author_label")
created_at = serializers.CharField(read_only=True)
updated_at = serializers.CharField(read_only=True)
type_ = serializers.ChoiceField(source="thread_type", choices=("discussion", "question"))
title = serializers.CharField()
raw_body = serializers.CharField(source="body")
pinned = serializers.BooleanField()
closed = serializers.BooleanField()
following = serializers.SerializerMethodField("get_following")
abuse_flagged = serializers.SerializerMethodField("get_abuse_flagged")
voted = serializers.SerializerMethodField("get_voted")
vote_count = serializers.SerializerMethodField("get_vote_count")
comment_count = serializers.IntegerField(source="comments_count")
unread_comment_count = serializers.IntegerField(source="unread_comments_count")
def __init__(self, *args, **kwargs):
super(ThreadSerializer, self).__init__(*args, **kwargs)
# type and id are invalid class attribute names, so we must declare
# different names above and modify them here
self.fields["id"] = self.fields.pop("id_")
self.fields["type"] = self.fields.pop("type_")
def get_group_name(self, obj):
"""Returns the name of the group identified by the thread's group_id."""
return self.context["group_ids_to_names"].get(obj["group_id"])
def _is_anonymous(self, obj):
"""
Returns a boolean indicating whether the thread should be anonymous to
the requester.
"""
return (
obj["anonymous"] or
obj["anonymous_to_peers"] and not self.context["is_requester_privileged"]
)
def get_author(self, obj):
"""Returns the author's username, or None if the thread is anonymous."""
return None if self._is_anonymous(obj) else obj["username"]
def _get_user_label(self, user_id):
"""
Returns the role label (i.e. "staff" or "community_ta") for the user
with the given id.
"""
return (
"staff" if user_id in self.context["staff_user_ids"] else
"community_ta" if user_id in self.context["ta_user_ids"] else
None
)
def get_author_label(self, obj):
"""Returns the role label for the thread author."""
return None if self._is_anonymous(obj) else self._get_user_label(int(obj["user_id"]))
def get_following(self, obj):
"""
Returns a boolean indicating whether the requester is following the
thread.
"""
return obj["id"] in self.context["cc_requester"]["subscribed_thread_ids"]
def get_abuse_flagged(self, obj):
"""
Returns a boolean indicating whether the requester has flagged the
thread as abusive.
"""
return self.context["cc_requester"]["id"] in obj["abuse_flaggers"]
def get_voted(self, obj):
"""
Returns a boolean indicating whether the requester has voted for the
thread.
"""
return obj["id"] in self.context["cc_requester"]["upvoted_ids"]
def get_vote_count(self, obj):
"""Returns the number of votes for the thread."""
return obj["votes"]["up_count"]
......@@ -12,6 +12,8 @@ from pytz import UTC
from django.http import Http404
from django.test.client import RequestFactory
from opaque_keys.edx.locator import CourseLocator
from courseware.tests.factories import BetaTesterFactory, StaffFactory
from discussion_api.api import get_course_topics, get_thread_list
from discussion_api.tests.utils import CommentsServiceMockMixin
......@@ -24,10 +26,22 @@ from django_comment_common.models import (
)
from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup
from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
from student.tests.factories import UserFactory
from student.tests.factories import CourseEnrollmentFactory, UserFactory
from xmodule.modulestore.django import modulestore
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory
from xmodule.partitions.partitions import Group, UserPartition
from xmodule.tabs import DiscussionTab
def _remove_discussion_tab(course, user_id):
"""
Remove the discussion tab for the course.
user_id is passed to the modulestore as the editor of the module.
"""
course.tabs = [tab for tab in course.tabs if not isinstance(tab, DiscussionTab)]
modulestore().update_item(course, user_id)
@mock.patch.dict("django.conf.settings.FEATURES", {"DISABLE_START_DATES": False})
......@@ -49,12 +63,13 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
course="y",
run="z",
start=datetime.now(UTC),
discussion_topics={},
discussion_topics={"Test Topic": {"id": "non-courseware-topic-id"}},
user_partitions=[self.partition],
cohort_config={"cohorted": True},
days_early_for_beta=3
)
self.user = UserFactory.create()
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
def make_discussion_module(self, topic_id, category, subcategory, **kwargs):
"""Build a discussion module in self.course"""
......@@ -72,7 +87,7 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
Get course topics for self.course, using the given user or self.user if
not provided, and generating absolute URIs with a test scheme/host.
"""
return get_course_topics(self.course, user or self.user)
return get_course_topics(self.course.id, user or self.user)
def make_expected_tree(self, topic_id, name, children=None):
"""
......@@ -87,50 +102,58 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
}
return node
def test_empty(self):
actual = self.get_course_topics()
expected = {
"courseware_topics": [],
"non_courseware_topics": [],
}
self.assertEqual(actual, expected)
def test_nonexistent_course(self):
with self.assertRaises(Http404):
get_course_topics(CourseLocator.from_string("non/existent/course"), self.user)
def test_not_enrolled(self):
unenrolled_user = UserFactory.create()
with self.assertRaises(Http404):
get_course_topics(self.course.id, unenrolled_user)
def test_discussions_disabled(self):
_remove_discussion_tab(self.course, self.user.id)
with self.assertRaises(Http404):
self.get_course_topics()
def test_non_courseware(self):
self.course.discussion_topics = {"Topic Name": {"id": "topic-id"}}
self.course.save()
def test_without_courseware(self):
actual = self.get_course_topics()
expected = {
"courseware_topics": [],
"non_courseware_topics": [self.make_expected_tree("topic-id", "Topic Name")],
"non_courseware_topics": [
self.make_expected_tree("non-courseware-topic-id", "Test Topic")
],
}
self.assertEqual(actual, expected)
def test_courseware(self):
self.make_discussion_module("topic-id", "Foo", "Bar")
def test_with_courseware(self):
self.make_discussion_module("courseware-topic-id", "Foo", "Bar")
actual = self.get_course_topics()
expected = {
"courseware_topics": [
self.make_expected_tree(
None,
"Foo",
[self.make_expected_tree("topic-id", "Bar")]
[self.make_expected_tree("courseware-topic-id", "Bar")]
),
],
"non_courseware_topics": [],
"non_courseware_topics": [
self.make_expected_tree("non-courseware-topic-id", "Test Topic")
],
}
self.assertEqual(actual, expected)
def test_many(self):
self.course.discussion_topics = {
"A": {"id": "non-courseware-1"},
"B": {"id": "non-courseware-2"},
}
modulestore().update_item(self.course, self.user.id)
self.make_discussion_module("courseware-1", "A", "1")
self.make_discussion_module("courseware-2", "A", "2")
self.make_discussion_module("courseware-3", "B", "1")
self.make_discussion_module("courseware-4", "B", "2")
self.make_discussion_module("courseware-5", "C", "1")
self.course.discussion_topics = {
"A": {"id": "non-courseware-1"},
"B": {"id": "non-courseware-2"},
}
self.course.save()
actual = self.get_course_topics()
expected = {
"courseware_topics": [
......@@ -164,6 +187,13 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
self.assertEqual(actual, expected)
def test_sort_key(self):
self.course.discussion_topics = {
"W": {"id": "non-courseware-1", "sort_key": "Z"},
"X": {"id": "non-courseware-2"},
"Y": {"id": "non-courseware-3", "sort_key": "Y"},
"Z": {"id": "non-courseware-4", "sort_key": "W"},
}
modulestore().update_item(self.course, self.user.id)
self.make_discussion_module("courseware-1", "First", "A", sort_key="D")
self.make_discussion_module("courseware-2", "First", "B", sort_key="B")
self.make_discussion_module("courseware-3", "First", "C", sort_key="E")
......@@ -171,13 +201,6 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
self.make_discussion_module("courseware-5", "Second", "B", sort_key="G")
self.make_discussion_module("courseware-6", "Second", "C")
self.make_discussion_module("courseware-7", "Second", "D", sort_key="A")
self.course.discussion_topics = {
"W": {"id": "non-courseware-1", "sort_key": "Z"},
"X": {"id": "non-courseware-2"},
"Y": {"id": "non-courseware-3", "sort_key": "Y"},
"Z": {"id": "non-courseware-4", "sort_key": "W"},
}
self.course.save()
actual = self.get_course_topics()
expected = {
"courseware_topics": [
......@@ -223,6 +246,7 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
subcategories does not appear in the result.
"""
beta_tester = BetaTesterFactory.create(course_key=self.course.id)
CourseEnrollmentFactory.create(user=beta_tester, course_id=self.course.id)
staff = StaffFactory.create(course_key=self.course.id)
for user, group_idx in [(self.user, 0), (beta_tester, 1)]:
cohort = CohortFactory.create(
......@@ -269,7 +293,9 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
]
),
],
"non_courseware_topics": [],
"non_courseware_topics": [
self.make_expected_tree("non-courseware-topic-id", "Test Topic"),
],
}
self.assertEqual(student_actual, student_expected)
......@@ -290,7 +316,9 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
[self.make_expected_tree("courseware-5", "Future Start Date")]
),
],
"non_courseware_topics": [],
"non_courseware_topics": [
self.make_expected_tree("non-courseware-topic-id", "Test Topic"),
],
}
self.assertEqual(beta_actual, beta_expected)
......@@ -315,7 +343,9 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
]
),
],
"non_courseware_topics": [],
"non_courseware_topics": [
self.make_expected_tree("non-courseware-topic-id", "Test Topic"),
],
}
self.assertEqual(staff_actual, staff_expected)
......@@ -334,6 +364,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
self.request = RequestFactory().get("/test_path")
self.request.user = self.user
self.course = CourseFactory.create()
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
self.author = UserFactory.create()
self.cohort = CohortFactory.create(course_id=self.course.id)
......@@ -344,7 +375,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"""
course = course or self.course
self.register_get_threads_response(threads, page, num_pages)
ret = get_thread_list(self.request, course, page, page_size)
ret = get_thread_list(self.request, course.id, page, page_size)
return ret
def create_role(self, role_name, users):
......@@ -353,34 +384,19 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
role.users = users
role.save()
def make_cs_thread(self, thread_data):
"""
Create a dictionary containing all needed thread fields as returned by
the comments service with dummy data overridden by thread_data
"""
ret = {
"id": "dummy",
"course_id": unicode(self.course.id),
"commentable_id": "dummy",
"group_id": None,
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "1970-01-01T00:00:00Z",
"updated_at": "1970-01-01T00:00:00Z",
"thread_type": "discussion",
"title": "dummy",
"body": "dummy",
"pinned": False,
"closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 0},
"comments_count": 0,
"unread_comments_count": 0,
}
ret.update(thread_data)
return ret
def test_nonexistent_course(self):
with self.assertRaises(Http404):
get_thread_list(self.request, CourseLocator.from_string("non/existent/course"), 1, 1)
def test_not_enrolled(self):
self.request.user = UserFactory.create()
with self.assertRaises(Http404):
self.get_thread_list([])
def test_discussions_disabled(self):
_remove_discussion_tab(self.course, self.user.id)
with self.assertRaises(Http404):
self.get_thread_list([])
def test_empty(self):
self.assertEqual(
......@@ -404,11 +420,6 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
})
def test_thread_content(self):
self.register_get_user_response(
self.user,
subscribed_thread_ids=["test_thread_id_0"],
upvoted_ids=["test_thread_id_1"]
)
source_threads = [
{
"id": "test_thread_id_0",
......@@ -452,27 +463,6 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"comments_count": 18,
"unread_comments_count": 0,
},
{
"id": "test_thread_id_2",
"course_id": unicode(self.course.id),
"commentable_id": "topic_x",
"group_id": self.cohort.id + 1, # non-existent group
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-04-28T00:44:44Z",
"updated_at": "2015-04-28T00:55:55Z",
"thread_type": "discussion",
"title": "Yet Another Test Title",
"body": "Still more content",
"pinned": True,
"closed": False,
"abuse_flaggers": [str(self.user.id)],
"votes": {"up_count": 0},
"comments_count": 0,
"unread_comments_count": 0,
},
]
expected_threads = [
{
......@@ -490,7 +480,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"raw_body": "Test body",
"pinned": False,
"closed": False,
"following": True,
"following": False,
"abuse_flagged": False,
"voted": False,
"vote_count": 4,
......@@ -514,33 +504,11 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"closed": True,
"following": False,
"abuse_flagged": False,
"voted": True,
"voted": False,
"vote_count": 9,
"comment_count": 18,
"unread_comment_count": 0,
},
{
"id": "test_thread_id_2",
"course_id": unicode(self.course.id),
"topic_id": "topic_x",
"group_id": self.cohort.id + 1,
"group_name": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-04-28T00:44:44Z",
"updated_at": "2015-04-28T00:55:55Z",
"type": "discussion",
"title": "Yet Another Test Title",
"raw_body": "Still more content",
"pinned": True,
"closed": False,
"following": False,
"abuse_flagged": True,
"voted": False,
"vote_count": 0,
"comment_count": 0,
"unread_comment_count": 0,
},
]
self.assertEqual(
self.get_thread_list(source_threads),
......@@ -565,6 +533,7 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
@ddt.unpack
def test_request_group(self, role_name, course_is_cohorted):
cohort_course = CourseFactory.create(cohort_config={"cohorted": course_is_cohorted})
CourseEnrollmentFactory.create(user=self.user, course_id=cohort_course.id)
CohortFactory.create(course_id=cohort_course.id, users=[self.user])
role = Role.objects.create(name=role_name, course_id=cohort_course.id)
role.users = [self.user]
......@@ -603,70 +572,4 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
# Test page past the last one
self.register_get_threads_response([], page=3, num_pages=3)
with self.assertRaises(Http404):
get_thread_list(self.request, self.course, page=4, page_size=10)
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, True, False, True),
(FORUM_ROLE_ADMINISTRATOR, False, True, False),
(FORUM_ROLE_MODERATOR, True, False, True),
(FORUM_ROLE_MODERATOR, False, True, False),
(FORUM_ROLE_COMMUNITY_TA, True, False, True),
(FORUM_ROLE_COMMUNITY_TA, False, True, False),
(FORUM_ROLE_STUDENT, True, False, True),
(FORUM_ROLE_STUDENT, False, True, True),
)
@ddt.unpack
def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_api_anonymous):
"""
Test that a thread is properly made anonymous.
A thread should be anonymous iff the anonymous field is true or the
anonymous_to_peers field is true and the requester does not have a
privileged role.
role_name is the name of the requester's role.
thread_anon is the value of the anonymous field in the thread data.
thread_anon_to_peers is the value of the anonymous_to_peers field in the
thread data.
expected_api_anonymous is whether the thread should actually be
anonymous in the API output when requested by a user with the given
role.
"""
self.create_role(role_name, [self.user])
result = self.get_thread_list([
self.make_cs_thread({
"anonymous": anonymous,
"anonymous_to_peers": anonymous_to_peers,
})
])
actual_api_anonymous = result["results"][0]["author"] is None
self.assertEqual(actual_api_anonymous, expected_api_anonymous)
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, False, "staff"),
(FORUM_ROLE_ADMINISTRATOR, True, None),
(FORUM_ROLE_MODERATOR, False, "staff"),
(FORUM_ROLE_MODERATOR, True, None),
(FORUM_ROLE_COMMUNITY_TA, False, "community_ta"),
(FORUM_ROLE_COMMUNITY_TA, True, None),
(FORUM_ROLE_STUDENT, False, None),
(FORUM_ROLE_STUDENT, True, None),
)
@ddt.unpack
def test_author_labels(self, role_name, anonymous, expected_label):
"""
Test correctness of the author_label field.
The label should be "staff", "staff", or "community_ta" for the
Administrator, Moderator, and Community TA roles, respectively, but
the label should not be present if the thread is anonymous.
role_name is the name of the author's role.
anonymous is the value of the anonymous field in the thread data.
expected_label is the expected value of the author_label field in the
API output.
"""
self.create_role(role_name, [self.author])
result = self.get_thread_list([self.make_cs_thread({"anonymous": anonymous})])
actual_label = result["results"][0]["author_label"]
self.assertEqual(actual_label, expected_label)
get_thread_list(self.request, self.course.id, page=4, page_size=10)
"""
Tests for Discussion API serializers
"""
import ddt
import httpretty
from discussion_api.serializers import ThreadSerializer, get_context
from discussion_api.tests.utils import CommentsServiceMockMixin
from django_comment_common.models import (
FORUM_ROLE_ADMINISTRATOR,
FORUM_ROLE_COMMUNITY_TA,
FORUM_ROLE_MODERATOR,
FORUM_ROLE_STUDENT,
Role,
)
from student.tests.factories import UserFactory
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory
from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
@ddt.ddt
class ThreadSerializerTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"""Tests for ThreadSerializer."""
def setUp(self):
super(ThreadSerializerTest, self).setUp()
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
self.maxDiff = None # pylint: disable=invalid-name
self.user = UserFactory.create()
self.register_get_user_response(self.user)
self.course = CourseFactory.create()
self.author = UserFactory.create()
def create_role(self, role_name, users, course=None):
"""Create a Role in self.course with the given name and users"""
course = course or self.course
role = Role.objects.create(name=role_name, course_id=course.id)
role.users = users
def make_cs_thread(self, thread_data=None):
"""
Create a dictionary containing all needed thread fields as returned by
the comments service with dummy data overridden by thread_data
"""
ret = {
"id": "dummy",
"course_id": unicode(self.course.id),
"commentable_id": "dummy",
"group_id": None,
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "1970-01-01T00:00:00Z",
"updated_at": "1970-01-01T00:00:00Z",
"thread_type": "discussion",
"title": "dummy",
"body": "dummy",
"pinned": False,
"closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 0},
"comments_count": 0,
"unread_comments_count": 0,
"children": [],
"resp_total": 0,
}
if thread_data:
ret.update(thread_data)
return ret
def serialize(self, thread):
"""
Create a serializer with an appropriate context and use it to serialize
the given thread, returning the result.
"""
return ThreadSerializer(thread, context=get_context(self.course, self.user)).data
def test_basic(self):
thread = {
"id": "test_thread",
"course_id": unicode(self.course.id),
"commentable_id": "test_topic",
"group_id": None,
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"thread_type": "discussion",
"title": "Test Title",
"body": "Test body",
"pinned": True,
"closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 4},
"comments_count": 5,
"unread_comments_count": 3,
}
expected = {
"id": "test_thread",
"course_id": unicode(self.course.id),
"topic_id": "test_topic",
"group_id": None,
"group_name": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"type": "discussion",
"title": "Test Title",
"raw_body": "Test body",
"pinned": True,
"closed": False,
"following": False,
"abuse_flagged": False,
"voted": False,
"vote_count": 4,
"comment_count": 5,
"unread_comment_count": 3,
}
self.assertEqual(self.serialize(thread), expected)
def test_group(self):
cohort = CohortFactory.create(course_id=self.course.id)
serialized = self.serialize(self.make_cs_thread({"group_id": cohort.id}))
self.assertEqual(serialized["group_id"], cohort.id)
self.assertEqual(serialized["group_name"], cohort.name)
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, True, False, True),
(FORUM_ROLE_ADMINISTRATOR, False, True, False),
(FORUM_ROLE_MODERATOR, True, False, True),
(FORUM_ROLE_MODERATOR, False, True, False),
(FORUM_ROLE_COMMUNITY_TA, True, False, True),
(FORUM_ROLE_COMMUNITY_TA, False, True, False),
(FORUM_ROLE_STUDENT, True, False, True),
(FORUM_ROLE_STUDENT, False, True, True),
)
@ddt.unpack
def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_serialized_anonymous):
"""
Test that content is properly made anonymous.
Content should be anonymous iff the anonymous field is true or the
anonymous_to_peers field is true and the requester does not have a
privileged role.
role_name is the name of the requester's role.
anonymous is the value of the anonymous field in the content.
anonymous_to_peers is the value of the anonymous_to_peers field in the
content.
expected_serialized_anonymous is whether the content should actually be
anonymous in the API output when requested by a user with the given
role.
"""
self.create_role(role_name, [self.user])
serialized = self.serialize(
self.make_cs_thread({"anonymous": anonymous, "anonymous_to_peers": anonymous_to_peers})
)
actual_serialized_anonymous = serialized["author"] is None
self.assertEqual(actual_serialized_anonymous, expected_serialized_anonymous)
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, False, "staff"),
(FORUM_ROLE_ADMINISTRATOR, True, None),
(FORUM_ROLE_MODERATOR, False, "staff"),
(FORUM_ROLE_MODERATOR, True, None),
(FORUM_ROLE_COMMUNITY_TA, False, "community_ta"),
(FORUM_ROLE_COMMUNITY_TA, True, None),
(FORUM_ROLE_STUDENT, False, None),
(FORUM_ROLE_STUDENT, True, None),
)
@ddt.unpack
def test_author_labels(self, role_name, anonymous, expected_label):
"""
Test correctness of the author_label field.
The label should be "staff", "staff", or "community_ta" for the
Administrator, Moderator, and Community TA roles, respectively, but
the label should not be present if the thread is anonymous.
role_name is the name of the author's role.
anonymous is the value of the anonymous field in the content.
expected_label is the expected value of the author_label field in the
API output.
"""
self.create_role(role_name, [self.author])
serialized = self.serialize(self.make_cs_thread({"anonymous": anonymous}))
self.assertEqual(serialized["author_label"], expected_label)
def test_following(self):
thread_id = "test_thread"
self.register_get_user_response(self.user, subscribed_thread_ids=[thread_id])
serialized = self.serialize(self.make_cs_thread({"id": thread_id}))
self.assertEqual(serialized["following"], True)
def test_abuse_flagged(self):
serialized = self.serialize(self.make_cs_thread({"abuse_flaggers": [str(self.user.id)]}))
self.assertEqual(serialized["abuse_flagged"], True)
def test_voted(self):
thread_id = "test_thread"
self.register_get_user_response(self.user, upvoted_ids=[thread_id])
serialized = self.serialize(self.make_cs_thread({"id": thread_id}))
self.assertEqual(serialized["voted"], True)
......@@ -42,11 +42,6 @@ class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin):
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
self.client.login(username=self.user.username, password=self.password)
def login_unenrolled_user(self):
"""Create a user not enrolled in the course and log it in"""
unenrolled_user = UserFactory.create(password=self.password)
self.client.login(username=unenrolled_user.username, password=self.password)
def assert_response_correct(self, response, expected_status, expected_content):
"""
Assert that the response has the given status code and parsed content
......@@ -71,7 +66,7 @@ class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
super(CourseTopicsViewTest, self).setUp()
self.url = reverse("course_topics", kwargs={"course_id": unicode(self.course.id)})
def test_non_existent_course(self):
def test_404(self):
response = self.client.get(
reverse("course_topics", kwargs={"course_id": "non/existent/course"})
)
......@@ -81,26 +76,7 @@ class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
{"developer_message": "Not found."}
)
def test_not_enrolled(self):
self.login_unenrolled_user()
response = self.client.get(self.url)
self.assert_response_correct(
response,
404,
{"developer_message": "Not found."}
)
def test_discussions_disabled(self):
self.course.tabs = [tab for tab in self.course.tabs if not isinstance(tab, DiscussionTab)]
modulestore().update_item(self.course, self.user.id)
response = self.client.get(self.url)
self.assert_response_correct(
response,
404,
{"developer_message": "Not found."}
)
def test_get(self):
def test_get_success(self):
response = self.client.get(self.url)
self.assert_response_correct(
response,
......@@ -132,9 +108,8 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
{"field_errors": {"course_id": "This field is required."}}
)
def test_not_enrolled(self):
self.login_unenrolled_user()
response = self.client.get(self.url, {"course_id": unicode(self.course.id)})
def test_404(self):
response = self.client.get(self.url, {"course_id": unicode("non/existent/course")})
self.assert_response_correct(
response,
404,
......
......@@ -2,7 +2,6 @@
Discussion API views
"""
from django.core.exceptions import ValidationError
from django.http import Http404
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication
from rest_framework.permissions import IsAuthenticated
......@@ -12,11 +11,9 @@ from rest_framework.viewsets import ViewSet
from opaque_keys.edx.locator import CourseLocator
from courseware.courses import get_course_with_access
from discussion_api.api import get_course_topics, get_thread_list
from discussion_api.forms import ThreadListGetForm
from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin
from xmodule.tabs import DiscussionTab
class _ViewMixin(object):
......@@ -27,17 +24,6 @@ class _ViewMixin(object):
authentication_classes = (OAuth2Authentication, SessionAuthentication)
permission_classes = (IsAuthenticated,)
def get_course_or_404(self, user, course_key):
"""
Get the course descriptor, raising Http404 if the course is not found,
the user cannot access forums for the course, or the discussion tab is
disabled for the course.
"""
course = get_course_with_access(user, 'load_forum', course_key)
if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]):
raise Http404
return course
class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView):
"""
......@@ -68,8 +54,7 @@ class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView):
def get(self, request, course_id):
"""Implements the GET method as described in the class docstring."""
course_key = CourseLocator.from_string(course_id)
course = self.get_course_or_404(request.user, course_key)
return Response(get_course_topics(course, request.user))
return Response(get_course_topics(course_key, request.user))
class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
......@@ -133,11 +118,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
form = ThreadListGetForm(request.GET)
if not form.is_valid():
raise ValidationError(form.errors)
course = self.get_course_or_404(request.user, form.cleaned_data["course_id"])
return Response(
get_thread_list(
request,
course,
form.cleaned_data["course_id"],
form.cleaned_data["page"],
form.cleaned_data["page_size"]
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment