Commit 1f03a4df by Greg Price

Merge pull request #7966 from edx/gprice/discussion-api-thread-list-more-fields

Add fields to threads returned by Discussion API thread list endpoint
parents e60b72cf 7ce579c2
......@@ -14,7 +14,8 @@ from django_comment_common.models import (
Role,
)
from lms.lib.comment_client.thread import Thread
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id
from lms.lib.comment_client.user import User
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, get_cohort_names
def get_course_topics(course, user):
......@@ -76,17 +77,38 @@ def get_course_topics(course, user):
}
def _cc_thread_to_api_thread(thread):
def _cc_thread_to_api_thread(thread, cc_user, staff_user_ids, ta_user_ids, group_ids_to_names):
"""
Convert a thread data dict from the comment_client format (which is a direct
representation of the format returned by the comments service) to the format
used in this API
Arguments:
thread (comment_client.thread.Thread): The thread to convert
cc_user (comment_client.user.User): The comment_client representation of
the requesting user
staff_user_ids (set): The set of user ids for users with the Moderator or
Administrator role in the course
ta_user_ids (set): The set of user ids for users with the Community TA
role in the course
group_ids_to_names (dict): A mapping of group ids to names
Returns:
dict: The discussion_api format representation of the thread.
"""
is_anonymous = (
thread["anonymous"] or
(
thread["anonymous_to_peers"] and
int(cc_user["id"]) not in (staff_user_ids | ta_user_ids)
)
)
ret = {
key: thread[key]
for key in [
"id",
"course_id",
"group_id",
"created_at",
"updated_at",
"type",
......@@ -97,21 +119,33 @@ def _cc_thread_to_api_thread(thread):
}
ret.update({
"topic_id": thread["commentable_id"],
"group_name": group_ids_to_names.get(thread["group_id"]),
"author": None if is_anonymous else thread["username"],
"author_label": (
None if is_anonymous else
"staff" if int(thread["user_id"]) in staff_user_ids else
"community_ta" if int(thread["user_id"]) in ta_user_ids else
None
),
"raw_body": thread["body"],
"following": thread["id"] in cc_user["subscribed_thread_ids"],
"abuse_flagged": cc_user["id"] in thread["abuse_flaggers"],
"voted": thread["id"] in cc_user["upvoted_ids"],
"vote_count": thread["votes"]["up_count"],
"comment_count": thread["comments_count"],
"unread_comment_count": thread["unread_comments_count"],
})
return ret
def get_thread_list(request, course_key, page, page_size):
def get_thread_list(request, course, page, page_size):
"""
Return the list of all discussion threads pertaining to the given course
Parameters:
request: The django request objects used for build_absolute_uri
course_key: The key of the course to get discussion threads for
course: The course to get discussion threads for
page: The page number (1-indexed) to retrieve
page_size: The number of threads to retrieve per page
......@@ -121,13 +155,14 @@ def get_thread_list(request, course_key, page, page_size):
discussion_api.views.ThreadViewSet for more detail.
"""
user_is_privileged = Role.objects.filter(
course_id=course_key,
course_id=course.id,
name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR, FORUM_ROLE_COMMUNITY_TA],
users=request.user
).exists()
cc_user = User.from_django_user(request.user).retrieve()
threads, result_page, num_pages, _ = Thread.search({
"course_id": unicode(course_key),
"group_id": None if user_is_privileged else get_cohort_id(request.user, course_key),
"course_id": unicode(course.id),
"group_id": None if user_is_privileged else get_cohort_id(request.user, course.id),
"sort_key": "date",
"sort_order": "desc",
"page": page,
......@@ -138,6 +173,25 @@ def get_thread_list(request, course_key, page, page_size):
# behavior and return a 404 in that case
if result_page != page:
raise Http404
# TODO: cache staff_user_ids and ta_user_ids if we need to improve perf
staff_user_ids = {
user.id
for role in Role.objects.filter(
name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR],
course_id=course.id
)
for user in role.users.all()
}
ta_user_ids = {
user.id
for role in Role.objects.filter(name=FORUM_ROLE_COMMUNITY_TA, course_id=course.id)
for user in role.users.all()
}
# For now, the only groups are cohorts
group_ids_to_names = get_cohort_names(course)
results = [_cc_thread_to_api_thread(thread) for thread in threads]
results = [
_cc_thread_to_api_thread(thread, cc_user, staff_user_ids, ta_user_ids, group_ids_to_names)
for thread in threads
]
return get_paginated_data(request, results, page, num_pages)
......@@ -121,6 +121,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for ThreadViewSet list"""
def setUp(self):
super(ThreadViewSetListTest, self).setUp()
self.author = UserFactory.create()
self.url = reverse("thread-list")
def test_course_id_missing(self):
......@@ -141,10 +142,16 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
)
def test_basic(self):
self.register_get_user_response(self.user, upvoted_ids=["test_thread"])
source_threads = [{
"id": "test_thread",
"course_id": unicode(self.course.id),
"commentable_id": "test_topic",
"group_id": None,
"user_id": str(self.author.id),
"username": self.author.username,
"anonymous": False,
"anonymous_to_peers": False,
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"type": "discussion",
......@@ -152,6 +159,8 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"body": "Test body",
"pinned": False,
"closed": False,
"abuse_flaggers": [],
"votes": {"up_count": 4},
"comments_count": 5,
"unread_comments_count": 3,
}]
......@@ -159,6 +168,10 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"id": "test_thread",
"course_id": unicode(self.course.id),
"topic_id": "test_topic",
"group_id": None,
"group_name": None,
"author": self.author.username,
"author_label": None,
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"type": "discussion",
......@@ -166,6 +179,10 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"raw_body": "Test body",
"pinned": False,
"closed": False,
"following": False,
"abuse_flagged": False,
"voted": True,
"vote_count": 4,
"comment_count": 5,
"unread_comment_count": 3,
}]
......@@ -190,6 +207,7 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
})
def test_pagination(self):
self.register_get_user_response(self.user)
self.register_get_threads_response([], page=1, num_pages=1)
response = self.client.get(
self.url,
......
......@@ -21,6 +21,19 @@ class CommentsServiceMockMixin(object):
status=200
)
def register_get_user_response(self, user, subscribed_thread_ids=None, upvoted_ids=None):
"""Register a mock response for GET on the CS user instance endpoint"""
httpretty.register_uri(
httpretty.GET,
"http://localhost:4567/api/v1/users/{id}".format(id=user.id),
body=json.dumps({
"id": str(user.id),
"subscribed_thread_ids": subscribed_thread_ids or [],
"upvoted_ids": upvoted_ids or [],
}),
status=200
)
def assert_last_query_params(self, expected_params):
"""
Assert that the last mock request had the expected query parameters
......
......@@ -133,12 +133,11 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
form = ThreadListGetForm(request.GET)
if not form.is_valid():
raise ValidationError(form.errors)
course_key = form.cleaned_data["course_id"]
self.get_course_or_404(request.user, course_key)
course = self.get_course_or_404(request.user, form.cleaned_data["course_id"])
return Response(
get_thread_list(
request,
course_key,
course,
form.cleaned_data["page"],
form.cleaned_data["page_size"]
)
......
......@@ -293,6 +293,12 @@ def get_course_cohorts(course, assignment_type=None):
query_set = query_set.filter(cohort__assignment_type=assignment_type) if assignment_type else query_set
return list(query_set)
def get_cohort_names(course):
"""Return a dict that maps cohort ids to names for the given course"""
return {cohort.id: cohort.name for cohort in get_course_cohorts(course)}
### Helpers for cohort management views
......
......@@ -462,6 +462,15 @@ class TestCohorts(ModuleStoreTestCase):
cohort_set = {c.name for c in cohorts.get_course_cohorts(course)}
self.assertEqual(cohort_set, {"AutoGroup1", "AutoGroup2", "ManualCohort", "ManualCohort2"})
def test_get_cohort_names(self):
course = modulestore().get_course(self.toy_course_key)
cohort1 = CohortFactory(course_id=course.id, name="Cohort1")
cohort2 = CohortFactory(course_id=course.id, name="Cohort2")
self.assertEqual(
cohorts.get_cohort_names(course),
{cohort1.id: cohort1.name, cohort2.id: cohort2.name}
)
def test_is_commentable_cohorted(self):
course = modulestore().get_course(self.toy_course_key)
self.assertFalse(cohorts.is_course_cohorted(course.id))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment