Commit f552eca7 by christopher lee

MA-642 Add topic_id query to discussion/thread api

Also added thread_list_url field.
parent f81d09cb
""" """
Discussion API internal interface Discussion API internal interface
""" """
from urllib import urlencode
from urlparse import urlunparse
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.urlresolvers import reverse
from django.http import Http404 from django.http import Http404
from collections import defaultdict from collections import defaultdict
from opaque_keys import InvalidKeyError from opaque_keys import InvalidKeyError
...@@ -39,7 +44,16 @@ def _get_course_or_404(course_key, user): ...@@ -39,7 +44,16 @@ def _get_course_or_404(course_key, user):
return course return course
def get_course_topics(course_key, user): def get_thread_list_url(request, course_key, topic_id_list):
"""
Returns the URL for the thread_list_url field, given a list of topic_ids
"""
path = reverse("thread-list")
query_list = [("course_id", unicode(course_key))] + [("topic_id", topic_id) for topic_id in topic_id_list]
return request.build_absolute_uri(urlunparse(("", "", path, "", urlencode(query_list), "")))
def get_course_topics(request, course_key):
""" """
Return the course topic listing for the given course and user. Return the course topic listing for the given course and user.
...@@ -60,22 +74,33 @@ def get_course_topics(course_key, user): ...@@ -60,22 +74,33 @@ def get_course_topics(course_key, user):
""" """
return module.sort_key or module.discussion_target return module.sort_key or module.discussion_target
course = _get_course_or_404(course_key, user) course = _get_course_or_404(course_key, request.user)
discussion_modules = get_accessible_discussion_modules(course, user) discussion_modules = get_accessible_discussion_modules(course, request.user)
modules_by_category = defaultdict(list) modules_by_category = defaultdict(list)
for module in discussion_modules: for module in discussion_modules:
modules_by_category[module.discussion_category].append(module) modules_by_category[module.discussion_category].append(module)
def get_sorted_modules(category):
"""Returns key sorted modules by category"""
return sorted(modules_by_category[category], key=get_module_sort_key)
courseware_topics = [ courseware_topics = [
{ {
"id": None, "id": None,
"name": category, "name": category,
"thread_list_url": get_thread_list_url(
request,
course_key,
[item.discussion_id for item in get_sorted_modules(category)]
),
"children": [ "children": [
{ {
"id": module.discussion_id, "id": module.discussion_id,
"name": module.discussion_target, "name": module.discussion_target,
"thread_list_url": get_thread_list_url(request, course_key, [module.discussion_id]),
"children": [], "children": [],
} }
for module in sorted(modules_by_category[category], key=get_module_sort_key) for module in get_sorted_modules(category)
], ],
} }
for category in sorted(modules_by_category.keys()) for category in sorted(modules_by_category.keys())
...@@ -85,6 +110,7 @@ def get_course_topics(course_key, user): ...@@ -85,6 +110,7 @@ def get_course_topics(course_key, user):
{ {
"id": entry["id"], "id": entry["id"],
"name": name, "name": name,
"thread_list_url": get_thread_list_url(request, course_key, [entry["id"]]),
"children": [], "children": [],
} }
for name, entry in sorted( for name, entry in sorted(
...@@ -99,7 +125,7 @@ def get_course_topics(course_key, user): ...@@ -99,7 +125,7 @@ def get_course_topics(course_key, user):
} }
def get_thread_list(request, course_key, page, page_size): def get_thread_list(request, course_key, page, page_size, topic_id_list=None):
""" """
Return the list of all discussion threads pertaining to the given course Return the list of all discussion threads pertaining to the given course
...@@ -109,6 +135,7 @@ def get_thread_list(request, course_key, page, page_size): ...@@ -109,6 +135,7 @@ def get_thread_list(request, course_key, page, page_size):
course_key: The key of the course to get discussion threads for course_key: The key of the course to get discussion threads for
page: The page number (1-indexed) to retrieve page: The page number (1-indexed) to retrieve
page_size: The number of threads to retrieve per page page_size: The number of threads to retrieve per page
topic_id_list: The list of topic_ids to get the discussion threads for
Returns: Returns:
...@@ -117,6 +144,7 @@ def get_thread_list(request, course_key, page, page_size): ...@@ -117,6 +144,7 @@ def get_thread_list(request, course_key, page, page_size):
""" """
course = _get_course_or_404(course_key, request.user) course = _get_course_or_404(course_key, request.user)
context = get_context(course, request) context = get_context(course, request)
topic_ids_csv = ",".join(topic_id_list) if topic_id_list else None
threads, result_page, num_pages, _ = Thread.search({ threads, result_page, num_pages, _ = Thread.search({
"course_id": unicode(course.id), "course_id": unicode(course.id),
"group_id": ( "group_id": (
...@@ -127,6 +155,7 @@ def get_thread_list(request, course_key, page, page_size): ...@@ -127,6 +155,7 @@ def get_thread_list(request, course_key, page, page_size):
"sort_order": "desc", "sort_order": "desc",
"page": page, "page": page,
"per_page": page_size, "per_page": page_size,
"commentable_ids": topic_ids_csv,
}) })
# The comments service returns the last page of results if the requested # The comments service returns the last page of results if the requested
# page is beyond the last page, but we want be consistent with DRF's general # page is beyond the last page, but we want be consistent with DRF's general
......
...@@ -2,12 +2,31 @@ ...@@ -2,12 +2,31 @@
Discussion API forms Discussion API forms
""" """
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.forms import BooleanField, CharField, Form, IntegerField, NullBooleanField from django.forms import (
BooleanField,
CharField,
Field,
Form,
IntegerField,
MultipleHiddenInput,
NullBooleanField,
)
from opaque_keys import InvalidKeyError from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
class TopicIdField(Field):
"""
Field for a list of topic_ids
"""
widget = MultipleHiddenInput
def validate(self, value):
if value and "" in value:
raise ValidationError("This field cannot be empty.")
class _PaginationForm(Form): class _PaginationForm(Form):
"""A form that includes pagination fields""" """A form that includes pagination fields"""
page = IntegerField(required=False, min_value=1) page = IntegerField(required=False, min_value=1)
...@@ -27,6 +46,7 @@ class ThreadListGetForm(_PaginationForm): ...@@ -27,6 +46,7 @@ class ThreadListGetForm(_PaginationForm):
A form to validate query parameters in the thread list retrieval endpoint A form to validate query parameters in the thread list retrieval endpoint
""" """
course_id = CharField() course_id = CharField()
topic_id = TopicIdField(required=False)
def clean_course_id(self): def clean_course_id(self):
"""Validate course_id""" """Validate course_id"""
......
...@@ -3,7 +3,8 @@ Tests for Discussion API internal interface ...@@ -3,7 +3,8 @@ Tests for Discussion API internal interface
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta
import itertools import itertools
from urlparse import urlparse from urlparse import urlparse, urlunparse
from urllib import urlencode
import ddt import ddt
import httpretty import httpretty
...@@ -58,9 +59,10 @@ def _remove_discussion_tab(course, user_id): ...@@ -58,9 +59,10 @@ def _remove_discussion_tab(course, user_id):
@mock.patch.dict("django.conf.settings.FEATURES", {"DISABLE_START_DATES": False}) @mock.patch.dict("django.conf.settings.FEATURES", {"DISABLE_START_DATES": False})
class GetCourseTopicsTest(ModuleStoreTestCase): class GetCourseTopicsTest(UrlResetMixin, ModuleStoreTestCase):
"""Test for get_course_topics""" """Test for get_course_topics"""
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
def setUp(self): def setUp(self):
super(GetCourseTopicsTest, self).setUp() super(GetCourseTopicsTest, self).setUp()
self.maxDiff = None # pylint: disable=invalid-name self.maxDiff = None # pylint: disable=invalid-name
...@@ -82,6 +84,8 @@ class GetCourseTopicsTest(ModuleStoreTestCase): ...@@ -82,6 +84,8 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
days_early_for_beta=3 days_early_for_beta=3
) )
self.user = UserFactory.create() self.user = UserFactory.create()
self.request = RequestFactory().get("/dummy")
self.request.user = self.user
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
def make_discussion_module(self, topic_id, category, subcategory, **kwargs): def make_discussion_module(self, topic_id, category, subcategory, **kwargs):
...@@ -95,34 +99,46 @@ class GetCourseTopicsTest(ModuleStoreTestCase): ...@@ -95,34 +99,46 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
**kwargs **kwargs
) )
def get_course_topics(self, user=None): def get_thread_list_url(self, topic_id_list):
"""
Returns the URL for the thread_list_url field, given a list of topic_ids
"""
path = "http://testserver/api/discussion/v1/threads/"
query_list = [("course_id", unicode(self.course.id))] + [("topic_id", topic_id) for topic_id in topic_id_list]
return urlunparse(("", "", path, "", urlencode(query_list), ""))
def get_course_topics(self):
""" """
Get course topics for self.course, using the given user or self.user if Get course topics for self.course, using the given user or self.user if
not provided, and generating absolute URIs with a test scheme/host. not provided, and generating absolute URIs with a test scheme/host.
""" """
return get_course_topics(self.course.id, user or self.user) return get_course_topics(self.request, self.course.id)
def make_expected_tree(self, topic_id, name, children=None): def make_expected_tree(self, topic_id, name, children=None):
""" """
Build an expected result tree given a topic id, display name, and Build an expected result tree given a topic id, display name, and
children children
""" """
topic_id_list = [topic_id] if topic_id else [child["id"] for child in children]
children = children or [] children = children or []
node = { node = {
"id": topic_id, "id": topic_id,
"name": name, "name": name,
"children": children, "children": children,
"thread_list_url": self.get_thread_list_url(topic_id_list)
} }
return node return node
def test_nonexistent_course(self): def test_nonexistent_course(self):
with self.assertRaises(Http404): with self.assertRaises(Http404):
get_course_topics(CourseLocator.from_string("non/existent/course"), self.user) get_course_topics(self.request, CourseLocator.from_string("non/existent/course"))
def test_not_enrolled(self): def test_not_enrolled(self):
unenrolled_user = UserFactory.create() unenrolled_user = UserFactory.create()
self.request.user = unenrolled_user
with self.assertRaises(Http404): with self.assertRaises(Http404):
get_course_topics(self.course.id, unenrolled_user) self.get_course_topics()
def test_discussions_disabled(self): def test_discussions_disabled(self):
_remove_discussion_tab(self.course, self.user.id) _remove_discussion_tab(self.course, self.user.id)
...@@ -311,8 +327,8 @@ class GetCourseTopicsTest(ModuleStoreTestCase): ...@@ -311,8 +327,8 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
], ],
} }
self.assertEqual(student_actual, student_expected) self.assertEqual(student_actual, student_expected)
self.request.user = beta_tester
beta_actual = self.get_course_topics(beta_tester) beta_actual = self.get_course_topics()
beta_expected = { beta_expected = {
"courseware_topics": [ "courseware_topics": [
self.make_expected_tree( self.make_expected_tree(
...@@ -335,7 +351,8 @@ class GetCourseTopicsTest(ModuleStoreTestCase): ...@@ -335,7 +351,8 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
} }
self.assertEqual(beta_actual, beta_expected) self.assertEqual(beta_actual, beta_expected)
staff_actual = self.get_course_topics(staff) self.request.user = staff
staff_actual = self.get_course_topics()
staff_expected = { staff_expected = {
"courseware_topics": [ "courseware_topics": [
self.make_expected_tree( self.make_expected_tree(
...@@ -382,14 +399,14 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -382,14 +399,14 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
self.author = UserFactory.create() self.author = UserFactory.create()
self.cohort = CohortFactory.create(course_id=self.course.id) self.cohort = CohortFactory.create(course_id=self.course.id)
def get_thread_list(self, threads, page=1, page_size=1, num_pages=1, course=None): def get_thread_list(self, threads, page=1, page_size=1, num_pages=1, course=None, topic_id_list=None):
""" """
Register the appropriate comments service response, then call Register the appropriate comments service response, then call
get_thread_list and return the result. get_thread_list and return the result.
""" """
course = course or self.course course = course or self.course
self.register_get_threads_response(threads, page, num_pages) self.register_get_threads_response(threads, page, num_pages)
ret = get_thread_list(self.request, course.id, page, page_size) ret = get_thread_list(self.request, course.id, page, page_size, topic_id_list)
return ret return ret
def test_nonexistent_course(self): def test_nonexistent_course(self):
...@@ -416,6 +433,19 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest ...@@ -416,6 +433,19 @@ class GetThreadListTest(CommentsServiceMockMixin, UrlResetMixin, ModuleStoreTest
} }
) )
def test_get_threads_by_topic_id(self):
self.get_thread_list([], topic_id_list=["topic_x", "topic_meow"])
self.assertEqual(urlparse(httpretty.last_request().path).path, "/api/v1/threads")
self.assert_last_query_params({
"course_id": [unicode(self.course.id)],
"sort_key": ["date"],
"sort_order": ["desc"],
"page": ["1"],
"per_page": ["1"],
"recursive": ["False"],
"commentable_ids": ["topic_x,topic_meow"]
})
def test_basic_query_params(self): def test_basic_query_params(self):
self.get_thread_list([], page=6, page_size=14) self.get_thread_list([], page=6, page_size=14)
self.assert_last_query_params({ self.assert_last_query_params({
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
Tests for Discussion API forms Tests for Discussion API forms
""" """
from unittest import TestCase from unittest import TestCase
from urllib import urlencode
from django.http import QueryDict
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
...@@ -66,13 +69,19 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): ...@@ -66,13 +69,19 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
def setUp(self): def setUp(self):
super(ThreadListGetFormTest, self).setUp() super(ThreadListGetFormTest, self).setUp()
self.form_data = { self.form_data = QueryDict(
"course_id": "Foo/Bar/Baz", urlencode(
"page": "2", {
"page_size": "13", "course_id": "Foo/Bar/Baz",
} "page": "2",
"page_size": "13",
}
),
mutable=True
)
def test_basic(self): def test_basic(self):
self.form_data.setlist("topic_id", ["example topic_id", "example 2nd topic_id"])
form = self.get_form(expected_valid=True) form = self.get_form(expected_valid=True)
self.assertEqual( self.assertEqual(
form.cleaned_data, form.cleaned_data,
...@@ -80,6 +89,7 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): ...@@ -80,6 +89,7 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"course_id": CourseLocator.from_string("Foo/Bar/Baz"), "course_id": CourseLocator.from_string("Foo/Bar/Baz"),
"page": 2, "page": 2,
"page_size": 13, "page_size": 13,
"topic_id": ["example topic_id", "example 2nd topic_id"],
} }
) )
...@@ -91,6 +101,10 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): ...@@ -91,6 +101,10 @@ class ThreadListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
self.form_data["course_id"] = "invalid course id" self.form_data["course_id"] = "invalid course id"
self.assert_error("course_id", "'invalid course id' is not a valid course id") self.assert_error("course_id", "'invalid course id' is not a valid course id")
def test_empty_topic_id(self):
self.form_data.setlist("topic_id", ["", "not empty"])
self.assert_error("topic_id", "This field cannot be empty.")
class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase): class CommentListGetFormTest(FormTestMixin, PaginationTestMixin, TestCase):
"""Tests for CommentListGetForm""" """Tests for CommentListGetForm"""
......
...@@ -85,7 +85,9 @@ class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase): ...@@ -85,7 +85,9 @@ class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"non_courseware_topics": [{ "non_courseware_topics": [{
"id": "test_topic", "id": "test_topic",
"name": "Test Topic", "name": "Test Topic",
"children": [] "children": [],
"thread_list_url":
"http://testserver/api/discussion/v1/threads/?course_id=x%2Fy%2Fz&topic_id=test_topic",
}], }],
} }
) )
......
...@@ -60,7 +60,7 @@ class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView): ...@@ -60,7 +60,7 @@ class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView):
def get(self, request, course_id): def get(self, request, course_id):
"""Implements the GET method as described in the class docstring.""" """Implements the GET method as described in the class docstring."""
course_key = CourseLocator.from_string(course_id) course_key = CourseLocator.from_string(course_id)
return Response(get_course_topics(course_key, request.user)) return Response(get_course_topics(request, course_key))
class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
...@@ -90,6 +90,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -90,6 +90,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
* page_size: The number of items per page (default is 10, max is 100) * page_size: The number of items per page (default is 10, max is 100)
* topic_id: The id of the topic to retrieve the threads. There can be
multiple topic_id queries to retrieve threads from multiple topics
at once.
**POST Parameters**: **POST Parameters**:
* course_id (required): The course to create the thread in * course_id (required): The course to create the thread in
...@@ -157,7 +161,8 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet): ...@@ -157,7 +161,8 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
request, request,
form.cleaned_data["course_id"], form.cleaned_data["course_id"],
form.cleaned_data["page"], form.cleaned_data["page"],
form.cleaned_data["page_size"] form.cleaned_data["page_size"],
form.cleaned_data["topic_id"],
) )
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment