Commit a9e0082c by Usman Khalid

Cache cohort info during requests to reduce SQL queries.

TNL-1258
parent d59be994
......@@ -110,12 +110,12 @@ def is_course_cohorted(course_key):
return get_course_cohort_settings(course_key).is_cohorted
def get_cohort_id(user, course_key):
def get_cohort_id(user, course_key, use_cached=False):
"""
Given a course key and a user, return the id of the cohort that user is
assigned to in that course. If they don't have a cohort, return None.
"""
cohort = get_cohort(user, course_key)
cohort = get_cohort(user, course_key, use_cached=use_cached)
return None if cohort is None else cohort.id
......@@ -172,7 +172,7 @@ def get_cohorted_commentables(course_key):
@transaction.commit_on_success
def get_cohort(user, course_key, assign=True):
def get_cohort(user, course_key, assign=True, use_cached=False):
"""
Given a Django user and a CourseKey, return the user's cohort in that
cohort.
......@@ -181,6 +181,7 @@ def get_cohort(user, course_key, assign=True):
user: a Django User object.
course_key: CourseKey
assign (bool): if False then we don't assign a group to user
use_cached (bool): Whether to use the cached value or fetch from database.
Returns:
A CourseUserGroup object if the course is cohorted and the User has a
......@@ -189,23 +190,33 @@ def get_cohort(user, course_key, assign=True):
Raises:
ValueError if the CourseKey doesn't exist.
"""
# pylint: disable=protected-access
# We cache the cohort on the user object so that we do not have to repeatedly
# query the database during a request. If the cached value exists, just return it.
if use_cached and hasattr(user, '_cohort'):
return user._cohort
# First check whether the course is cohorted (users shouldn't be in a cohort
# in non-cohorted courses, but settings can change after course starts)
course = courses.get_course(course_key)
course_cohort_settings = get_course_cohort_settings(course.id)
if not course_cohort_settings.is_cohorted:
return None
user._cohort = None
return user._cohort
try:
return CourseUserGroup.objects.get(
user._cohort = CourseUserGroup.objects.get(
course_id=course_key,
group_type=CourseUserGroup.COHORT,
users__id=user.id,
)
return user._cohort
except CourseUserGroup.DoesNotExist:
# Didn't find the group. We'll go on to create one if needed.
if not assign:
# Do not cache the cohort here, because in the next call assign
# may be True, and we will have to assign the user a cohort.
return None
cohorts = get_course_cohorts(course, assignment_type=CourseCohort.RANDOM)
......@@ -220,7 +231,8 @@ def get_cohort(user, course_key, assign=True):
user.course_groups.add(cohort)
return cohort
user._cohort = cohort
return user._cohort
def migrate_cohort_settings(course):
......@@ -387,7 +399,7 @@ def add_user_to_cohort(cohort, username_or_email):
return (user, previous_cohort_name)
def get_group_info_for_cohort(cohort):
def get_group_info_for_cohort(cohort, use_cached=False):
"""
Get the ids of the group and partition to which this cohort has been linked
as a tuple of (int, int).
......@@ -395,9 +407,21 @@ def get_group_info_for_cohort(cohort):
If the cohort has not been linked to any group/partition, both values in the
tuple will be None.
"""
res = CourseUserGroupPartitionGroup.objects.filter(course_user_group=cohort)
if len(res):
return res[0].group_id, res[0].partition_id
# pylint: disable=protected-access
# We cache the partition group on the cohort object so that we do not have to repeatedly
# query the database during a request.
if not use_cached and hasattr(cohort, '_partition_group'):
delattr(cohort, '_partition_group')
if not hasattr(cohort, '_partition_group'):
try:
cohort._partition_group = CourseUserGroupPartitionGroup.objects.get(course_user_group=cohort)
except CourseUserGroupPartitionGroup.DoesNotExist:
cohort._partition_group = None
if cohort._partition_group:
return cohort._partition_group.group_id, cohort._partition_group.partition_id
return None, None
......
......@@ -22,7 +22,7 @@ class CohortPartitionScheme(object):
# pylint: disable=unused-argument
@classmethod
def get_group_for_user(cls, course_key, user, user_partition, track_function=None):
def get_group_for_user(cls, course_key, user, user_partition, track_function=None, use_cached=True):
"""
Returns the Group from the specified user partition to which the user
is assigned, via their cohort membership and any mappings from cohorts
......@@ -48,12 +48,12 @@ class CohortPartitionScheme(object):
return None
return None
cohort = get_cohort(user, course_key)
cohort = get_cohort(user, course_key, use_cached=use_cached)
if cohort is None:
# student doesn't have a cohort
return None
group_id, partition_id = get_group_info_for_cohort(cohort)
group_id, partition_id = get_group_info_for_cohort(cohort, use_cached=use_cached)
if partition_id is None:
# cohort isn't mapped to any partition group.
return None
......
......@@ -10,8 +10,7 @@ from opaque_keys.edx.locations import SlashSeparatedCourseKey
from xmodule.modulestore.django import modulestore
from xmodule.modulestore import ModuleStoreEnum
import json
from ..cohorts import get_course_cohort_settings, set_course_cohort_settings
from ..cohorts import set_course_cohort_settings
from ..models import CourseUserGroup, CourseCohort, CourseCohortsSettings
......@@ -126,6 +125,7 @@ def config_course_cohorts_legacy(
pass
# pylint: disable=dangerous-default-value
def config_course_cohorts(
course,
is_cohorted,
......@@ -154,13 +154,14 @@ def config_course_cohorts(
Nothing -- modifies course in place.
"""
def to_id(name):
"""Convert name to id."""
return topic_name_to_id(course, name)
set_course_cohort_settings(
course.id,
is_cohorted = is_cohorted,
cohorted_discussions = [to_id(name) for name in cohorted_discussions],
always_cohort_inline_discussions = always_cohort_inline_discussions
is_cohorted=is_cohorted,
cohorted_discussions=[to_id(name) for name in cohorted_discussions],
always_cohort_inline_discussions=always_cohort_inline_discussions
)
for cohort_name in auto_cohorts:
......
......@@ -2,13 +2,14 @@
Tests for cohorts
"""
# pylint: disable=no-member
import ddt
from mock import call, patch
from django.contrib.auth.models import User
from django.db import IntegrityError
from django.http import Http404
from django.test import TestCase
from django.test.utils import override_settings
from mock import call, patch
from opaque_keys.edx.locations import SlashSeparatedCourseKey
from student.models import CourseEnrollment
......@@ -121,6 +122,7 @@ class TestCohortSignals(TestCase):
self.assertFalse(mock_tracker.emit.called)
@ddt.ddt
class TestCohorts(ModuleStoreTestCase):
"""
Test the cohorts feature
......@@ -243,12 +245,33 @@ class TestCohorts(ModuleStoreTestCase):
cohort.id,
"user should be assigned to the correct cohort"
)
self.assertEquals(
cohorts.get_cohort(other_user, course.id).id,
cohorts.get_cohort_by_name(course.id, cohorts.DEFAULT_COHORT_NAME).id,
"other_user should be assigned to the default cohort"
)
@ddt.data(
(True, 2),
(False, 6),
)
@ddt.unpack
def test_get_cohort_sql_queries(self, use_cached, num_sql_queries):
"""
Test number of queries by cohorts.get_cohort() with and without caching.
"""
course = modulestore().get_course(self.toy_course_key)
config_course_cohorts(course, is_cohorted=True)
cohort = CohortFactory(course_id=course.id, name="TestCohort")
user = UserFactory(username="test", email="a@b.com")
cohort.users.add(user)
with self.assertNumQueries(num_sql_queries):
for __ in range(3):
cohorts.get_cohort(user, course.id, use_cached=use_cached)
def test_get_cohort_with_assign(self):
"""
Make sure cohorts.get_cohort() returns None if no group is already
......@@ -473,7 +496,7 @@ class TestCohorts(ModuleStoreTestCase):
config_course_cohorts(
course,
is_cohorted=True,
discussion_topics= ["General", "Feedback"],
discussion_topics=["General", "Feedback"],
cohorted_discussions=["Feedback"]
)
......@@ -497,7 +520,7 @@ class TestCohorts(ModuleStoreTestCase):
config_course_cohorts(
course,
is_cohorted=True,
discussion_topics =["General", "Feedback"],
discussion_topics=["General", "Feedback"],
cohorted_discussions=["Feedback", "random_inline"]
)
self.assertTrue(
......@@ -741,6 +764,7 @@ class TestCohorts(ModuleStoreTestCase):
)
@ddt.ddt
class TestCohortsAndPartitionGroups(ModuleStoreTestCase):
"""
Test Cohorts and Partitions Groups.
......@@ -803,6 +827,25 @@ class TestCohortsAndPartitionGroups(ModuleStoreTestCase):
(None, None),
)
@ddt.data(
(True, 1),
(False, 3),
)
@ddt.unpack
def test_get_group_info_for_cohort_queries(self, use_cached, num_sql_queries):
"""
Basic test of the partition_group_info accessor function
"""
# create a link for the cohort in the db
self._link_cohort_partition_group(
self.first_cohort,
self.partition_id,
self.group1_id
)
with self.assertNumQueries(num_sql_queries):
for __ in range(3):
self.assertIsNotNone(cohorts.get_group_info_for_cohort(self.first_cohort, use_cached=use_cached))
def test_multiple_cohorts(self):
"""
Test that multiple cohorts can be linked to the same partition group
......
......@@ -63,6 +63,7 @@ class TestCohortPartitionScheme(ModuleStoreTestCase):
self.course_key,
self.student,
partition or self.user_partition,
use_cached=False
),
group
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment