Commit 1f03a4df by Greg Price

Merge pull request #7966 from edx/gprice/discussion-api-thread-list-more-fields

Add fields to threads returned by Discussion API thread list endpoint
parents e60b72cf 7ce579c2
...@@ -14,7 +14,8 @@ from django_comment_common.models import ( ...@@ -14,7 +14,8 @@ from django_comment_common.models import (
Role, Role,
) )
from lms.lib.comment_client.thread import Thread from lms.lib.comment_client.thread import Thread
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id from lms.lib.comment_client.user import User
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, get_cohort_names
def get_course_topics(course, user): def get_course_topics(course, user):
...@@ -76,17 +77,38 @@ def get_course_topics(course, user): ...@@ -76,17 +77,38 @@ def get_course_topics(course, user):
} }
def _cc_thread_to_api_thread(thread): def _cc_thread_to_api_thread(thread, cc_user, staff_user_ids, ta_user_ids, group_ids_to_names):
""" """
Convert a thread data dict from the comment_client format (which is a direct Convert a thread data dict from the comment_client format (which is a direct
representation of the format returned by the comments service) to the format representation of the format returned by the comments service) to the format
used in this API used in this API
Arguments:
thread (comment_client.thread.Thread): The thread to convert
cc_user (comment_client.user.User): The comment_client representation of
the requesting user
staff_user_ids (set): The set of user ids for users with the Moderator or
Administrator role in the course
ta_user_ids (set): The set of user ids for users with the Community TA
role in the course
group_ids_to_names (dict): A mapping of group ids to names
Returns:
dict: The discussion_api format representation of the thread.
""" """
is_anonymous = (
thread["anonymous"] or
(
thread["anonymous_to_peers"] and
int(cc_user["id"]) not in (staff_user_ids | ta_user_ids)
)
)
ret = { ret = {
key: thread[key] key: thread[key]
for key in [ for key in [
"id", "id",
"course_id", "course_id",
"group_id",
"created_at", "created_at",
"updated_at", "updated_at",
"type", "type",
...@@ -97,21 +119,33 @@ def _cc_thread_to_api_thread(thread): ...@@ -97,21 +119,33 @@ def _cc_thread_to_api_thread(thread):
} }
ret.update({ ret.update({
"topic_id": thread["commentable_id"], "topic_id": thread["commentable_id"],
"group_name": group_ids_to_names.get(thread["group_id"]),
"author": None if is_anonymous else thread["username"],
"author_label": (
None if is_anonymous else
"staff" if int(thread["user_id"]) in staff_user_ids else
"community_ta" if int(thread["user_id"]) in ta_user_ids else
None
),
"raw_body": thread["body"], "raw_body": thread["body"],
"following": thread["id"] in cc_user["subscribed_thread_ids"],
"abuse_flagged": cc_user["id"] in thread["abuse_flaggers"],
"voted": thread["id"] in cc_user["upvoted_ids"],
"vote_count": thread["votes"]["up_count"],
"comment_count": thread["comments_count"], "comment_count": thread["comments_count"],
"unread_comment_count": thread["unread_comments_count"], "unread_comment_count": thread["unread_comments_count"],
}) })
return ret return ret
def get_thread_list(request, course_key, page, page_size): def get_thread_list(request, course, page, page_size):
""" """
Return the list of all discussion threads pertaining to the given course Return the list of all discussion threads pertaining to the given course
Parameters: Parameters:
request: The django request objects used for build_absolute_uri request: The django request objects used for build_absolute_uri
course_key: The key of the course to get discussion threads for course: The course to get discussion threads for
page: The page number (1-indexed) to retrieve page: The page number (1-indexed) to retrieve
page_size: The number of threads to retrieve per page page_size: The number of threads to retrieve per page
...@@ -121,13 +155,14 @@ def get_thread_list(request, course_key, page, page_size): ...@@ -121,13 +155,14 @@ def get_thread_list(request, course_key, page, page_size):
discussion_api.views.ThreadViewSet for more detail. discussion_api.views.ThreadViewSet for more detail.
""" """
user_is_privileged = Role.objects.filter( user_is_privileged = Role.objects.filter(
course_id=course_key, course_id=course.id,
name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR, FORUM_ROLE_COMMUNITY_TA], name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR, FORUM_ROLE_COMMUNITY_TA],
users=request.user users=request.user
).exists() ).exists()
cc_user = User.from_django_user(request.user).retrieve()
threads, result_page, num_pages, _ = Thread.search({ threads, result_page, num_pages, _ = Thread.search({
"course_id": unicode(course_key), "course_id": unicode(course.id),
"group_id": None if user_is_privileged else get_cohort_id(request.user, course_key), "group_id": None if user_is_privileged else get_cohort_id(request.user, course.id),
"sort_key": "date", "sort_key": "date",
"sort_order": "desc", "sort_order": "desc",
"page": page, "page": page,
...@@ -138,6 +173,25 @@ def get_thread_list(request, course_key, page, page_size): ...@@ -138,6 +173,25 @@ def get_thread_list(request, course_key, page, page_size):
# behavior and return a 404 in that case # behavior and return a 404 in that case
if result_page != page: if result_page != page:
raise Http404 raise Http404
# TODO: cache staff_user_ids and ta_user_ids if we need to improve perf
staff_user_ids = {
user.id
for role in Role.objects.filter(
name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR],
course_id=course.id
)
for user in role.users.all()
}
ta_user_ids = {
user.id
for role in Role.objects.filter(name=FORUM_ROLE_COMMUNITY_TA, course_id=course.id)
for user in role.users.all()
}
# For now, the only groups are cohorts
group_ids_to_names = get_cohort_names(course)
results = [_cc_thread_to_api_thread(thread) for thread in threads] results = [
_cc_thread_to_api_thread(thread, cc_user, staff_user_ids, ta_user_ids, group_ids_to_names)
for thread in threads
]
return get_paginated_data(request, results, page, num_pages) return get_paginated_data(request, results, page, num_pages)
...@@ -12,8 +12,6 @@ from pytz import UTC ...@@ -12,8 +12,6 @@ from pytz import UTC
from django.http import Http404 from django.http import Http404
from django.test.client import RequestFactory from django.test.client import RequestFactory
from opaque_keys.edx.locator import CourseLocator
from courseware.tests.factories import BetaTesterFactory, StaffFactory from courseware.tests.factories import BetaTesterFactory, StaffFactory
from discussion_api.api import get_course_topics, get_thread_list from discussion_api.api import get_course_topics, get_thread_list
from discussion_api.tests.utils import CommentsServiceMockMixin from discussion_api.tests.utils import CommentsServiceMockMixin
...@@ -323,16 +321,21 @@ class GetCourseTopicsTest(ModuleStoreTestCase): ...@@ -323,16 +321,21 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
@ddt.ddt @ddt.ddt
@httpretty.activate
class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"""Test for get_thread_list""" """Test for get_thread_list"""
def setUp(self): def setUp(self):
super(GetThreadListTest, self).setUp() super(GetThreadListTest, self).setUp()
httpretty.reset()
httpretty.enable()
self.addCleanup(httpretty.disable)
self.maxDiff = None # pylint: disable=invalid-name self.maxDiff = None # pylint: disable=invalid-name
self.user = UserFactory.create() self.user = UserFactory.create()
self.register_get_user_response(self.user)
self.request = RequestFactory().get("/test_path") self.request = RequestFactory().get("/test_path")
self.request.user = self.user self.request.user = self.user
self.course = CourseFactory.create() self.course = CourseFactory.create()
self.author = UserFactory.create()
self.cohort = CohortFactory.create(course_id=self.course.id)
def get_thread_list(self, threads, page=1, page_size=1, num_pages=1, course=None): def get_thread_list(self, threads, page=1, page_size=1, num_pages=1, course=None):
""" """
...@@ -341,7 +344,42 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -341,7 +344,42 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
""" """
course = course or self.course course = course or self.course
self.register_get_threads_response(threads, page, num_pages) self.register_get_threads_response(threads, page, num_pages)
ret = get_thread_list(self.request, course.id, page, page_size) ret = get_thread_list(self.request, course, page, page_size)
return ret
def create_role(self, role_name, users):
"""Create a Role in self.course with the given name and users"""
role = Role.objects.create(name=role_name, course_id=self.course.id)
role.users = users
role.save()
def make_cs_thread(self, thread_data):
"""
Create a dictionary containing all needed thread fields as returned by
the comments service with dummy data overridden by thread_data
"""
ret = {
"id": "dummy",
"course_id": unicode(self.course.id),
"commentable_id": "dummy",
"group_id": None,
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "1970-01-01T00:00:00Z",
"updated_at": "1970-01-01T00:00:00Z",
"type": "discussion",
"title": "dummy",
"body": "dummy",
"pinned": False,
"closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 0},
"comments_count": 0,
"unread_comments_count": 0,
}
ret.update(thread_data)
return ret return ret
def test_empty(self): def test_empty(self):
...@@ -366,11 +404,21 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -366,11 +404,21 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
}) })
def test_thread_content(self): def test_thread_content(self):
self.register_get_user_response(
self.user,
subscribed_thread_ids=["test_thread_id_0"],
upvoted_ids=["test_thread_id_1"]
)
source_threads = [ source_threads = [
{ {
"id": "test_thread_id_0", "id": "test_thread_id_0",
"course_id": unicode(self.course.id), "course_id": unicode(self.course.id),
"commentable_id": "topic_x", "commentable_id": "topic_x",
"group_id": None,
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-04-28T00:00:00Z", "created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z", "updated_at": "2015-04-28T11:11:11Z",
"type": "discussion", "type": "discussion",
...@@ -378,6 +426,8 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -378,6 +426,8 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"body": "Test body", "body": "Test body",
"pinned": False, "pinned": False,
"closed": False, "closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 4},
"comments_count": 5, "comments_count": 5,
"unread_comments_count": 3, "unread_comments_count": 3,
}, },
...@@ -385,6 +435,11 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -385,6 +435,11 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"id": "test_thread_id_1", "id": "test_thread_id_1",
"course_id": unicode(self.course.id), "course_id": unicode(self.course.id),
"commentable_id": "topic_y", "commentable_id": "topic_y",
"group_id": self.cohort.id,
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-04-28T22:22:22Z", "created_at": "2015-04-28T22:22:22Z",
"updated_at": "2015-04-28T00:33:33Z", "updated_at": "2015-04-28T00:33:33Z",
"type": "question", "type": "question",
...@@ -392,6 +447,8 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -392,6 +447,8 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"body": "More content", "body": "More content",
"pinned": False, "pinned": False,
"closed": True, "closed": True,
"abuse_flaggers": [],
"votes": {"up_count": 9},
"comments_count": 18, "comments_count": 18,
"unread_comments_count": 0, "unread_comments_count": 0,
}, },
...@@ -399,6 +456,11 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -399,6 +456,11 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"id": "test_thread_id_2", "id": "test_thread_id_2",
"course_id": unicode(self.course.id), "course_id": unicode(self.course.id),
"commentable_id": "topic_x", "commentable_id": "topic_x",
"group_id": self.cohort.id + 1, # non-existent group
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-04-28T00:44:44Z", "created_at": "2015-04-28T00:44:44Z",
"updated_at": "2015-04-28T00:55:55Z", "updated_at": "2015-04-28T00:55:55Z",
"type": "discussion", "type": "discussion",
...@@ -406,6 +468,8 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -406,6 +468,8 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"body": "Still more content", "body": "Still more content",
"pinned": True, "pinned": True,
"closed": False, "closed": False,
"abuse_flaggers": [str(self.user.id)],
"votes": {"up_count": 0},
"comments_count": 0, "comments_count": 0,
"unread_comments_count": 0, "unread_comments_count": 0,
}, },
...@@ -415,6 +479,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -415,6 +479,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"id": "test_thread_id_0", "id": "test_thread_id_0",
"course_id": unicode(self.course.id), "course_id": unicode(self.course.id),
"topic_id": "topic_x", "topic_id": "topic_x",
"group_id": None,
"group_name": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-04-28T00:00:00Z", "created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z", "updated_at": "2015-04-28T11:11:11Z",
"type": "discussion", "type": "discussion",
...@@ -422,6 +490,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -422,6 +490,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"raw_body": "Test body", "raw_body": "Test body",
"pinned": False, "pinned": False,
"closed": False, "closed": False,
"following": True,
"abuse_flagged": False,
"voted": False,
"vote_count": 4,
"comment_count": 5, "comment_count": 5,
"unread_comment_count": 3, "unread_comment_count": 3,
}, },
...@@ -429,6 +501,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -429,6 +501,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"id": "test_thread_id_1", "id": "test_thread_id_1",
"course_id": unicode(self.course.id), "course_id": unicode(self.course.id),
"topic_id": "topic_y", "topic_id": "topic_y",
"group_id": self.cohort.id,
"group_name": self.cohort.name,
"author": self.author.username,
"author_label": None,
"created_at": "2015-04-28T22:22:22Z", "created_at": "2015-04-28T22:22:22Z",
"updated_at": "2015-04-28T00:33:33Z", "updated_at": "2015-04-28T00:33:33Z",
"type": "question", "type": "question",
...@@ -436,6 +512,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -436,6 +512,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"raw_body": "More content", "raw_body": "More content",
"pinned": False, "pinned": False,
"closed": True, "closed": True,
"following": False,
"abuse_flagged": False,
"voted": True,
"vote_count": 9,
"comment_count": 18, "comment_count": 18,
"unread_comment_count": 0, "unread_comment_count": 0,
}, },
...@@ -443,6 +523,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -443,6 +523,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"id": "test_thread_id_2", "id": "test_thread_id_2",
"course_id": unicode(self.course.id), "course_id": unicode(self.course.id),
"topic_id": "topic_x", "topic_id": "topic_x",
"group_id": self.cohort.id + 1,
"group_name": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-04-28T00:44:44Z", "created_at": "2015-04-28T00:44:44Z",
"updated_at": "2015-04-28T00:55:55Z", "updated_at": "2015-04-28T00:55:55Z",
"type": "discussion", "type": "discussion",
...@@ -450,6 +534,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -450,6 +534,10 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"raw_body": "Still more content", "raw_body": "Still more content",
"pinned": True, "pinned": True,
"closed": False, "closed": False,
"following": False,
"abuse_flagged": True,
"voted": False,
"vote_count": 0,
"comment_count": 0, "comment_count": 0,
"unread_comment_count": 0, "unread_comment_count": 0,
}, },
...@@ -515,4 +603,70 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase): ...@@ -515,4 +603,70 @@ class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
# Test page past the last one # Test page past the last one
self.register_get_threads_response([], page=3, num_pages=3) self.register_get_threads_response([], page=3, num_pages=3)
with self.assertRaises(Http404): with self.assertRaises(Http404):
get_thread_list(self.request, self.course.id, page=4, page_size=10) get_thread_list(self.request, self.course, page=4, page_size=10)
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, True, False, True),
(FORUM_ROLE_ADMINISTRATOR, False, True, False),
(FORUM_ROLE_MODERATOR, True, False, True),
(FORUM_ROLE_MODERATOR, False, True, False),
(FORUM_ROLE_COMMUNITY_TA, True, False, True),
(FORUM_ROLE_COMMUNITY_TA, False, True, False),
(FORUM_ROLE_STUDENT, True, False, True),
(FORUM_ROLE_STUDENT, False, True, True),
)
@ddt.unpack
def test_anonymity(self, role_name, anonymous, anonymous_to_peers, expected_api_anonymous):
"""
Test that a thread is properly made anonymous.
A thread should be anonymous iff the anonymous field is true or the
anonymous_to_peers field is true and the requester does not have a
privileged role.
role_name is the name of the requester's role.
thread_anon is the value of the anonymous field in the thread data.
thread_anon_to_peers is the value of the anonymous_to_peers field in the
thread data.
expected_api_anonymous is whether the thread should actually be
anonymous in the API output when requested by a user with the given
role.
"""
self.create_role(role_name, [self.user])
result = self.get_thread_list([
self.make_cs_thread({
"anonymous": anonymous,
"anonymous_to_peers": anonymous_to_peers,
})
])
actual_api_anonymous = result["results"][0]["author"] is None
self.assertEqual(actual_api_anonymous, expected_api_anonymous)
@ddt.data(
(FORUM_ROLE_ADMINISTRATOR, False, "staff"),
(FORUM_ROLE_ADMINISTRATOR, True, None),
(FORUM_ROLE_MODERATOR, False, "staff"),
(FORUM_ROLE_MODERATOR, True, None),
(FORUM_ROLE_COMMUNITY_TA, False, "community_ta"),
(FORUM_ROLE_COMMUNITY_TA, True, None),
(FORUM_ROLE_STUDENT, False, None),
(FORUM_ROLE_STUDENT, True, None),
)
@ddt.unpack
def test_author_labels(self, role_name, anonymous, expected_label):
"""
Test correctness of the author_label field.
The label should be "staff", "staff", or "community_ta" for the
Administrator, Moderator, and Community TA roles, respectively, but
the label should not be present if the thread is anonymous.
role_name is the name of the author's role.
anonymous is the value of the anonymous field in the thread data.
expected_label is the expected value of the author_label field in the
API output.
"""
self.create_role(role_name, [self.author])
result = self.get_thread_list([self.make_cs_thread({"anonymous": anonymous})])
actual_label = result["results"][0]["author_label"]
self.assertEqual(actual_label, expected_label)
...@@ -121,6 +121,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -121,6 +121,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for ThreadViewSet list""" """Tests for ThreadViewSet list"""
def setUp(self): def setUp(self):
super(ThreadViewSetListTest, self).setUp() super(ThreadViewSetListTest, self).setUp()
self.author = UserFactory.create()
self.url = reverse("thread-list") self.url = reverse("thread-list")
def test_course_id_missing(self): def test_course_id_missing(self):
...@@ -141,10 +142,16 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -141,10 +142,16 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
) )
def test_basic(self): def test_basic(self):
self.register_get_user_response(self.user, upvoted_ids=["test_thread"])
source_threads = [{ source_threads = [{
"id": "test_thread", "id": "test_thread",
"course_id": unicode(self.course.id), "course_id": unicode(self.course.id),
"commentable_id": "test_topic", "commentable_id": "test_topic",
"group_id": None,
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-04-28T00:00:00Z", "created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z", "updated_at": "2015-04-28T11:11:11Z",
"type": "discussion", "type": "discussion",
...@@ -152,6 +159,8 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -152,6 +159,8 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"body": "Test body", "body": "Test body",
"pinned": False, "pinned": False,
"closed": False, "closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 4},
"comments_count": 5, "comments_count": 5,
"unread_comments_count": 3, "unread_comments_count": 3,
}] }]
...@@ -159,6 +168,10 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -159,6 +168,10 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"id": "test_thread", "id": "test_thread",
"course_id": unicode(self.course.id), "course_id": unicode(self.course.id),
"topic_id": "test_topic", "topic_id": "test_topic",
"group_id": None,
"group_name": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-04-28T00:00:00Z", "created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z", "updated_at": "2015-04-28T11:11:11Z",
"type": "discussion", "type": "discussion",
...@@ -166,6 +179,10 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -166,6 +179,10 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"raw_body": "Test body", "raw_body": "Test body",
"pinned": False, "pinned": False,
"closed": False, "closed": False,
"following": False,
"abuse_flagged": False,
"voted": True,
"vote_count": 4,
"comment_count": 5, "comment_count": 5,
"unread_comment_count": 3, "unread_comment_count": 3,
}] }]
...@@ -190,6 +207,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -190,6 +207,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
}) })
def test_pagination(self): def test_pagination(self):
self.register_get_user_response(self.user)
self.register_get_threads_response([], page=1, num_pages=1) self.register_get_threads_response([], page=1, num_pages=1)
response = self.client.get( response = self.client.get(
self.url, self.url,
......
...@@ -21,6 +21,19 @@ class CommentsServiceMockMixin(object): ...@@ -21,6 +21,19 @@ class CommentsServiceMockMixin(object):
status=200 status=200
) )
def register_get_user_response(self, user, subscribed_thread_ids=None, upvoted_ids=None):
"""Register a mock response for GET on the CS user instance endpoint"""
httpretty.register_uri(
httpretty.GET,
"http://localhost:4567/api/v1/users/{id}".format(id=user.id),
body=json.dumps({
"id": str(user.id),
"subscribed_thread_ids": subscribed_thread_ids or [],
"upvoted_ids": upvoted_ids or [],
}),
status=200
)
def assert_last_query_params(self, expected_params): def assert_last_query_params(self, expected_params):
""" """
Assert that the last mock request had the expected query parameters Assert that the last mock request had the expected query parameters
......
...@@ -133,12 +133,11 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -133,12 +133,11 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
form = ThreadListGetForm(request.GET) form = ThreadListGetForm(request.GET)
if not form.is_valid(): if not form.is_valid():
raise ValidationError(form.errors) raise ValidationError(form.errors)
course_key = form.cleaned_data["course_id"] course = self.get_course_or_404(request.user, form.cleaned_data["course_id"])
self.get_course_or_404(request.user, course_key)
return Response( return Response(
get_thread_list( get_thread_list(
request, request,
course_key, course,
form.cleaned_data["page"], form.cleaned_data["page"],
form.cleaned_data["page_size"] form.cleaned_data["page_size"]
) )
......
...@@ -293,6 +293,12 @@ def get_course_cohorts(course, assignment_type=None): ...@@ -293,6 +293,12 @@ def get_course_cohorts(course, assignment_type=None):
query_set = query_set.filter(cohort__assignment_type=assignment_type) if assignment_type else query_set query_set = query_set.filter(cohort__assignment_type=assignment_type) if assignment_type else query_set
return list(query_set) return list(query_set)
def get_cohort_names(course):
"""Return a dict that maps cohort ids to names for the given course"""
return {cohort.id: cohort.name for cohort in get_course_cohorts(course)}
### Helpers for cohort management views ### Helpers for cohort management views
......
...@@ -462,6 +462,15 @@ class TestCohorts(ModuleStoreTestCase): ...@@ -462,6 +462,15 @@ class TestCohorts(ModuleStoreTestCase):
cohort_set = {c.name for c in cohorts.get_course_cohorts(course)} cohort_set = {c.name for c in cohorts.get_course_cohorts(course)}
self.assertEqual(cohort_set, {"AutoGroup1", "AutoGroup2", "ManualCohort", "ManualCohort2"}) self.assertEqual(cohort_set, {"AutoGroup1", "AutoGroup2", "ManualCohort", "ManualCohort2"})
def test_get_cohort_names(self):
course = modulestore().get_course(self.toy_course_key)
cohort1 = CohortFactory(course_id=course.id, name="Cohort1")
cohort2 = CohortFactory(course_id=course.id, name="Cohort2")
self.assertEqual(
cohorts.get_cohort_names(course),
{cohort1.id: cohort1.name, cohort2.id: cohort2.name}
)
def test_is_commentable_cohorted(self): def test_is_commentable_cohorted(self):
course = modulestore().get_course(self.toy_course_key) course = modulestore().get_course(self.toy_course_key)
self.assertFalse(cohorts.is_course_cohorted(course.id)) self.assertFalse(cohorts.is_course_cohorted(course.id))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment