Commit 90f28dce by Greg Price

Add thread list endpoint to the new discussion API

This is an initial implementation that only allows retrieval of all
threads for a course and only returns an easily computed subset of the
fields that are needed, in order to keep this change from getting too
large.

JIRA: MA-641
parent 73bbb892
""" """
Discussion API internal interface Discussion API internal interface
""" """
from django.http import Http404
from collections import defaultdict from collections import defaultdict
from lms.lib.comment_client.thread import Thread
from discussion_api.pagination import get_paginated_data
from django_comment_client.utils import get_accessible_discussion_modules from django_comment_client.utils import get_accessible_discussion_modules
...@@ -63,3 +67,62 @@ def get_course_topics(course, user): ...@@ -63,3 +67,62 @@ def get_course_topics(course, user):
"courseware_topics": courseware_topics, "courseware_topics": courseware_topics,
"non_courseware_topics": non_courseware_topics, "non_courseware_topics": non_courseware_topics,
} }
def _cc_thread_to_api_thread(thread):
"""
Convert a thread data dict from the comment_client format (which is a direct
representation of the format returned by the comments service) to the format
used in this API
"""
ret = {
key: thread[key]
for key in [
"id",
"course_id",
"created_at",
"updated_at",
"type",
"title",
"pinned",
"closed",
]
}
ret.update({
"topic_id": thread["commentable_id"],
"raw_body": thread["body"],
"comment_count": thread["comments_count"],
"unread_comment_count": thread["unread_comments_count"],
})
return ret
def get_thread_list(request, course_key, page, page_size):
"""
Return the list of all discussion threads pertaining to the given course
Parameters:
request: The django request objects used for build_absolute_uri
course_key: The key of the course to get discussion threads for
page: The page number (1-indexed) to retrieve
page_size: The number of threads to retrieve per page
Returns:
A paginated result containing a list of threads; see
discussion_api.views.ThreadViewSet for more detail.
"""
threads, result_page, num_pages, _ = Thread.search({
"course_id": unicode(course_key),
"page": page,
"per_page": page_size
})
# The comments service returns the last page of results if the requested
# page is beyond the last page, but we want be consistent with DRF's general
# behavior and return a 404 in that case
if result_page != page:
raise Http404
results = [_cc_thread_to_api_thread(thread) for thread in threads]
return get_paginated_data(request, results, page, num_pages)
"""
Discussion API forms
"""
from django.core.exceptions import ValidationError
from django.forms import Form, CharField, IntegerField
from opaque_keys import InvalidKeyError
from opaque_keys.edx.locator import CourseLocator
class ThreadListGetForm(Form):
"""
A form to validate query parameters in the thread list retrieval endpoint
"""
course_id = CharField()
page = IntegerField(required=False, min_value=1)
page_size = IntegerField(required=False, min_value=1)
def clean_course_id(self):
"""Validate course_id"""
value = self.cleaned_data["course_id"]
try:
return CourseLocator.from_string(value)
except InvalidKeyError:
raise ValidationError("'{}' is not a valid course id".format(value))
def clean_page(self):
"""Return given valid page or default of 1"""
return self.cleaned_data.get("page") or 1
def clean_page_size(self):
"""Return given valid page_size (capped at 100) or default of 10"""
return min(self.cleaned_data.get("page_size") or 10, 100)
"""
Discussion API pagination support
"""
from rest_framework.pagination import BasePaginationSerializer, NextPageField, PreviousPageField
class _PaginationSerializer(BasePaginationSerializer):
"""
A pagination serializer without the count field, because the Comments
Service does not return result counts
"""
next = NextPageField(source="*")
previous = PreviousPageField(source="*")
class _Page(object):
"""
Implements just enough of the django.core.paginator.Page interface to allow
PaginationSerializer to work.
"""
def __init__(self, object_list, page_num, num_pages):
"""
Create a new page containing the given objects, with the given page
number and number of pages
"""
self.object_list = object_list
self.page_num = page_num
self.num_pages = num_pages
def has_next(self):
"""Returns True if there is a page after this one, otherwise False"""
return self.page_num < self.num_pages
def has_previous(self):
"""Returns True if there is a page before this one, otherwise False"""
return self.page_num > 1
def next_page_number(self):
"""Returns the number of the next page"""
return self.page_num + 1
def previous_page_number(self):
"""Returns the number of the previous page"""
return self.page_num - 1
def get_paginated_data(request, results, page_num, per_page):
"""
Return a dict with the following values:
next: The URL for the next page
previous: The URL for the previous page
results: The results on this page
"""
return _PaginationSerializer(
instance=_Page(results, page_num, per_page),
context={"request": request}
).data
...@@ -3,11 +3,18 @@ Tests for Discussion API internal interface ...@@ -3,11 +3,18 @@ Tests for Discussion API internal interface
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta
import httpretty
import mock import mock
from pytz import UTC from pytz import UTC
from django.http import Http404
from django.test.client import RequestFactory
from opaque_keys.edx.locator import CourseLocator
from courseware.tests.factories import BetaTesterFactory, StaffFactory from courseware.tests.factories import BetaTesterFactory, StaffFactory
from discussion_api.api import get_course_topics from discussion_api.api import get_course_topics, get_thread_list
from discussion_api.tests.utils import CommentsServiceMockMixin
from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup from openedx.core.djangoapps.course_groups.models import CourseUserGroupPartitionGroup
from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory from openedx.core.djangoapps.course_groups.tests.helpers import CohortFactory
from student.tests.factories import UserFactory from student.tests.factories import UserFactory
...@@ -304,3 +311,171 @@ class GetCourseTopicsTest(ModuleStoreTestCase): ...@@ -304,3 +311,171 @@ class GetCourseTopicsTest(ModuleStoreTestCase):
"non_courseware_topics": [], "non_courseware_topics": [],
} }
self.assertEqual(staff_actual, staff_expected) self.assertEqual(staff_actual, staff_expected)
@httpretty.activate
class GetThreadListTest(CommentsServiceMockMixin, ModuleStoreTestCase):
"""Test for get_thread_list"""
def setUp(self):
super(GetThreadListTest, self).setUp()
self.maxDiff = None # pylint: disable=invalid-name
self.request = RequestFactory().get("/test_path")
self.course_key = CourseLocator.from_string("a/b/c")
def get_thread_list(self, threads, page=1, page_size=1, num_pages=1):
"""
Register the appropriate comments service response, then call
get_thread_list and return the result.
"""
self.register_get_threads_response(threads, page, num_pages)
ret = get_thread_list(self.request, self.course_key, page, page_size)
return ret
def test_empty(self):
self.assertEqual(
self.get_thread_list([]),
{
"results": [],
"next": None,
"previous": None,
}
)
def test_basic_query_params(self):
self.get_thread_list([], page=6, page_size=14)
self.assert_last_query_params({
"course_id": [unicode(self.course_key)],
"page": ["6"],
"per_page": ["14"],
"recursive": ["False"],
})
def test_thread_content(self):
source_threads = [
{
"id": "test_thread_id_0",
"course_id": unicode(self.course_key),
"commentable_id": "topic_x",
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"type": "discussion",
"title": "Test Title",
"body": "Test body",
"pinned": False,
"closed": False,
"comments_count": 5,
"unread_comments_count": 3,
},
{
"id": "test_thread_id_1",
"course_id": unicode(self.course_key),
"commentable_id": "topic_y",
"created_at": "2015-04-28T22:22:22Z",
"updated_at": "2015-04-28T00:33:33Z",
"type": "question",
"title": "Another Test Title",
"body": "More content",
"pinned": False,
"closed": True,
"comments_count": 18,
"unread_comments_count": 0,
},
{
"id": "test_thread_id_2",
"course_id": unicode(self.course_key),
"commentable_id": "topic_x",
"created_at": "2015-04-28T00:44:44Z",
"updated_at": "2015-04-28T00:55:55Z",
"type": "discussion",
"title": "Yet Another Test Title",
"body": "Still more content",
"pinned": True,
"closed": False,
"comments_count": 0,
"unread_comments_count": 0,
},
]
expected_threads = [
{
"id": "test_thread_id_0",
"course_id": unicode(self.course_key),
"topic_id": "topic_x",
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"type": "discussion",
"title": "Test Title",
"raw_body": "Test body",
"pinned": False,
"closed": False,
"comment_count": 5,
"unread_comment_count": 3,
},
{
"id": "test_thread_id_1",
"course_id": unicode(self.course_key),
"topic_id": "topic_y",
"created_at": "2015-04-28T22:22:22Z",
"updated_at": "2015-04-28T00:33:33Z",
"type": "question",
"title": "Another Test Title",
"raw_body": "More content",
"pinned": False,
"closed": True,
"comment_count": 18,
"unread_comment_count": 0,
},
{
"id": "test_thread_id_2",
"course_id": unicode(self.course_key),
"topic_id": "topic_x",
"created_at": "2015-04-28T00:44:44Z",
"updated_at": "2015-04-28T00:55:55Z",
"type": "discussion",
"title": "Yet Another Test Title",
"raw_body": "Still more content",
"pinned": True,
"closed": False,
"comment_count": 0,
"unread_comment_count": 0,
},
]
self.assertEqual(
self.get_thread_list(source_threads),
{
"results": expected_threads,
"next": None,
"previous": None,
}
)
def test_pagination(self):
# N.B. Empty thread list is not realistic but convenient for this test
self.assertEqual(
self.get_thread_list([], page=1, num_pages=3),
{
"results": [],
"next": "http://testserver/test_path?page=2",
"previous": None,
}
)
self.assertEqual(
self.get_thread_list([], page=2, num_pages=3),
{
"results": [],
"next": "http://testserver/test_path?page=3",
"previous": "http://testserver/test_path?page=1",
}
)
self.assertEqual(
self.get_thread_list([], page=3, num_pages=3),
{
"results": [],
"next": None,
"previous": "http://testserver/test_path?page=2",
}
)
# Test page past the last one
self.register_get_threads_response([], page=3, num_pages=3)
with self.assertRaises(Http404):
get_thread_list(self.request, self.course_key, page=4, page_size=10)
"""
Tests for Discussion API forms
"""
from unittest import TestCase
from opaque_keys.edx.locator import CourseLocator
from discussion_api.forms import ThreadListGetForm
class ThreadListGetFormTest(TestCase):
"""Tests for ThreadListGetForm"""
def setUp(self):
super(ThreadListGetFormTest, self).setUp()
self.form_data = {
"course_id": "Foo/Bar/Baz",
"page": "2",
"page_size": "13",
}
def get_form(self, expected_valid):
"""
Return a form bound to self.form_data, asserting its validity (or lack
thereof) according to expected_valid
"""
form = ThreadListGetForm(self.form_data)
self.assertEqual(form.is_valid(), expected_valid)
return form
def assert_error(self, expected_field, expected_message):
"""
Create a form bound to self.form_data, assert its invalidity, and assert
that its error dictionary contains one entry with the expected field and
message
"""
form = self.get_form(expected_valid=False)
self.assertEqual(form.errors, {expected_field: [expected_message]})
def assert_field_value(self, field, expected_value):
"""
Create a form bound to self.form_data, assert its validity, and assert
that the given field in the cleaned data has the expected value
"""
form = self.get_form(expected_valid=True)
self.assertEqual(form.cleaned_data[field], expected_value)
def test_basic(self):
form = self.get_form(expected_valid=True)
self.assertEqual(
form.cleaned_data,
{
"course_id": CourseLocator.from_string("Foo/Bar/Baz"),
"page": 2,
"page_size": 13,
}
)
def test_missing_course_id(self):
self.form_data.pop("course_id")
self.assert_error("course_id", "This field is required.")
def test_invalid_course_id(self):
self.form_data["course_id"] = "invalid course id"
self.assert_error("course_id", "'invalid course id' is not a valid course id")
def test_missing_page(self):
self.form_data.pop("page")
self.assert_field_value("page", 1)
def test_invalid_page(self):
self.form_data["page"] = "0"
self.assert_error("page", "Ensure this value is greater than or equal to 1.")
def test_missing_page_size(self):
self.form_data.pop("page_size")
self.assert_field_value("page_size", 10)
def test_zero_page_size(self):
self.form_data["page_size"] = "0"
self.assert_error("page_size", "Ensure this value is greater than or equal to 1.")
def test_excessive_page_size(self):
self.form_data["page_size"] = "101"
self.assert_field_value("page_size", 100)
"""
Tests for Discussion API pagination support
"""
from unittest import TestCase
from django.test import RequestFactory
from discussion_api.pagination import get_paginated_data
class PaginationSerializerTest(TestCase):
"""Tests for PaginationSerializer"""
def do_case(self, objects, page_num, num_pages, expected):
"""
Make a dummy request, and assert that get_paginated_data with the given
parameters returns the expected result
"""
request = RequestFactory().get("/test")
actual = get_paginated_data(request, objects, page_num, num_pages)
self.assertEqual(actual, expected)
def test_empty(self):
self.do_case(
[], 1, 0,
{
"next": None,
"previous": None,
"results": [],
}
)
def test_only_page(self):
self.do_case(
["foo"], 1, 1,
{
"next": None,
"previous": None,
"results": ["foo"],
}
)
def test_first_of_many(self):
self.do_case(
["foo"], 1, 3,
{
"next": "http://testserver/test?page=2",
"previous": None,
"results": ["foo"],
}
)
def test_last_of_many(self):
self.do_case(
["foo"], 3, 3,
{
"next": None,
"previous": "http://testserver/test?page=2",
"results": ["foo"],
}
)
def test_middle_of_many(self):
self.do_case(
["foo"], 2, 3,
{
"next": "http://testserver/test?page=3",
"previous": "http://testserver/test?page=1",
"results": ["foo"],
}
)
...@@ -4,11 +4,13 @@ Tests for Discussion API views ...@@ -4,11 +4,13 @@ Tests for Discussion API views
from datetime import datetime from datetime import datetime
import json import json
import httpretty
import mock import mock
from pytz import UTC from pytz import UTC
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from discussion_api.tests.utils import CommentsServiceMockMixin
from student.tests.factories import CourseEnrollmentFactory, UserFactory from student.tests.factories import CourseEnrollmentFactory, UserFactory
from util.testing import UrlResetMixin from util.testing import UrlResetMixin
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
...@@ -17,12 +19,16 @@ from xmodule.modulestore.tests.factories import CourseFactory ...@@ -17,12 +19,16 @@ from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.tabs import DiscussionTab from xmodule.tabs import DiscussionTab
class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase): class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin):
"""Tests for CourseTopicsView""" """
Mixin for common code in tests of Discussion API views. This includes
creation of common structures (e.g. a course, user, and enrollment), logging
in the test client, utility functions, and a test case for unauthenticated
requests. Subclasses must set self.url in their setUp methods.
"""
@mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True}) @mock.patch.dict("django.conf.settings.FEATURES", {"ENABLE_DISCUSSION_SERVICE": True})
def setUp(self): def setUp(self):
super(CourseTopicsViewTest, self).setUp() super(DiscussionAPIViewTestMixin, self).setUp()
self.maxDiff = None # pylint: disable=invalid-name self.maxDiff = None # pylint: disable=invalid-name
self.course = CourseFactory.create( self.course = CourseFactory.create(
org="x", org="x",
...@@ -34,9 +40,13 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase): ...@@ -34,9 +40,13 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase):
self.password = "password" self.password = "password"
self.user = UserFactory.create(password=self.password) self.user = UserFactory.create(password=self.password)
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id) CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
self.url = reverse("course_topics", kwargs={"course_id": unicode(self.course.id)})
self.client.login(username=self.user.username, password=self.password) self.client.login(username=self.user.username, password=self.password)
def login_unenrolled_user(self):
"""Create a user not enrolled in the course and log it in"""
unenrolled_user = UserFactory.create(password=self.password)
self.client.login(username=unenrolled_user.username, password=self.password)
def assert_response_correct(self, response, expected_status, expected_content): def assert_response_correct(self, response, expected_status, expected_content):
""" """
Assert that the response has the given status code and parsed content Assert that the response has the given status code and parsed content
...@@ -54,6 +64,13 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase): ...@@ -54,6 +64,13 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase):
{"developer_message": "Authentication credentials were not provided."} {"developer_message": "Authentication credentials were not provided."}
) )
class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for CourseTopicsView"""
def setUp(self):
super(CourseTopicsViewTest, self).setUp()
self.url = reverse("course_topics", kwargs={"course_id": unicode(self.course.id)})
def test_non_existent_course(self): def test_non_existent_course(self):
response = self.client.get( response = self.client.get(
reverse("course_topics", kwargs={"course_id": "non/existent/course"}) reverse("course_topics", kwargs={"course_id": "non/existent/course"})
...@@ -65,8 +82,7 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase): ...@@ -65,8 +82,7 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase):
) )
def test_not_enrolled(self): def test_not_enrolled(self):
unenrolled_user = UserFactory.create(password=self.password) self.login_unenrolled_user()
self.client.login(username=unenrolled_user.username, password=self.password)
response = self.client.get(self.url) response = self.client.get(self.url)
self.assert_response_correct( self.assert_response_correct(
response, response,
...@@ -98,3 +114,93 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase): ...@@ -98,3 +114,93 @@ class CourseTopicsViewTest(UrlResetMixin, ModuleStoreTestCase):
}], }],
} }
) )
@httpretty.activate
class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
"""Tests for ThreadViewSet list"""
def setUp(self):
super(ThreadViewSetListTest, self).setUp()
self.url = reverse("thread-list")
def test_course_id_missing(self):
response = self.client.get(self.url)
self.assert_response_correct(
response,
400,
{"field_errors": {"course_id": "This field is required."}}
)
def test_not_enrolled(self):
self.login_unenrolled_user()
response = self.client.get(self.url, {"course_id": unicode(self.course.id)})
self.assert_response_correct(
response,
404,
{"developer_message": "Not found."}
)
def test_basic(self):
source_threads = [{
"id": "test_thread",
"course_id": unicode(self.course.id),
"commentable_id": "test_topic",
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"type": "discussion",
"title": "Test Title",
"body": "Test body",
"pinned": False,
"closed": False,
"comments_count": 5,
"unread_comments_count": 3,
}]
expected_threads = [{
"id": "test_thread",
"course_id": unicode(self.course.id),
"topic_id": "test_topic",
"created_at": "2015-04-28T00:00:00Z",
"updated_at": "2015-04-28T11:11:11Z",
"type": "discussion",
"title": "Test Title",
"raw_body": "Test body",
"pinned": False,
"closed": False,
"comment_count": 5,
"unread_comment_count": 3,
}]
self.register_get_threads_response(source_threads, page=1, num_pages=2)
response = self.client.get(self.url, {"course_id": unicode(self.course.id)})
self.assert_response_correct(
response,
200,
{
"results": expected_threads,
"next": "http://testserver/api/discussion/v1/threads/?course_id=x%2Fy%2Fz&page=2",
"previous": None,
}
)
self.assert_last_query_params({
"course_id": [unicode(self.course.id)],
"page": ["1"],
"per_page": ["10"],
"recursive": ["False"],
})
def test_pagination(self):
self.register_get_threads_response([], page=1, num_pages=1)
response = self.client.get(
self.url,
{"course_id": unicode(self.course.id), "page": "18", "page_size": "4"}
)
self.assert_response_correct(
response,
404,
{"developer_message": "Not found."}
)
self.assert_last_query_params({
"course_id": [unicode(self.course.id)],
"page": ["18"],
"per_page": ["4"],
"recursive": ["False"],
})
"""
Discussion API test utilities
"""
import json
import httpretty
class CommentsServiceMockMixin(object):
"""Mixin with utility methods for mocking the comments service"""
def register_get_threads_response(self, threads, page, num_pages):
"""Register a mock response for GET on the CS thread list endpoint"""
httpretty.register_uri(
httpretty.GET,
"http://localhost:4567/api/v1/threads",
body=json.dumps({
"collection": threads,
"page": page,
"num_pages": num_pages,
}),
status=200
)
def assert_last_query_params(self, expected_params):
"""
Assert that the last mock request had the expected query parameters
"""
actual_params = dict(httpretty.last_request().querystring)
actual_params.pop("request_id") # request_id is random
self.assertEqual(actual_params, expected_params)
...@@ -2,10 +2,15 @@ ...@@ -2,10 +2,15 @@
Discussion API URLs Discussion API URLs
""" """
from django.conf import settings from django.conf import settings
from django.conf.urls import patterns, url from django.conf.urls import include, patterns, url
from discussion_api.views import CourseTopicsView from rest_framework.routers import SimpleRouter
from discussion_api.views import CourseTopicsView, ThreadViewSet
ROUTER = SimpleRouter()
ROUTER.register("threads", ThreadViewSet, base_name="thread")
urlpatterns = patterns( urlpatterns = patterns(
"discussion_api", "discussion_api",
...@@ -14,4 +19,5 @@ urlpatterns = patterns( ...@@ -14,4 +19,5 @@ urlpatterns = patterns(
CourseTopicsView.as_view(), CourseTopicsView.as_view(),
name="course_topics" name="course_topics"
), ),
url("^v1/", include(ROUTER.urls)),
) )
""" """
Discussion API views Discussion API views
""" """
from django.core.exceptions import ValidationError
from django.http import Http404 from django.http import Http404
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication from rest_framework.authentication import OAuth2Authentication, SessionAuthentication
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.viewsets import ViewSet
from opaque_keys.edx.locator import CourseLocator from opaque_keys.edx.locator import CourseLocator
from courseware.courses import get_course_with_access from courseware.courses import get_course_with_access
from discussion_api.api import get_course_topics from discussion_api.api import get_course_topics, get_thread_list
from discussion_api.forms import ThreadListGetForm
from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin
from xmodule.tabs import DiscussionTab from xmodule.tabs import DiscussionTab
class CourseTopicsView(DeveloperErrorViewMixin, APIView): class _ViewMixin(object):
"""
Mixin to provide common characteristics and utility functions for Discussion
API views
"""
authentication_classes = (OAuth2Authentication, SessionAuthentication)
permission_classes = (IsAuthenticated,)
def get_course_or_404(self, user, course_key):
"""
Get the course descriptor, raising Http404 if the course is not found,
the user cannot access forums for the course, or the discussion tab is
disabled for the course.
"""
course = get_course_with_access(user, 'load_forum', course_key)
if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]):
raise Http404
return course
class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView):
""" """
**Use Cases** **Use Cases**
...@@ -42,13 +65,81 @@ class CourseTopicsView(DeveloperErrorViewMixin, APIView): ...@@ -42,13 +65,81 @@ class CourseTopicsView(DeveloperErrorViewMixin, APIView):
* non_courseware_topics: The list of topic trees that are not linked to * non_courseware_topics: The list of topic trees that are not linked to
courseware. Items are of the same format as in courseware_topics. courseware. Items are of the same format as in courseware_topics.
""" """
authentication_classes = (OAuth2Authentication, SessionAuthentication)
permission_classes = (IsAuthenticated,)
def get(self, request, course_id): def get(self, request, course_id):
"""Implements the GET method as described in the class docstring.""" """Implements the GET method as described in the class docstring."""
course_key = CourseLocator.from_string(course_id) course_key = CourseLocator.from_string(course_id)
course = get_course_with_access(request.user, 'load_forum', course_key) course = self.get_course_or_404(request.user, course_key)
if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]):
raise Http404
return Response(get_course_topics(course, request.user)) return Response(get_course_topics(course, request.user))
class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
"""
**Use Cases**
Retrieve the list of threads for a course.
**Example Requests**:
GET /api/discussion/v1/threads/?course_id=ExampleX/Demo/2015
**GET Parameters**:
* course_id (required): The course to retrieve threads for
* page: The (1-indexed) page to retrieve (default is 1)
* page_size: The number of items per page (default is 10, max is 100)
**Response Values**:
* results: The list of threads. Each item in the list includes:
* id: The id of the thread
* course_id: The id of the thread's course
* topic_id: The id of the thread's topic
* created_at: The ISO 8601 timestamp for the creation of the thread
* updated_at: The ISO 8601 timestamp for the last modification of
the thread, which may not have been an update of the title/body
* type: The thread's type (either "question" or "discussion")
* title: The thread's title
* raw_body: The thread's raw body text without any rendering applied
* pinned: Boolean indicating whether the thread has been pinned
* closed: Boolean indicating whether the thread has been closed
* comment_count: The number of comments within the thread
* unread_comment_count: The number of comments within the thread
that were created or updated since the last time the user read
the thread
* next: The URL of the next page (or null if first page)
* previous: The URL of the previous page (or null if last page)
"""
def list(self, request):
"""
Implements the GET method for the list endpoint as described in the
class docstring.
"""
form = ThreadListGetForm(request.GET)
if not form.is_valid():
raise ValidationError(form.errors)
course_key = form.cleaned_data["course_id"]
self.get_course_or_404(request.user, course_key)
return Response(
get_thread_list(
request,
course_key,
form.cleaned_data["page"],
form.cleaned_data["page_size"]
)
)
""" """
Utilities related to API views Utilities related to API views
""" """
from django.core.exceptions import NON_FIELD_ERRORS, ValidationError
from django.http import Http404 from django.http import Http404
from rest_framework.exceptions import APIException from rest_framework.exceptions import APIException
...@@ -19,10 +20,31 @@ class DeveloperErrorViewMixin(object): ...@@ -19,10 +20,31 @@ class DeveloperErrorViewMixin(object):
""" """
return Response({"developer_message": developer_message}, status=status_code) return Response({"developer_message": developer_message}, status=status_code)
def make_validation_error_response(self, validation_error):
"""
Build a 400 error response from the given ValidationError
"""
if hasattr(validation_error, "message_dict"):
response_obj = {}
message_dict = dict(validation_error.message_dict)
non_field_error_list = message_dict.pop(NON_FIELD_ERRORS, None)
if non_field_error_list:
response_obj["developer_message"] = non_field_error_list[0]
if message_dict:
response_obj["field_errors"] = {
field: message_dict[field][0]
for field in message_dict
}
return Response(response_obj, status=400)
else:
return self.make_error_response(400, validation_error.messages[0])
def handle_exception(self, exc): def handle_exception(self, exc):
if isinstance(exc, APIException): if isinstance(exc, APIException):
return self.make_error_response(exc.status_code, exc.detail) return self.make_error_response(exc.status_code, exc.detail)
elif isinstance(exc, Http404): elif isinstance(exc, Http404):
return self.make_error_response(404, "Not found.") return self.make_error_response(404, "Not found.")
elif isinstance(exc, ValidationError):
return self.make_validation_error_response(exc)
else: else:
raise raise
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment