Commit ef3b0e47 by Toby Lawrence Committed by GitHub

Merge pull request #13544 from edx/TNL-5632

[TNL-5632] Don't request responses/comments for non-AJAX requests.
parents 971f102b 6704e17a
...@@ -58,3 +58,51 @@ class RequestCache(object): ...@@ -58,3 +58,51 @@ class RequestCache(object):
""" """
self.clear_request_cache() self.clear_request_cache()
return None return None
def request_cached(f):
"""
A decorator for wrapping a function and automatically handles caching its return value, as well as returning
that cached value for subsequent calls to the same function, with the same parameters, within a given request.
Notes:
- we convert arguments and keyword arguments to their string form to build the cache key, so if you have
args/kwargs that can't be converted to strings, you're gonna have a bad time (don't do it)
- cache key cardinality depends on the args/kwargs, so if you're caching a function that takes five arguments,
you might have deceptively low cache efficiency. prefer function with fewer arguments.
- we use the default request cache, not a named request cache (this shouldn't matter, but just mentioning it)
- benchmark, benchmark, benchmark! if you never measure, how will you know you've improved? or regressed?
Arguments:
f (func): the function to wrap
Returns:
func: a wrapper function which will call the wrapped function, passing in the same args/kwargs,
cache the value it returns, and return that cached value for subsequent calls with the
same args/kwargs within a single request
"""
def wrapper(*args, **kwargs):
"""
Wrapper function to decorate with.
"""
# Build our cache key based on the module the function belongs to, the functions name, and a stringified
# list of arguments and a query string-style stringified list of keyword arguments.
converted_args = map(str, args)
converted_kwargs = map(str, reduce(list.__add__, map(list, sorted(kwargs.iteritems())), []))
cache_keys = [f.__module__, f.func_name] + converted_args + converted_kwargs
cache_key = '.'.join(cache_keys)
# Check to see if we have a result in cache. If not, invoke our wrapped
# function. Cache and return the result to the caller.
rcache = RequestCache.get_request_cache()
if cache_key in rcache.data:
return rcache.data.get(cache_key)
else:
result = f(*args, **kwargs)
rcache.data[cache_key] = result
return result
return wrapper
...@@ -4,8 +4,10 @@ Tests for the request cache. ...@@ -4,8 +4,10 @@ Tests for the request cache.
from celery.task import task from celery.task import task
from django.conf import settings from django.conf import settings
from django.test import TestCase from django.test import TestCase
from mock import Mock
from request_cache import get_request_or_stub from request_cache import get_request_or_stub
from request_cache.middleware import RequestCache, request_cached
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
...@@ -33,3 +35,190 @@ class TestRequestCache(TestCase): ...@@ -33,3 +35,190 @@ class TestRequestCache(TestCase):
""" Test that the request cache is cleared after a task is run. """ """ Test that the request cache is cleared after a task is run. """
self._dummy_task.apply(args=(self,)).get() self._dummy_task.apply(args=(self,)).get()
self.assertEqual(modulestore().request_cache.data, {}) self.assertEqual(modulestore().request_cache.data, {})
def test_request_cached_miss_and_then_hit(self):
"""
Ensure that after a cache miss, we fill the cache and can hit it.
"""
RequestCache.clear_request_cache()
to_be_wrapped = Mock()
to_be_wrapped.return_value = 42
self.assertEqual(to_be_wrapped.call_count, 0)
def mock_wrapper(*args, **kwargs):
"""Simple wrapper to let us decorate our mock."""
return to_be_wrapped(*args, **kwargs)
wrapped = request_cached(mock_wrapper)
result = wrapped()
self.assertEqual(result, 42)
self.assertEqual(to_be_wrapped.call_count, 1)
result = wrapped()
self.assertEqual(result, 42)
self.assertEqual(to_be_wrapped.call_count, 1)
def test_request_cached_with_caches_despite_changing_wrapped_result(self):
"""
Ensure that after caching a result, we always send it back, even if the underlying result changes.
"""
RequestCache.clear_request_cache()
to_be_wrapped = Mock()
to_be_wrapped.side_effect = [1, 2, 3]
self.assertEqual(to_be_wrapped.call_count, 0)
def mock_wrapper(*args, **kwargs):
"""Simple wrapper to let us decorate our mock."""
return to_be_wrapped(*args, **kwargs)
wrapped = request_cached(mock_wrapper)
result = wrapped()
self.assertEqual(result, 1)
self.assertEqual(to_be_wrapped.call_count, 1)
result = wrapped()
self.assertEqual(result, 1)
self.assertEqual(to_be_wrapped.call_count, 1)
direct_result = mock_wrapper()
self.assertEqual(direct_result, 2)
self.assertEqual(to_be_wrapped.call_count, 2)
result = wrapped()
self.assertEqual(result, 1)
self.assertEqual(to_be_wrapped.call_count, 2)
direct_result = mock_wrapper()
self.assertEqual(direct_result, 3)
self.assertEqual(to_be_wrapped.call_count, 3)
def test_request_cached_with_changing_args(self):
"""
Ensure that calling a decorated function with different positional arguments
will not use a cached value invoked by a previous call with different arguments.
"""
RequestCache.clear_request_cache()
to_be_wrapped = Mock()
to_be_wrapped.side_effect = [1, 2, 3, 4, 5, 6]
self.assertEqual(to_be_wrapped.call_count, 0)
def mock_wrapper(*args, **kwargs):
"""Simple wrapper to let us decorate our mock."""
return to_be_wrapped(*args, **kwargs)
wrapped = request_cached(mock_wrapper)
# This will be a miss, and make an underlying call.
result = wrapped(1)
self.assertEqual(result, 1)
self.assertEqual(to_be_wrapped.call_count, 1)
# This will be a miss, and make an underlying call.
result = wrapped(2)
self.assertEqual(result, 2)
self.assertEqual(to_be_wrapped.call_count, 2)
# This is bypass of the decorator.
direct_result = mock_wrapper(3)
self.assertEqual(direct_result, 3)
self.assertEqual(to_be_wrapped.call_count, 3)
# These will be hits, and not make an underlying call.
result = wrapped(1)
self.assertEqual(result, 1)
self.assertEqual(to_be_wrapped.call_count, 3)
result = wrapped(2)
self.assertEqual(result, 2)
self.assertEqual(to_be_wrapped.call_count, 3)
def test_request_cached_with_changing_kwargs(self):
"""
Ensure that calling a decorated function with different keyword arguments
will not use a cached value invoked by a previous call with different arguments.
"""
RequestCache.clear_request_cache()
to_be_wrapped = Mock()
to_be_wrapped.side_effect = [1, 2, 3, 4, 5, 6]
self.assertEqual(to_be_wrapped.call_count, 0)
def mock_wrapper(*args, **kwargs):
"""Simple wrapper to let us decorate our mock."""
return to_be_wrapped(*args, **kwargs)
wrapped = request_cached(mock_wrapper)
# This will be a miss, and make an underlying call.
result = wrapped(1, foo=1)
self.assertEqual(result, 1)
self.assertEqual(to_be_wrapped.call_count, 1)
# This will be a miss, and make an underlying call.
result = wrapped(2, foo=2)
self.assertEqual(result, 2)
self.assertEqual(to_be_wrapped.call_count, 2)
# This is bypass of the decorator.
direct_result = mock_wrapper(3, foo=3)
self.assertEqual(direct_result, 3)
self.assertEqual(to_be_wrapped.call_count, 3)
# These will be hits, and not make an underlying call.
result = wrapped(1, foo=1)
self.assertEqual(result, 1)
self.assertEqual(to_be_wrapped.call_count, 3)
result = wrapped(2, foo=2)
self.assertEqual(result, 2)
self.assertEqual(to_be_wrapped.call_count, 3)
# Since we're changing foo, this will be a miss.
result = wrapped(2, foo=5)
self.assertEqual(result, 4)
self.assertEqual(to_be_wrapped.call_count, 4)
def test_request_cached_with_none_result(self):
"""
Ensure that calling a decorated function that returns None
properly caches the result and doesn't recall the underlying
function.
"""
RequestCache.clear_request_cache()
to_be_wrapped = Mock()
to_be_wrapped.side_effect = [None, None, None, 1, 1]
self.assertEqual(to_be_wrapped.call_count, 0)
def mock_wrapper(*args, **kwargs):
"""Simple wrapper to let us decorate our mock."""
return to_be_wrapped(*args, **kwargs)
wrapped = request_cached(mock_wrapper)
# This will be a miss, and make an underlying call.
result = wrapped(1)
self.assertEqual(result, None)
self.assertEqual(to_be_wrapped.call_count, 1)
# This will be a miss, and make an underlying call.
result = wrapped(2)
self.assertEqual(result, None)
self.assertEqual(to_be_wrapped.call_count, 2)
# This is bypass of the decorator.
direct_result = mock_wrapper(3)
self.assertEqual(direct_result, None)
self.assertEqual(to_be_wrapped.call_count, 3)
# These will be hits, and not make an underlying call.
result = wrapped(1)
self.assertEqual(result, None)
self.assertEqual(to_be_wrapped.call_count, 3)
result = wrapped(2)
self.assertEqual(result, None)
self.assertEqual(to_be_wrapped.call_count, 3)
...@@ -435,22 +435,6 @@ class DiscussionTabMultipleThreadTest(BaseDiscussionTestCase): ...@@ -435,22 +435,6 @@ class DiscussionTabMultipleThreadTest(BaseDiscussionTestCase):
view = MultipleThreadFixture(threads) view = MultipleThreadFixture(threads)
view.push() view.push()
def test_page_scroll_on_thread_change_view(self):
"""
Check switching between threads changes the page focus
"""
# verify threads are rendered on the page
self.assertTrue(
self.thread_page_1.check_threads_rendered_successfully(thread_count=self.thread_count)
)
# From the thread_page_1 open & verify next thread
self.thread_page_1.click_and_open_thread(thread_id=self.thread_ids[1])
self.assertTrue(self.thread_page_2.is_browser_on_page())
# Verify that the focus is changed
self.thread_page_2.check_focus_is_set(selector=".discussion-article")
@attr('a11y') @attr('a11y')
def test_page_accessibility(self): def test_page_accessibility(self):
self.thread_page_1.a11y_audit.config.set_rules({ self.thread_page_1.a11y_audit.config.set_rules({
......
...@@ -12,8 +12,9 @@ from lms.lib.comment_client.utils import CommentClientPaginatedResult ...@@ -12,8 +12,9 @@ from lms.lib.comment_client.utils import CommentClientPaginatedResult
from django_comment_common.utils import ThreadContext from django_comment_common.utils import ThreadContext
from django_comment_client.permissions import get_team from django_comment_client.permissions import get_team
from django_comment_client.tests.group_id import ( from django_comment_client.tests.group_id import (
GroupIdAssertionMixin,
CohortedTopicGroupIdTestMixin, CohortedTopicGroupIdTestMixin,
NonCohortedTopicGroupIdTestMixin NonCohortedTopicGroupIdTestMixin,
) )
from django_comment_client.tests.unicode import UnicodeTestMixin from django_comment_client.tests.unicode import UnicodeTestMixin
from django_comment_client.tests.utils import CohortedTestCase from django_comment_client.tests.utils import CohortedTestCase
...@@ -347,11 +348,11 @@ class SingleThreadQueryCountTestCase(ModuleStoreTestCase): ...@@ -347,11 +348,11 @@ class SingleThreadQueryCountTestCase(ModuleStoreTestCase):
# course is outside the context manager that is verifying the number of queries, # course is outside the context manager that is verifying the number of queries,
# and with split mongo, that method ends up querying disabled_xblocks (which is then # and with split mongo, that method ends up querying disabled_xblocks (which is then
# cached and hence not queried as part of call_single_thread). # cached and hence not queried as part of call_single_thread).
(ModuleStoreEnum.Type.mongo, 1, 6, 4, 18, 7), (ModuleStoreEnum.Type.mongo, 1, 6, 4, 15, 3),
(ModuleStoreEnum.Type.mongo, 50, 6, 4, 18, 7), (ModuleStoreEnum.Type.mongo, 50, 6, 4, 15, 3),
# split mongo: 3 queries, regardless of thread response size. # split mongo: 3 queries, regardless of thread response size.
(ModuleStoreEnum.Type.split, 1, 3, 3, 17, 7), (ModuleStoreEnum.Type.split, 1, 3, 3, 14, 3),
(ModuleStoreEnum.Type.split, 50, 3, 3, 17, 7), (ModuleStoreEnum.Type.split, 50, 3, 3, 14, 3),
) )
@ddt.unpack @ddt.unpack
def test_number_of_mongo_queries( def test_number_of_mongo_queries(
...@@ -550,8 +551,8 @@ class SingleThreadAccessTestCase(CohortedTestCase): ...@@ -550,8 +551,8 @@ class SingleThreadAccessTestCase(CohortedTestCase):
@patch('lms.lib.comment_client.utils.requests.request', autospec=True) @patch('lms.lib.comment_client.utils.requests.request', autospec=True)
class SingleThreadGroupIdTestCase(CohortedTestCase, CohortedTopicGroupIdTestMixin): class SingleThreadGroupIdTestCase(CohortedTestCase, GroupIdAssertionMixin):
cs_endpoint = "/threads" cs_endpoint = "/threads/dummy_thread_id"
def call_view(self, mock_request, commentable_id, user, group_id, pass_group_id=True, is_ajax=False): def call_view(self, mock_request, commentable_id, user, group_id, pass_group_id=True, is_ajax=False):
mock_request.side_effect = make_mock_request_impl( mock_request.side_effect = make_mock_request_impl(
......
...@@ -286,7 +286,11 @@ def forum_form_discussion(request, course_key): ...@@ -286,7 +286,11 @@ def forum_form_discussion(request, course_key):
@use_bulk_ops @use_bulk_ops
def single_thread(request, course_key, discussion_id, thread_id): def single_thread(request, course_key, discussion_id, thread_id):
""" """
Renders a response to display a single discussion thread. Renders a response to display a single discussion thread. This could either be a page refresh
after navigating to a single thread, a direct link to a single thread, or an AJAX call from the
discussions UI loading the responses/comments for a single thread.
Depending on the HTTP headers, we'll adjust our response accordingly.
""" """
nr_transaction = newrelic.agent.current_transaction() nr_transaction = newrelic.agent.current_transaction()
...@@ -295,13 +299,11 @@ def single_thread(request, course_key, discussion_id, thread_id): ...@@ -295,13 +299,11 @@ def single_thread(request, course_key, discussion_id, thread_id):
cc_user = cc.User.from_django_user(request.user) cc_user = cc.User.from_django_user(request.user)
user_info = cc_user.to_dict() user_info = cc_user.to_dict()
is_moderator = has_permission(request.user, "see_all_cohorts", course_key) is_moderator = has_permission(request.user, "see_all_cohorts", course_key)
is_staff = has_permission(request.user, 'openclose_thread', course.id)
# Currently, the front end always loads responses via AJAX, even for this
# page; it would be a nice optimization to avoid that extra round trip to
# the comments service.
try: try:
thread = cc.Thread.find(thread_id).retrieve( thread = cc.Thread.find(thread_id).retrieve(
with_responses=True, with_responses=request.is_ajax(),
recursive=request.is_ajax(), recursive=request.is_ajax(),
user_id=request.user.id, user_id=request.user.id,
response_skip=request.GET.get("resp_skip"), response_skip=request.GET.get("resp_skip"),
...@@ -323,7 +325,6 @@ def single_thread(request, course_key, discussion_id, thread_id): ...@@ -323,7 +325,6 @@ def single_thread(request, course_key, discussion_id, thread_id):
if getattr(thread, "group_id", None) is not None and user_group_id != thread.group_id: if getattr(thread, "group_id", None) is not None and user_group_id != thread.group_id:
raise Http404 raise Http404
is_staff = has_permission(request.user, 'openclose_thread', course.id)
if request.is_ajax(): if request.is_ajax():
with newrelic.agent.FunctionTrace(nr_transaction, "get_annotated_content_infos"): with newrelic.agent.FunctionTrace(nr_transaction, "get_annotated_content_infos"):
annotated_content_info = utils.get_annotated_content_infos( annotated_content_info = utils.get_annotated_content_infos(
...@@ -332,20 +333,19 @@ def single_thread(request, course_key, discussion_id, thread_id): ...@@ -332,20 +333,19 @@ def single_thread(request, course_key, discussion_id, thread_id):
request.user, request.user,
user_info=user_info user_info=user_info
) )
content = utils.prepare_content(thread.to_dict(), course_key, is_staff) content = utils.prepare_content(thread.to_dict(), course_key, is_staff)
with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"): with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"):
add_courseware_context([content], course, request.user) add_courseware_context([content], course, request.user)
return utils.JsonResponse({ return utils.JsonResponse({
'content': content, 'content': content,
'annotated_content_info': annotated_content_info, 'annotated_content_info': annotated_content_info,
}) })
else: else:
try: # Since we're in page render mode, and the discussions UI will request the thread list itself,
threads, query_params = get_threads(request, course, user_info) # we need only return the thread information for this one.
except ValueError: threads = [thread.to_dict()]
return HttpResponseBadRequest("Invalid group_id")
threads.append(thread.to_dict())
with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"): with newrelic.agent.FunctionTrace(nr_transaction, "add_courseware_context"):
add_courseware_context(threads, course, request.user) add_courseware_context(threads, course, request.user)
...@@ -379,7 +379,7 @@ def single_thread(request, course_key, discussion_id, thread_id): ...@@ -379,7 +379,7 @@ def single_thread(request, course_key, discussion_id, thread_id):
'threads': threads, 'threads': threads,
'roles': utils.get_role_ids(course_key), 'roles': utils.get_role_ids(course_key),
'is_moderator': is_moderator, 'is_moderator': is_moderator,
'thread_pages': query_params['num_pages'], 'thread_pages': 1,
'is_course_cohorted': is_course_cohorted(course_key), 'is_course_cohorted': is_course_cohorted(course_key),
'flag_moderator': bool( 'flag_moderator': bool(
has_permission(request.user, 'openclose_thread', course.id) or has_permission(request.user, 'openclose_thread', course.id) or
......
...@@ -370,8 +370,8 @@ class ViewsQueryCountTestCase(UrlResetMixin, ModuleStoreTestCase, MockRequestSet ...@@ -370,8 +370,8 @@ class ViewsQueryCountTestCase(UrlResetMixin, ModuleStoreTestCase, MockRequestSet
return inner return inner
@ddt.data( @ddt.data(
(ModuleStoreEnum.Type.mongo, 3, 4, 33), (ModuleStoreEnum.Type.mongo, 3, 4, 31),
(ModuleStoreEnum.Type.split, 3, 13, 33), (ModuleStoreEnum.Type.split, 3, 13, 31),
) )
@ddt.unpack @ddt.unpack
@count_queries @count_queries
......
...@@ -5,7 +5,7 @@ Module for checking permissions with the comment_client backend ...@@ -5,7 +5,7 @@ Module for checking permissions with the comment_client backend
import logging import logging
from types import NoneType from types import NoneType
from request_cache.middleware import RequestCache from request_cache.middleware import RequestCache, request_cached
from lms.lib.comment_client import Thread from lms.lib.comment_client import Thread
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
...@@ -31,19 +31,14 @@ def has_permission(user, permission, course_id=None): ...@@ -31,19 +31,14 @@ def has_permission(user, permission, course_id=None):
CONDITIONS = ['is_open', 'is_author', 'is_question_author', 'is_team_member_if_applicable'] CONDITIONS = ['is_open', 'is_author', 'is_question_author', 'is_team_member_if_applicable']
@request_cached
def get_team(commentable_id): def get_team(commentable_id):
""" Returns the team that the commentable_id belongs to if it exists. Returns None otherwise. """ """ Returns the team that the commentable_id belongs to if it exists. Returns None otherwise. """
request_cache_dict = RequestCache.get_request_cache().data
cache_key = "django_comment_client.team_commentable.{}".format(commentable_id)
if cache_key in request_cache_dict:
return request_cache_dict[cache_key]
try: try:
team = CourseTeam.objects.get(discussion_topic_id=commentable_id) team = CourseTeam.objects.get(discussion_topic_id=commentable_id)
except CourseTeam.DoesNotExist: except CourseTeam.DoesNotExist:
team = None team = None
request_cache_dict[cache_key] = team
return team return team
......
...@@ -241,17 +241,17 @@ class CachedDiscussionIdMapTestCase(ModuleStoreTestCase): ...@@ -241,17 +241,17 @@ class CachedDiscussionIdMapTestCase(ModuleStoreTestCase):
) )
def test_cache_returns_correct_key(self): def test_cache_returns_correct_key(self):
usage_key = utils.get_cached_discussion_key(self.course, 'test_discussion_id') usage_key = utils.get_cached_discussion_key(self.course.id, 'test_discussion_id')
self.assertEqual(usage_key, self.discussion.location) self.assertEqual(usage_key, self.discussion.location)
def test_cache_returns_none_if_id_is_not_present(self): def test_cache_returns_none_if_id_is_not_present(self):
usage_key = utils.get_cached_discussion_key(self.course, 'bogus_id') usage_key = utils.get_cached_discussion_key(self.course.id, 'bogus_id')
self.assertIsNone(usage_key) self.assertIsNone(usage_key)
def test_cache_raises_exception_if_course_structure_not_cached(self): def test_cache_raises_exception_if_course_structure_not_cached(self):
CourseStructure.objects.all().delete() CourseStructure.objects.all().delete()
with self.assertRaises(utils.DiscussionIdMapIsNotCached): with self.assertRaises(utils.DiscussionIdMapIsNotCached):
utils.get_cached_discussion_key(self.course, 'test_discussion_id') utils.get_cached_discussion_key(self.course.id, 'test_discussion_id')
def test_cache_raises_exception_if_discussion_id_not_cached(self): def test_cache_raises_exception_if_discussion_id_not_cached(self):
cache = CourseStructure.objects.get(course_id=self.course.id) cache = CourseStructure.objects.get(course_id=self.course.id)
...@@ -259,7 +259,7 @@ class CachedDiscussionIdMapTestCase(ModuleStoreTestCase): ...@@ -259,7 +259,7 @@ class CachedDiscussionIdMapTestCase(ModuleStoreTestCase):
cache.save() cache.save()
with self.assertRaises(utils.DiscussionIdMapIsNotCached): with self.assertRaises(utils.DiscussionIdMapIsNotCached):
utils.get_cached_discussion_key(self.course, 'test_discussion_id') utils.get_cached_discussion_key(self.course.id, 'test_discussion_id')
def test_xblock_does_not_have_required_keys(self): def test_xblock_does_not_have_required_keys(self):
self.assertTrue(utils.has_required_keys(self.discussion)) self.assertTrue(utils.has_required_keys(self.discussion))
......
...@@ -27,6 +27,7 @@ from openedx.core.djangoapps.course_groups.cohorts import ( ...@@ -27,6 +27,7 @@ from openedx.core.djangoapps.course_groups.cohorts import (
get_course_cohort_settings, get_cohort_by_id, get_cohort_id, is_course_cohorted get_course_cohort_settings, get_cohort_by_id, get_cohort_id, is_course_cohorted
) )
from openedx.core.djangoapps.course_groups.models import CourseUserGroup from openedx.core.djangoapps.course_groups.models import CourseUserGroup
from request_cache.middleware import request_cached
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -154,17 +155,19 @@ class DiscussionIdMapIsNotCached(Exception): ...@@ -154,17 +155,19 @@ class DiscussionIdMapIsNotCached(Exception):
pass pass
def get_cached_discussion_key(course, discussion_id): @request_cached
def get_cached_discussion_key(course_id, discussion_id):
""" """
Returns the usage key of the discussion xblock associated with discussion_id if it is cached. If the discussion id Returns the usage key of the discussion xblock associated with discussion_id if it is cached. If the discussion id
map is cached but does not contain discussion_id, returns None. If the discussion id map is not cached for course, map is cached but does not contain discussion_id, returns None. If the discussion id map is not cached for course,
raises a DiscussionIdMapIsNotCached exception. raises a DiscussionIdMapIsNotCached exception.
""" """
try: try:
cached_mapping = CourseStructure.objects.get(course_id=course.id).discussion_id_map mapping = CourseStructure.objects.get(course_id=course_id).discussion_id_map
if not cached_mapping: if not mapping:
raise DiscussionIdMapIsNotCached() raise DiscussionIdMapIsNotCached()
return cached_mapping.get(discussion_id)
return mapping.get(discussion_id)
except CourseStructure.DoesNotExist: except CourseStructure.DoesNotExist:
raise DiscussionIdMapIsNotCached() raise DiscussionIdMapIsNotCached()
...@@ -177,7 +180,7 @@ def get_cached_discussion_id_map(course, discussion_ids, user): ...@@ -177,7 +180,7 @@ def get_cached_discussion_id_map(course, discussion_ids, user):
try: try:
entries = [] entries = []
for discussion_id in discussion_ids: for discussion_id in discussion_ids:
key = get_cached_discussion_key(course, discussion_id) key = get_cached_discussion_key(course.id, discussion_id)
if not key: if not key:
continue continue
xblock = modulestore().get_item(key) xblock = modulestore().get_item(key)
...@@ -398,7 +401,7 @@ def discussion_category_id_access(course, user, discussion_id, xblock=None): ...@@ -398,7 +401,7 @@ def discussion_category_id_access(course, user, discussion_id, xblock=None):
return True return True
try: try:
if not xblock: if not xblock:
key = get_cached_discussion_key(course, discussion_id) key = get_cached_discussion_key(course.id, discussion_id)
if not key: if not key:
return False return False
xblock = modulestore().get_item(key) xblock = modulestore().get_item(key)
......
...@@ -13,7 +13,7 @@ from django.utils.translation import ugettext as _ ...@@ -13,7 +13,7 @@ from django.utils.translation import ugettext as _
from courseware import courses from courseware import courses
from eventtracking import tracker from eventtracking import tracker
from request_cache.middleware import RequestCache from request_cache.middleware import RequestCache, request_cached
from student.models import get_user_by_username_or_email from student.models import get_user_by_username_or_email
from .models import ( from .models import (
...@@ -485,6 +485,7 @@ def set_course_cohort_settings(course_key, **kwargs): ...@@ -485,6 +485,7 @@ def set_course_cohort_settings(course_key, **kwargs):
return course_cohort_settings return course_cohort_settings
@request_cached
def get_course_cohort_settings(course_key): def get_course_cohort_settings(course_key):
""" """
Return cohort settings for a course. Return cohort settings for a course.
......
...@@ -264,7 +264,7 @@ class TestCohorts(ModuleStoreTestCase): ...@@ -264,7 +264,7 @@ class TestCohorts(ModuleStoreTestCase):
@ddt.data( @ddt.data(
(True, 3), (True, 3),
(False, 9), (False, 7),
) )
@ddt.unpack @ddt.unpack
def test_get_cohort_sql_queries(self, use_cached, num_sql_queries): def test_get_cohort_sql_queries(self, use_cached, num_sql_queries):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment