Commit 58229195 by Nimisha Asthagiri

Cache Enrollment state for (user, course) in request cache.

parent 2251097c
...@@ -10,7 +10,7 @@ file and check it in at the same time as your model changes. To do that, ...@@ -10,7 +10,7 @@ file and check it in at the same time as your model changes. To do that,
2. ./manage.py lms schemamigration student --auto description_of_your_change 2. ./manage.py lms schemamigration student --auto description_of_your_change
3. Add the migration file created in edx-platform/common/djangoapps/student/migrations/ 3. Add the migration file created in edx-platform/common/djangoapps/student/migrations/
""" """
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict, namedtuple
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import total_ordering from functools import total_ordering
import hashlib import hashlib
...@@ -54,6 +54,7 @@ from enrollment.api import _default_course_mode ...@@ -54,6 +54,7 @@ from enrollment.api import _default_course_mode
import lms.lib.comment_client as cc import lms.lib.comment_client as cc
from openedx.core.djangoapps.commerce.utils import ecommerce_api_client, ECOMMERCE_DATE_FORMAT from openedx.core.djangoapps.commerce.utils import ecommerce_api_client, ECOMMERCE_DATE_FORMAT
from openedx.core.djangoapps.content.course_overviews.models import CourseOverview from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
import request_cache
from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers from openedx.core.djangoapps.site_configuration import helpers as configuration_helpers
from util.model_utils import emit_field_changed_events, get_changed_fields_dict from util.model_utils import emit_field_changed_events, get_changed_fields_dict
from util.query import use_read_replica_if_available from util.query import use_read_replica_if_available
...@@ -974,6 +975,12 @@ class CourseEnrollmentManager(models.Manager): ...@@ -974,6 +975,12 @@ class CourseEnrollmentManager(models.Manager):
) )
# Named tuple for fields pertaining to the state of
# CourseEnrollment for a user in a course. This type
# is used to cache the state in the request cache.
CourseEnrollmentState = namedtuple('CourseEnrollmentState', 'mode, is_active')
class CourseEnrollment(models.Model): class CourseEnrollment(models.Model):
""" """
Represents a Student's Enrollment record for a single Course. You should Represents a Student's Enrollment record for a single Course. You should
...@@ -1120,6 +1127,11 @@ class CourseEnrollment(models.Model): ...@@ -1120,6 +1127,11 @@ class CourseEnrollment(models.Model):
if activation_changed or mode_changed: if activation_changed or mode_changed:
self.save() self.save()
self._update_enrollment_in_request_cache(
self.user,
self.course_id,
CourseEnrollmentState(self.mode, self.is_active),
)
if activation_changed: if activation_changed:
if self.is_active: if self.is_active:
...@@ -1389,12 +1401,9 @@ class CourseEnrollment(models.Model): ...@@ -1389,12 +1401,9 @@ class CourseEnrollment(models.Model):
""" """
if not user.is_authenticated(): if not user.is_authenticated():
return False return False
else:
try: enrollment_state = cls._get_enrollment_state(user, course_key)
record = cls.objects.get(user=user, course_id=course_key) return enrollment_state.is_active or False
return record.is_active
except cls.DoesNotExist:
return False
@classmethod @classmethod
def is_enrolled_by_partial(cls, user, course_id_partial): def is_enrolled_by_partial(cls, user, course_id_partial):
...@@ -1436,11 +1445,8 @@ class CourseEnrollment(models.Model): ...@@ -1436,11 +1445,8 @@ class CourseEnrollment(models.Model):
and is_active is whether the enrollment is active. and is_active is whether the enrollment is active.
Returns (None, None) if the courseenrollment record does not exist. Returns (None, None) if the courseenrollment record does not exist.
""" """
try: enrollment_state = cls._get_enrollment_state(user, course_id)
record = cls.objects.get(user=user, course_id=course_id) return enrollment_state.mode, enrollment_state.is_active
return (record.mode, record.is_active)
except cls.DoesNotExist:
return (None, None)
@classmethod @classmethod
def enrollments_for_user(cls, user): def enrollments_for_user(cls, user):
...@@ -1593,6 +1599,45 @@ class CourseEnrollment(models.Model): ...@@ -1593,6 +1599,45 @@ class CourseEnrollment(models.Model):
""" """
return cls.COURSE_ENROLLMENT_CACHE_KEY.format(user_id, unicode(course_key)) return cls.COURSE_ENROLLMENT_CACHE_KEY.format(user_id, unicode(course_key))
@classmethod
def _get_enrollment_state(cls, user, course_key):
"""
Returns the CourseEnrollmentState for the given user
and course_key, caching the result for later retrieval.
"""
enrollment_state = cls._get_enrollment_in_request_cache(user, course_key)
if not enrollment_state:
try:
record = cls.objects.get(user=user, course_id=course_key)
enrollment_state = CourseEnrollmentState(record.mode, record.is_active)
except cls.DoesNotExist:
enrollment_state = CourseEnrollmentState(None, None)
cls._update_enrollment_in_request_cache(user, course_key, enrollment_state)
return enrollment_state
@classmethod
def _get_mode_active_request_cache(cls):
"""
Returns the request-specific cache for CourseEnrollment
"""
return request_cache.get_cache('CourseEnrollment.mode_and_active')
@classmethod
def _get_enrollment_in_request_cache(cls, user, course_key):
"""
Returns the cached value (CourseEnrollmentState) for the user's
enrollment in the request cache. If not cached, returns None.
"""
return cls._get_mode_active_request_cache().get((user.id, course_key))
@classmethod
def _update_enrollment_in_request_cache(cls, user, course_key, enrollment_state):
"""
Updates the cached value for the user's enrollment in the
request cache.
"""
cls._get_mode_active_request_cache()[(user.id, course_key)] = enrollment_state
@receiver(models.signals.post_save, sender=CourseEnrollment) @receiver(models.signals.post_save, sender=CourseEnrollment)
@receiver(models.signals.post_delete, sender=CourseEnrollment) @receiver(models.signals.post_delete, sender=CourseEnrollment)
......
...@@ -6,7 +6,6 @@ from datetime import datetime, timedelta ...@@ -6,7 +6,6 @@ from datetime import datetime, timedelta
import json import json
import logging import logging
import unittest import unittest
from urlparse import urljoin
import ddt import ddt
from django.conf import settings from django.conf import settings
...@@ -30,6 +29,7 @@ from certificates.tests.factories import GeneratedCertificateFactory # pylint: ...@@ -30,6 +29,7 @@ from certificates.tests.factories import GeneratedCertificateFactory # pylint:
from config_models.models import cache from config_models.models import cache
from course_modes.models import CourseMode from course_modes.models import CourseMode
from lms.djangoapps.verify_student.models import SoftwareSecurePhotoVerification from lms.djangoapps.verify_student.models import SoftwareSecurePhotoVerification
from openedx.core.djangolib.testing.utils import CacheIsolationTestCase
from openedx.core.djangoapps.programs.models import ProgramsApiConfig from openedx.core.djangoapps.programs.models import ProgramsApiConfig
from openedx.core.djangoapps.programs.tests import factories as programs_factories from openedx.core.djangoapps.programs.tests import factories as programs_factories
from openedx.core.djangoapps.programs.tests.mixins import ProgramsApiConfigMixin from openedx.core.djangoapps.programs.tests.mixins import ProgramsApiConfigMixin
...@@ -589,7 +589,7 @@ class EnrollmentEventTestMixin(EventTestMixin): ...@@ -589,7 +589,7 @@ class EnrollmentEventTestMixin(EventTestMixin):
self.mock_tracker.reset_mock() self.mock_tracker.reset_mock()
class EnrollInCourseTest(EnrollmentEventTestMixin, TestCase): class EnrollInCourseTest(EnrollmentEventTestMixin, CacheIsolationTestCase):
"""Tests enrolling and unenrolling in courses.""" """Tests enrolling and unenrolling in courses."""
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
......
...@@ -229,18 +229,18 @@ class TestFieldOverrideMongoPerformance(FieldOverridePerformanceTestCase): ...@@ -229,18 +229,18 @@ class TestFieldOverrideMongoPerformance(FieldOverridePerformanceTestCase):
# # of sql queries to default, # # of sql queries to default,
# # of mongo queries, # # of mongo queries,
# ) # )
('no_overrides', 1, True, False): (22, 6), ('no_overrides', 1, True, False): (17, 6),
('no_overrides', 2, True, False): (22, 6), ('no_overrides', 2, True, False): (17, 6),
('no_overrides', 3, True, False): (22, 6), ('no_overrides', 3, True, False): (17, 6),
('ccx', 1, True, False): (22, 6), ('ccx', 1, True, False): (17, 6),
('ccx', 2, True, False): (22, 6), ('ccx', 2, True, False): (17, 6),
('ccx', 3, True, False): (22, 6), ('ccx', 3, True, False): (17, 6),
('no_overrides', 1, False, False): (22, 6), ('no_overrides', 1, False, False): (17, 6),
('no_overrides', 2, False, False): (22, 6), ('no_overrides', 2, False, False): (17, 6),
('no_overrides', 3, False, False): (22, 6), ('no_overrides', 3, False, False): (17, 6),
('ccx', 1, False, False): (22, 6), ('ccx', 1, False, False): (17, 6),
('ccx', 2, False, False): (22, 6), ('ccx', 2, False, False): (17, 6),
('ccx', 3, False, False): (22, 6), ('ccx', 3, False, False): (17, 6),
} }
...@@ -252,19 +252,19 @@ class TestFieldOverrideSplitPerformance(FieldOverridePerformanceTestCase): ...@@ -252,19 +252,19 @@ class TestFieldOverrideSplitPerformance(FieldOverridePerformanceTestCase):
__test__ = True __test__ = True
TEST_DATA = { TEST_DATA = {
('no_overrides', 1, True, False): (22, 3), ('no_overrides', 1, True, False): (17, 3),
('no_overrides', 2, True, False): (22, 3), ('no_overrides', 2, True, False): (17, 3),
('no_overrides', 3, True, False): (22, 3), ('no_overrides', 3, True, False): (17, 3),
('ccx', 1, True, False): (22, 3), ('ccx', 1, True, False): (17, 3),
('ccx', 2, True, False): (22, 3), ('ccx', 2, True, False): (17, 3),
('ccx', 3, True, False): (22, 3), ('ccx', 3, True, False): (17, 3),
('ccx', 1, True, True): (23, 3), ('ccx', 1, True, True): (18, 3),
('ccx', 2, True, True): (23, 3), ('ccx', 2, True, True): (18, 3),
('ccx', 3, True, True): (23, 3), ('ccx', 3, True, True): (18, 3),
('no_overrides', 1, False, False): (22, 3), ('no_overrides', 1, False, False): (17, 3),
('no_overrides', 2, False, False): (22, 3), ('no_overrides', 2, False, False): (17, 3),
('no_overrides', 3, False, False): (22, 3), ('no_overrides', 3, False, False): (17, 3),
('ccx', 1, False, False): (22, 3), ('ccx', 1, False, False): (17, 3),
('ccx', 2, False, False): (22, 3), ('ccx', 2, False, False): (17, 3),
('ccx', 3, False, False): (22, 3), ('ccx', 3, False, False): (17, 3),
} }
...@@ -316,7 +316,7 @@ class SelfPacedCourseInfoTestCase(LoginEnrollmentTestCase, SharedModuleStoreTest ...@@ -316,7 +316,7 @@ class SelfPacedCourseInfoTestCase(LoginEnrollmentTestCase, SharedModuleStoreTest
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
def test_num_queries_instructor_paced(self): def test_num_queries_instructor_paced(self):
self.fetch_course_info_with_queries(self.instructor_paced_course, 23, 4) self.fetch_course_info_with_queries(self.instructor_paced_course, 18, 4)
def test_num_queries_self_paced(self): def test_num_queries_self_paced(self):
self.fetch_course_info_with_queries(self.self_paced_course, 23, 4) self.fetch_course_info_with_queries(self.self_paced_course, 18, 4)
...@@ -1346,7 +1346,7 @@ class ProgressPageTests(ModuleStoreTestCase): ...@@ -1346,7 +1346,7 @@ class ProgressPageTests(ModuleStoreTestCase):
self.assertContains(resp, u"Download Your Certificate") self.assertContains(resp, u"Download Your Certificate")
@ddt.data( @ddt.data(
*itertools.product(((39, 4, True), (39, 4, False)), (True, False)) *itertools.product(((34, 4, True), (34, 4, False)), (True, False))
) )
@ddt.unpack @ddt.unpack
def test_query_counts(self, (sql_calls, mongo_calls, self_paced), self_paced_enabled): def test_query_counts(self, (sql_calls, mongo_calls, self_paced), self_paced_enabled):
......
...@@ -347,11 +347,11 @@ class SingleThreadQueryCountTestCase(ModuleStoreTestCase): ...@@ -347,11 +347,11 @@ class SingleThreadQueryCountTestCase(ModuleStoreTestCase):
# course is outside the context manager that is verifying the number of queries, # course is outside the context manager that is verifying the number of queries,
# and with split mongo, that method ends up querying disabled_xblocks (which is then # and with split mongo, that method ends up querying disabled_xblocks (which is then
# cached and hence not queried as part of call_single_thread). # cached and hence not queried as part of call_single_thread).
(ModuleStoreEnum.Type.mongo, 1, 6, 4, 18, 8), (ModuleStoreEnum.Type.mongo, 1, 6, 4, 18, 7),
(ModuleStoreEnum.Type.mongo, 50, 6, 4, 18, 8), (ModuleStoreEnum.Type.mongo, 50, 6, 4, 18, 7),
# split mongo: 3 queries, regardless of thread response size. # split mongo: 3 queries, regardless of thread response size.
(ModuleStoreEnum.Type.split, 1, 3, 3, 17, 8), (ModuleStoreEnum.Type.split, 1, 3, 3, 17, 7),
(ModuleStoreEnum.Type.split, 50, 3, 3, 17, 8), (ModuleStoreEnum.Type.split, 50, 3, 3, 17, 7),
) )
@ddt.unpack @ddt.unpack
def test_number_of_mongo_queries( def test_number_of_mongo_queries(
......
...@@ -9,7 +9,6 @@ from mock import patch ...@@ -9,7 +9,6 @@ from mock import patch
from abc import ABCMeta from abc import ABCMeta
from courseware.models import StudentModule from courseware.models import StudentModule
from django.conf import settings from django.conf import settings
from django.test import TestCase
from django.utils.translation import get_language from django.utils.translation import get_language
from django.utils.translation import override as override_language from django.utils.translation import override as override_language
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
...@@ -18,6 +17,7 @@ from student.tests.factories import UserFactory ...@@ -18,6 +17,7 @@ from student.tests.factories import UserFactory
from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory
from lms.djangoapps.ccx.tests.factories import CcxFactory from lms.djangoapps.ccx.tests.factories import CcxFactory
from openedx.core.djangolib.testing.utils import CacheIsolationTestCase
from student.models import CourseEnrollment, CourseEnrollmentAllowed from student.models import CourseEnrollment, CourseEnrollmentAllowed
from student.roles import CourseCcxCoachRole from student.roles import CourseCcxCoachRole
from student.tests.factories import ( from student.tests.factories import (
...@@ -40,7 +40,7 @@ from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase, TE ...@@ -40,7 +40,7 @@ from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase, TE
@attr(shard=1) @attr(shard=1)
class TestSettableEnrollmentState(TestCase): class TestSettableEnrollmentState(CacheIsolationTestCase):
""" Test the basis class for enrollment tests. """ """ Test the basis class for enrollment tests. """
def setUp(self): def setUp(self):
super(TestSettableEnrollmentState, self).setUp() super(TestSettableEnrollmentState, self).setUp()
...@@ -62,7 +62,7 @@ class TestSettableEnrollmentState(TestCase): ...@@ -62,7 +62,7 @@ class TestSettableEnrollmentState(TestCase):
self.assertEqual(mes, ees) self.assertEqual(mes, ees)
class TestEnrollmentChangeBase(TestCase): class TestEnrollmentChangeBase(CacheIsolationTestCase):
""" """
Test instructor enrollment administration against database effects. Test instructor enrollment administration against database effects.
...@@ -565,7 +565,7 @@ class SettableEnrollmentState(EmailEnrollmentState): ...@@ -565,7 +565,7 @@ class SettableEnrollmentState(EmailEnrollmentState):
@attr(shard=1) @attr(shard=1)
class TestSendBetaRoleEmail(TestCase): class TestSendBetaRoleEmail(CacheIsolationTestCase):
""" """
Test edge cases for `send_beta_role_email` Test edge cases for `send_beta_role_email`
""" """
......
...@@ -611,7 +611,7 @@ class CreditRequirementApiTests(CreditApiTestBase): ...@@ -611,7 +611,7 @@ class CreditRequirementApiTests(CreditApiTestBase):
api.set_credit_requirements(self.course_key, requirements) api.set_credit_requirements(self.course_key, requirements)
# Satisfy one of the requirements, but not the other # Satisfy one of the requirements, but not the other
with self.assertNumQueries(13): with self.assertNumQueries(12):
api.set_credit_requirement_status( api.set_credit_requirement_status(
user, user,
self.course_key, self.course_key,
...@@ -623,7 +623,7 @@ class CreditRequirementApiTests(CreditApiTestBase): ...@@ -623,7 +623,7 @@ class CreditRequirementApiTests(CreditApiTestBase):
self.assertFalse(api.is_user_eligible_for_credit(user.username, self.course_key)) self.assertFalse(api.is_user_eligible_for_credit(user.username, self.course_key))
# Satisfy the other requirement # Satisfy the other requirement
with self.assertNumQueries(22): with self.assertNumQueries(21):
api.set_credit_requirement_status( api.set_credit_requirement_status(
user, user,
self.course_key, self.course_key,
...@@ -677,7 +677,7 @@ class CreditRequirementApiTests(CreditApiTestBase): ...@@ -677,7 +677,7 @@ class CreditRequirementApiTests(CreditApiTestBase):
# Delete the eligibility entries and satisfy the user's eligibility # Delete the eligibility entries and satisfy the user's eligibility
# requirement again to trigger eligibility notification # requirement again to trigger eligibility notification
CreditEligibility.objects.all().delete() CreditEligibility.objects.all().delete()
with self.assertNumQueries(17): with self.assertNumQueries(16):
api.set_credit_requirement_status( api.set_credit_requirement_status(
user, user,
self.course_key, self.course_key,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment