Commit a9e0082c by Usman Khalid

Cache cohort info during requests to reduce SQL queries.

TNL-1258
parent d59be994
...@@ -110,12 +110,12 @@ def is_course_cohorted(course_key): ...@@ -110,12 +110,12 @@ def is_course_cohorted(course_key):
return get_course_cohort_settings(course_key).is_cohorted return get_course_cohort_settings(course_key).is_cohorted
def get_cohort_id(user, course_key): def get_cohort_id(user, course_key, use_cached=False):
""" """
Given a course key and a user, return the id of the cohort that user is Given a course key and a user, return the id of the cohort that user is
assigned to in that course. If they don't have a cohort, return None. assigned to in that course. If they don't have a cohort, return None.
""" """
cohort = get_cohort(user, course_key) cohort = get_cohort(user, course_key, use_cached=use_cached)
return None if cohort is None else cohort.id return None if cohort is None else cohort.id
...@@ -172,7 +172,7 @@ def get_cohorted_commentables(course_key): ...@@ -172,7 +172,7 @@ def get_cohorted_commentables(course_key):
@transaction.commit_on_success @transaction.commit_on_success
def get_cohort(user, course_key, assign=True): def get_cohort(user, course_key, assign=True, use_cached=False):
""" """
Given a Django user and a CourseKey, return the user's cohort in that Given a Django user and a CourseKey, return the user's cohort in that
cohort. cohort.
...@@ -181,6 +181,7 @@ def get_cohort(user, course_key, assign=True): ...@@ -181,6 +181,7 @@ def get_cohort(user, course_key, assign=True):
user: a Django User object. user: a Django User object.
course_key: CourseKey course_key: CourseKey
assign (bool): if False then we don't assign a group to user assign (bool): if False then we don't assign a group to user
use_cached (bool): Whether to use the cached value or fetch from database.
Returns: Returns:
A CourseUserGroup object if the course is cohorted and the User has a A CourseUserGroup object if the course is cohorted and the User has a
...@@ -189,23 +190,33 @@ def get_cohort(user, course_key, assign=True): ...@@ -189,23 +190,33 @@ def get_cohort(user, course_key, assign=True):
Raises: Raises:
ValueError if the CourseKey doesn't exist. ValueError if the CourseKey doesn't exist.
""" """
# pylint: disable=protected-access
# We cache the cohort on the user object so that we do not have to repeatedly
# query the database during a request. If the cached value exists, just return it.
if use_cached and hasattr(user, '_cohort'):
return user._cohort
# First check whether the course is cohorted (users shouldn't be in a cohort # First check whether the course is cohorted (users shouldn't be in a cohort
# in non-cohorted courses, but settings can change after course starts) # in non-cohorted courses, but settings can change after course starts)
course = courses.get_course(course_key) course = courses.get_course(course_key)
course_cohort_settings = get_course_cohort_settings(course.id) course_cohort_settings = get_course_cohort_settings(course.id)
if not course_cohort_settings.is_cohorted: if not course_cohort_settings.is_cohorted:
return None user._cohort = None
return user._cohort
try: try:
return CourseUserGroup.objects.get( user._cohort = CourseUserGroup.objects.get(
course_id=course_key, course_id=course_key,
group_type=CourseUserGroup.COHORT, group_type=CourseUserGroup.COHORT,
users__id=user.id, users__id=user.id,
) )
return user._cohort
except CourseUserGroup.DoesNotExist: except CourseUserGroup.DoesNotExist:
# Didn't find the group. We'll go on to create one if needed. # Didn't find the group. We'll go on to create one if needed.
if not assign: if not assign:
# Do not cache the cohort here, because in the next call assign
# may be True, and we will have to assign the user a cohort.
return None return None
cohorts = get_course_cohorts(course, assignment_type=CourseCohort.RANDOM) cohorts = get_course_cohorts(course, assignment_type=CourseCohort.RANDOM)
...@@ -220,7 +231,8 @@ def get_cohort(user, course_key, assign=True): ...@@ -220,7 +231,8 @@ def get_cohort(user, course_key, assign=True):
user.course_groups.add(cohort) user.course_groups.add(cohort)
return cohort user._cohort = cohort
return user._cohort
def migrate_cohort_settings(course): def migrate_cohort_settings(course):
...@@ -387,7 +399,7 @@ def add_user_to_cohort(cohort, username_or_email): ...@@ -387,7 +399,7 @@ def add_user_to_cohort(cohort, username_or_email):
return (user, previous_cohort_name) return (user, previous_cohort_name)
def get_group_info_for_cohort(cohort): def get_group_info_for_cohort(cohort, use_cached=False):
""" """
Get the ids of the group and partition to which this cohort has been linked Get the ids of the group and partition to which this cohort has been linked
as a tuple of (int, int). as a tuple of (int, int).
...@@ -395,9 +407,21 @@ def get_group_info_for_cohort(cohort): ...@@ -395,9 +407,21 @@ def get_group_info_for_cohort(cohort):
If the cohort has not been linked to any group/partition, both values in the If the cohort has not been linked to any group/partition, both values in the
tuple will be None. tuple will be None.
""" """
res = CourseUserGroupPartitionGroup.objects.filter(course_user_group=cohort) # pylint: disable=protected-access
if len(res): # We cache the partition group on the cohort object so that we do not have to repeatedly
return res[0].group_id, res[0].partition_id # query the database during a request.
if not use_cached and hasattr(cohort, '_partition_group'):
delattr(cohort, '_partition_group')
if not hasattr(cohort, '_partition_group'):
try:
cohort._partition_group = CourseUserGroupPartitionGroup.objects.get(course_user_group=cohort)
except CourseUserGroupPartitionGroup.DoesNotExist:
cohort._partition_group = None
if cohort._partition_group:
return cohort._partition_group.group_id, cohort._partition_group.partition_id
return None, None return None, None
......
...@@ -22,7 +22,7 @@ class CohortPartitionScheme(object): ...@@ -22,7 +22,7 @@ class CohortPartitionScheme(object):
# pylint: disable=unused-argument # pylint: disable=unused-argument
@classmethod @classmethod
def get_group_for_user(cls, course_key, user, user_partition, track_function=None): def get_group_for_user(cls, course_key, user, user_partition, track_function=None, use_cached=True):
""" """
Returns the Group from the specified user partition to which the user Returns the Group from the specified user partition to which the user
is assigned, via their cohort membership and any mappings from cohorts is assigned, via their cohort membership and any mappings from cohorts
...@@ -48,12 +48,12 @@ class CohortPartitionScheme(object): ...@@ -48,12 +48,12 @@ class CohortPartitionScheme(object):
return None return None
return None return None
cohort = get_cohort(user, course_key) cohort = get_cohort(user, course_key, use_cached=use_cached)
if cohort is None: if cohort is None:
# student doesn't have a cohort # student doesn't have a cohort
return None return None
group_id, partition_id = get_group_info_for_cohort(cohort) group_id, partition_id = get_group_info_for_cohort(cohort, use_cached=use_cached)
if partition_id is None: if partition_id is None:
# cohort isn't mapped to any partition group. # cohort isn't mapped to any partition group.
return None return None
......
...@@ -10,8 +10,7 @@ from opaque_keys.edx.locations import SlashSeparatedCourseKey ...@@ -10,8 +10,7 @@ from opaque_keys.edx.locations import SlashSeparatedCourseKey
from xmodule.modulestore.django import modulestore from xmodule.modulestore.django import modulestore
from xmodule.modulestore import ModuleStoreEnum from xmodule.modulestore import ModuleStoreEnum
import json from ..cohorts import set_course_cohort_settings
from ..cohorts import get_course_cohort_settings, set_course_cohort_settings
from ..models import CourseUserGroup, CourseCohort, CourseCohortsSettings from ..models import CourseUserGroup, CourseCohort, CourseCohortsSettings
...@@ -126,6 +125,7 @@ def config_course_cohorts_legacy( ...@@ -126,6 +125,7 @@ def config_course_cohorts_legacy(
pass pass
# pylint: disable=dangerous-default-value
def config_course_cohorts( def config_course_cohorts(
course, course,
is_cohorted, is_cohorted,
...@@ -154,13 +154,14 @@ def config_course_cohorts( ...@@ -154,13 +154,14 @@ def config_course_cohorts(
Nothing -- modifies course in place. Nothing -- modifies course in place.
""" """
def to_id(name): def to_id(name):
"""Convert name to id."""
return topic_name_to_id(course, name) return topic_name_to_id(course, name)
set_course_cohort_settings( set_course_cohort_settings(
course.id, course.id,
is_cohorted = is_cohorted, is_cohorted=is_cohorted,
cohorted_discussions = [to_id(name) for name in cohorted_discussions], cohorted_discussions=[to_id(name) for name in cohorted_discussions],
always_cohort_inline_discussions = always_cohort_inline_discussions always_cohort_inline_discussions=always_cohort_inline_discussions
) )
for cohort_name in auto_cohorts: for cohort_name in auto_cohorts:
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
Tests for cohorts Tests for cohorts
""" """
# pylint: disable=no-member # pylint: disable=no-member
import ddt
from mock import call, patch
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db import IntegrityError from django.db import IntegrityError
from django.http import Http404 from django.http import Http404
from django.test import TestCase from django.test import TestCase
from django.test.utils import override_settings from django.test.utils import override_settings
from mock import call, patch
from opaque_keys.edx.locations import SlashSeparatedCourseKey from opaque_keys.edx.locations import SlashSeparatedCourseKey
from student.models import CourseEnrollment from student.models import CourseEnrollment
...@@ -121,6 +122,7 @@ class TestCohortSignals(TestCase): ...@@ -121,6 +122,7 @@ class TestCohortSignals(TestCase):
self.assertFalse(mock_tracker.emit.called) self.assertFalse(mock_tracker.emit.called)
@ddt.ddt
class TestCohorts(ModuleStoreTestCase): class TestCohorts(ModuleStoreTestCase):
""" """
Test the cohorts feature Test the cohorts feature
...@@ -243,12 +245,33 @@ class TestCohorts(ModuleStoreTestCase): ...@@ -243,12 +245,33 @@ class TestCohorts(ModuleStoreTestCase):
cohort.id, cohort.id,
"user should be assigned to the correct cohort" "user should be assigned to the correct cohort"
) )
self.assertEquals( self.assertEquals(
cohorts.get_cohort(other_user, course.id).id, cohorts.get_cohort(other_user, course.id).id,
cohorts.get_cohort_by_name(course.id, cohorts.DEFAULT_COHORT_NAME).id, cohorts.get_cohort_by_name(course.id, cohorts.DEFAULT_COHORT_NAME).id,
"other_user should be assigned to the default cohort" "other_user should be assigned to the default cohort"
) )
@ddt.data(
(True, 2),
(False, 6),
)
@ddt.unpack
def test_get_cohort_sql_queries(self, use_cached, num_sql_queries):
"""
Test number of queries by cohorts.get_cohort() with and without caching.
"""
course = modulestore().get_course(self.toy_course_key)
config_course_cohorts(course, is_cohorted=True)
cohort = CohortFactory(course_id=course.id, name="TestCohort")
user = UserFactory(username="test", email="a@b.com")
cohort.users.add(user)
with self.assertNumQueries(num_sql_queries):
for __ in range(3):
cohorts.get_cohort(user, course.id, use_cached=use_cached)
def test_get_cohort_with_assign(self): def test_get_cohort_with_assign(self):
""" """
Make sure cohorts.get_cohort() returns None if no group is already Make sure cohorts.get_cohort() returns None if no group is already
...@@ -473,7 +496,7 @@ class TestCohorts(ModuleStoreTestCase): ...@@ -473,7 +496,7 @@ class TestCohorts(ModuleStoreTestCase):
config_course_cohorts( config_course_cohorts(
course, course,
is_cohorted=True, is_cohorted=True,
discussion_topics= ["General", "Feedback"], discussion_topics=["General", "Feedback"],
cohorted_discussions=["Feedback"] cohorted_discussions=["Feedback"]
) )
...@@ -497,7 +520,7 @@ class TestCohorts(ModuleStoreTestCase): ...@@ -497,7 +520,7 @@ class TestCohorts(ModuleStoreTestCase):
config_course_cohorts( config_course_cohorts(
course, course,
is_cohorted=True, is_cohorted=True,
discussion_topics =["General", "Feedback"], discussion_topics=["General", "Feedback"],
cohorted_discussions=["Feedback", "random_inline"] cohorted_discussions=["Feedback", "random_inline"]
) )
self.assertTrue( self.assertTrue(
...@@ -741,6 +764,7 @@ class TestCohorts(ModuleStoreTestCase): ...@@ -741,6 +764,7 @@ class TestCohorts(ModuleStoreTestCase):
) )
@ddt.ddt
class TestCohortsAndPartitionGroups(ModuleStoreTestCase): class TestCohortsAndPartitionGroups(ModuleStoreTestCase):
""" """
Test Cohorts and Partitions Groups. Test Cohorts and Partitions Groups.
...@@ -803,6 +827,25 @@ class TestCohortsAndPartitionGroups(ModuleStoreTestCase): ...@@ -803,6 +827,25 @@ class TestCohortsAndPartitionGroups(ModuleStoreTestCase):
(None, None), (None, None),
) )
@ddt.data(
(True, 1),
(False, 3),
)
@ddt.unpack
def test_get_group_info_for_cohort_queries(self, use_cached, num_sql_queries):
"""
Basic test of the partition_group_info accessor function
"""
# create a link for the cohort in the db
self._link_cohort_partition_group(
self.first_cohort,
self.partition_id,
self.group1_id
)
with self.assertNumQueries(num_sql_queries):
for __ in range(3):
self.assertIsNotNone(cohorts.get_group_info_for_cohort(self.first_cohort, use_cached=use_cached))
def test_multiple_cohorts(self): def test_multiple_cohorts(self):
""" """
Test that multiple cohorts can be linked to the same partition group Test that multiple cohorts can be linked to the same partition group
......
...@@ -63,6 +63,7 @@ class TestCohortPartitionScheme(ModuleStoreTestCase): ...@@ -63,6 +63,7 @@ class TestCohortPartitionScheme(ModuleStoreTestCase):
self.course_key, self.course_key,
self.student, self.student,
partition or self.user_partition, partition or self.user_partition,
use_cached=False
), ),
group group
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment