Commit 565cdb8e by Greg Price

Move Discussion API access control checks

The checks are now within the Python API instead of the DRF view. This
will be necessary for certain operations (like fetching/editing threads)
because the relevant course cannot be known until the thread is fetched
from the comments service. This commit updates the existing endpoints to
fit that pattern.
parent b328a2e8
......@@ -5,6 +5,7 @@ from django.http import Http404
from collections import defaultdict
from courseware.courses import get_course_with_access
from discussion_api.pagination import get_paginated_data
from django_comment_client.utils import get_accessible_discussion_modules
from django_comment_common.models import (
......@@ -16,15 +17,28 @@ from django_comment_common.models import (
from lms.lib.comment_client.thread import Thread
from lms.lib.comment_client.user import User
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_id, get_cohort_names
from xmodule.tabs import DiscussionTab
def get_course_topics(course, user):
def _get_course_or_404(course_key, user):
"""
Get the course descriptor, raising Http404 if the course is not found,
the user cannot access forums for the course, or the discussion tab is
disabled for the course.
"""
course = get_course_with_access(user, 'load_forum', course_key)
if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]):
raise Http404
return course
def get_course_topics(course_key, user):
"""
Return the course topic listing for the given course and user.
Parameters:
course: The course to get topics for
course_key: The key of the course to get topics for
user: The requesting user, for access control
Returns:
......@@ -39,6 +53,7 @@ def get_course_topics(course, user):
"""
return module.sort_key or module.discussion_target
course = _get_course_or_404(course_key, user)
discussion_modules = get_accessible_discussion_modules(course, user)
modules_by_category = defaultdict(list)
for module in discussion_modules:
......@@ -138,14 +153,14 @@ def _cc_thread_to_api_thread(thread, cc_user, staff_user_ids, ta_user_ids, group
return ret
def get_thread_list(request, course, page, page_size):
def get_thread_list(request, course_key, page, page_size):
"""
Return the list of all discussion threads pertaining to the given course
Parameters:
request: The django request objects used for build_absolute_uri
course: The course to get discussion threads for
course_key: The key of the course to get discussion threads for
page: The page number (1-indexed) to retrieve
page_size: The number of threads to retrieve per page
......@@ -154,6 +169,7 @@ def get_thread_list(request, course, page, page_size):
A paginated result containing a list of threads; see
discussion_api.views.ThreadViewSet for more detail.
"""
course = _get_course_or_404(course_key, request.user)
user_is_privileged = Role.objects.filter(
course_id=course.id,
name__in=[FORUM_ROLE_ADMINISTRATOR, FORUM_ROLE_MODERATOR, FORUM_ROLE_COMMUNITY_TA],
......
......@@ -42,11 +42,6 @@ class DiscussionAPIViewTestMixin(CommentsServiceMockMixin, UrlResetMixin):
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id)
self.client.login(username=self.user.username, password=self.password)
def login_unenrolled_user(self):
"""Create a user not enrolled in the course and log it in"""
unenrolled_user = UserFactory.create(password=self.password)
self.client.login(username=unenrolled_user.username, password=self.password)
def assert_response_correct(self, response, expected_status, expected_content):
"""
Assert that the response has the given status code and parsed content
......@@ -71,7 +66,7 @@ class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
super(CourseTopicsViewTest, self).setUp()
self.url = reverse("course_topics", kwargs={"course_id": unicode(self.course.id)})
def test_non_existent_course(self):
def test_404(self):
response = self.client.get(
reverse("course_topics", kwargs={"course_id": "non/existent/course"})
)
......@@ -81,26 +76,7 @@ class CourseTopicsViewTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
{"developer_message": "Not found."}
)
def test_not_enrolled(self):
self.login_unenrolled_user()
response = self.client.get(self.url)
self.assert_response_correct(
response,
404,
{"developer_message": "Not found."}
)
def test_discussions_disabled(self):
self.course.tabs = [tab for tab in self.course.tabs if not isinstance(tab, DiscussionTab)]
modulestore().update_item(self.course, self.user.id)
response = self.client.get(self.url)
self.assert_response_correct(
response,
404,
{"developer_message": "Not found."}
)
def test_get(self):
def test_get_success(self):
response = self.client.get(self.url)
self.assert_response_correct(
response,
......@@ -132,9 +108,8 @@ class ThreadViewSetListTest(DiscussionAPIViewTestMixin, ModuleStoreTestCase):
{"field_errors": {"course_id": "This field is required."}}
)
def test_not_enrolled(self):
self.login_unenrolled_user()
response = self.client.get(self.url, {"course_id": unicode(self.course.id)})
def test_404(self):
response = self.client.get(self.url, {"course_id": unicode("non/existent/course")})
self.assert_response_correct(
response,
404,
......
......@@ -2,7 +2,6 @@
Discussion API views
"""
from django.core.exceptions import ValidationError
from django.http import Http404
from rest_framework.authentication import OAuth2Authentication, SessionAuthentication
from rest_framework.permissions import IsAuthenticated
......@@ -12,11 +11,9 @@ from rest_framework.viewsets import ViewSet
from opaque_keys.edx.locator import CourseLocator
from courseware.courses import get_course_with_access
from discussion_api.api import get_course_topics, get_thread_list
from discussion_api.forms import ThreadListGetForm
from openedx.core.lib.api.view_utils import DeveloperErrorViewMixin
from xmodule.tabs import DiscussionTab
class _ViewMixin(object):
......@@ -27,17 +24,6 @@ class _ViewMixin(object):
authentication_classes = (OAuth2Authentication, SessionAuthentication)
permission_classes = (IsAuthenticated,)
def get_course_or_404(self, user, course_key):
"""
Get the course descriptor, raising Http404 if the course is not found,
the user cannot access forums for the course, or the discussion tab is
disabled for the course.
"""
course = get_course_with_access(user, 'load_forum', course_key)
if not any([isinstance(tab, DiscussionTab) for tab in course.tabs]):
raise Http404
return course
class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView):
"""
......@@ -68,8 +54,7 @@ class CourseTopicsView(_ViewMixin, DeveloperErrorViewMixin, APIView):
def get(self, request, course_id):
"""Implements the GET method as described in the class docstring."""
course_key = CourseLocator.from_string(course_id)
course = self.get_course_or_404(request.user, course_key)
return Response(get_course_topics(course, request.user))
return Response(get_course_topics(course_key, request.user))
class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
......@@ -133,11 +118,10 @@ class ThreadViewSet(_ViewMixin, DeveloperErrorViewMixin, ViewSet):
form = ThreadListGetForm(request.GET)
if not form.is_valid():
raise ValidationError(form.errors)
course = self.get_course_or_404(request.user, form.cleaned_data["course_id"])
return Response(
get_thread_list(
request,
course,
form.cleaned_data["course_id"],
form.cleaned_data["page"],
form.cleaned_data["page_size"]
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment