Commit 4619e9f4 by Renzo Lucioni

Merge pull request #8927 from edx/renzo/enrollment-api-sans-modulestore

Remove modulestore dependency from Enrollment API
parents 80cf4d6e b50c9058
""" """
Data Aggregation Layer of the Enrollment API. Collects all enrollment specific data into a single Data Aggregation Layer of the Enrollment API. Collects all enrollment specific data into a single
source to be used throughout the API. source to be used throughout the API.
""" """
import logging import logging
from django.contrib.auth.models import User from django.contrib.auth.models import User
from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.keys import CourseKey
from xmodule.modulestore.django import modulestore
from enrollment.errors import ( from enrollment.errors import (
CourseNotFoundError, CourseEnrollmentClosedError, CourseEnrollmentFullError, CourseNotFoundError, CourseEnrollmentClosedError, CourseEnrollmentFullError,
CourseEnrollmentExistsError, UserNotFoundError, InvalidEnrollmentAttribute CourseEnrollmentExistsError, UserNotFoundError, InvalidEnrollmentAttribute
) )
from enrollment.serializers import CourseEnrollmentSerializer, CourseField from enrollment.serializers import CourseEnrollmentSerializer, CourseField
from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
from student.models import ( from student.models import (
CourseEnrollment, NonExistentCourseError, EnrollmentClosedError, CourseEnrollment, NonExistentCourseError, EnrollmentClosedError,
CourseFullError, AlreadyEnrolledError, CourseEnrollmentAttribute CourseFullError, AlreadyEnrolledError, CourseEnrollmentAttribute
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -245,7 +247,7 @@ def _invalid_attribute(attributes): ...@@ -245,7 +247,7 @@ def _invalid_attribute(attributes):
def get_course_enrollment_info(course_id, include_expired=False): def get_course_enrollment_info(course_id, include_expired=False):
"""Returns all course enrollment information for the given course. """Returns all course enrollment information for the given course.
Based on the course id, return all related course information.. Based on the course id, return all related course information.
Args: Args:
course_id (str): The course to retrieve enrollment information for. course_id (str): The course to retrieve enrollment information for.
...@@ -261,9 +263,12 @@ def get_course_enrollment_info(course_id, include_expired=False): ...@@ -261,9 +263,12 @@ def get_course_enrollment_info(course_id, include_expired=False):
""" """
course_key = CourseKey.from_string(course_id) course_key = CourseKey.from_string(course_id)
course = modulestore().get_course(course_key)
if course is None: try:
course = CourseOverview.get_from_id(course_key)
except CourseOverview.DoesNotExist:
msg = u"Requested enrollment information for unknown course {course}".format(course=course_id) msg = u"Requested enrollment information for unknown course {course}".format(course=course_id)
log.warning(msg) log.warning(msg)
raise CourseNotFoundError(msg) raise CourseNotFoundError(msg)
return CourseField().to_native(course, include_expired=include_expired) else:
return CourseField().to_native(course, include_expired=include_expired)
""" """
Serializers for all Course Enrollment related return objects. Serializers for all Course Enrollment related return objects.
""" """
import logging import logging
from rest_framework import serializers from rest_framework import serializers
from student.models import CourseEnrollment
from course_modes.models import CourseMode from course_modes.models import CourseMode
from student.models import CourseEnrollment
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -39,19 +39,18 @@ class CourseField(serializers.RelatedField): ...@@ -39,19 +39,18 @@ class CourseField(serializers.RelatedField):
""" """
def to_native(self, course, **kwargs): def to_native(self, course, **kwargs):
course_id = unicode(course.id)
course_modes = ModeSerializer( course_modes = ModeSerializer(
CourseMode.modes_for_course(course.id, kwargs.get('include_expired', False), only_selectable=False) CourseMode.modes_for_course(course.id, kwargs.get('include_expired', False), only_selectable=False)
).data # pylint: disable=no-member ).data # pylint: disable=no-member
return { return {
"course_id": course_id, 'course_id': unicode(course.id),
"enrollment_start": course.enrollment_start, 'enrollment_start': course.enrollment_start,
"enrollment_end": course.enrollment_end, 'enrollment_end': course.enrollment_end,
"course_start": course.start, 'course_start': course.start,
"course_end": course.end, 'course_end': course.end,
"invite_only": course.invitation_only, 'invite_only': course.invitation_only,
"course_modes": course_modes, 'course_modes': course_modes,
} }
......
...@@ -16,7 +16,7 @@ from rest_framework.test import APITestCase ...@@ -16,7 +16,7 @@ from rest_framework.test import APITestCase
from rest_framework import status from rest_framework import status
from django.conf import settings from django.conf import settings
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory from xmodule.modulestore.tests.factories import CourseFactory, check_mongo_calls_range
from django.test.utils import override_settings from django.test.utils import override_settings
from course_modes.models import CourseMode from course_modes.models import CourseMode
...@@ -26,8 +26,9 @@ from util.models import RateLimitConfiguration ...@@ -26,8 +26,9 @@ from util.models import RateLimitConfiguration
from util.testing import UrlResetMixin from util.testing import UrlResetMixin
from enrollment import api from enrollment import api
from enrollment.errors import CourseEnrollmentError from enrollment.errors import CourseEnrollmentError
from openedx.core.lib.django_test_client_utils import get_absolute_url from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
from openedx.core.djangoapps.user_api.models import UserOrgTag from openedx.core.djangoapps.user_api.models import UserOrgTag
from openedx.core.lib.django_test_client_utils import get_absolute_url
from student.tests.factories import UserFactory, CourseModeFactory from student.tests.factories import UserFactory, CourseModeFactory
from student.models import CourseEnrollment from student.models import CourseEnrollment
from embargo.test_utils import restrict_course from embargo.test_utils import restrict_course
...@@ -47,6 +48,8 @@ class EnrollmentTestMixin(object): ...@@ -47,6 +48,8 @@ class EnrollmentTestMixin(object):
mode=CourseMode.HONOR, mode=CourseMode.HONOR,
is_active=None, is_active=None,
enrollment_attributes=None, enrollment_attributes=None,
min_mongo_calls=0,
max_mongo_calls=0,
): ):
""" """
Enroll in the course and verify the response's status code. If the expected status is 200, also validates Enroll in the course and verify the response's status code. If the expected status is 200, also validates
...@@ -77,31 +80,33 @@ class EnrollmentTestMixin(object): ...@@ -77,31 +80,33 @@ class EnrollmentTestMixin(object):
if as_server: if as_server:
extra['HTTP_X_EDX_API_KEY'] = self.API_KEY extra['HTTP_X_EDX_API_KEY'] = self.API_KEY
with patch('enrollment.views.audit_log') as mock_audit_log: # Verify that the modulestore is queried as expected.
url = reverse('courseenrollments') with check_mongo_calls_range(min_finds=min_mongo_calls, max_finds=max_mongo_calls):
response = self.client.post(url, json.dumps(data), content_type='application/json', **extra) with patch('enrollment.views.audit_log') as mock_audit_log:
self.assertEqual(response.status_code, expected_status) url = reverse('courseenrollments')
response = self.client.post(url, json.dumps(data), content_type='application/json', **extra)
self.assertEqual(response.status_code, expected_status)
if expected_status == status.HTTP_200_OK: if expected_status == status.HTTP_200_OK:
data = json.loads(response.content) data = json.loads(response.content)
self.assertEqual(course_id, data['course_details']['course_id']) self.assertEqual(course_id, data['course_details']['course_id'])
if mode is not None: if mode is not None:
self.assertEqual(mode, data['mode']) self.assertEqual(mode, data['mode'])
if is_active is not None: if is_active is not None:
self.assertEqual(is_active, data['is_active']) self.assertEqual(is_active, data['is_active'])
else: else:
self.assertTrue(data['is_active']) self.assertTrue(data['is_active'])
if as_server: if as_server:
# Verify that an audit message was logged. # Verify that an audit message was logged.
self.assertTrue(mock_audit_log.called) self.assertTrue(mock_audit_log.called)
# If multiple enrollment calls are made in the scope of a # If multiple enrollment calls are made in the scope of a
# single test, we want to validate that audit messages are # single test, we want to validate that audit messages are
# logged for each call. # logged for each call.
mock_audit_log.reset_mock() mock_audit_log.reset_mock()
return response return response
...@@ -141,6 +146,10 @@ class EnrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase, APITestCase): ...@@ -141,6 +146,10 @@ class EnrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase, APITestCase):
self.rate_limit, rate_duration = throttle.parse_rate(throttle.rate) self.rate_limit, rate_duration = throttle.parse_rate(throttle.rate)
self.course = CourseFactory.create() self.course = CourseFactory.create()
# Load a CourseOverview. This initial load should result in a cache
# miss; the modulestore is queried and course metadata is cached.
__ = CourseOverview.get_from_id(self.course.id)
self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
self.other_user = UserFactory.create() self.other_user = UserFactory.create()
self.client.login(username=self.USERNAME, password=self.PASSWORD) self.client.login(username=self.USERNAME, password=self.PASSWORD)
...@@ -382,6 +391,10 @@ class EnrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase, APITestCase): ...@@ -382,6 +391,10 @@ class EnrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase, APITestCase):
@ddt.unpack @ddt.unpack
def test_get_course_details_course_dates(self, start_datetime, end_datetime, expected_start, expected_end): def test_get_course_details_course_dates(self, start_datetime, end_datetime, expected_start, expected_end):
course = CourseFactory.create(start=start_datetime, end=end_datetime) course = CourseFactory.create(start=start_datetime, end=end_datetime)
# Load a CourseOverview. This initial load should result in a cache
# miss; the modulestore is queried and course metadata is cached.
__ = CourseOverview.get_from_id(course.id)
self.assert_enrollment_status(course_id=unicode(course.id)) self.assert_enrollment_status(course_id=unicode(course.id))
# Check course details # Check course details
...@@ -411,7 +424,12 @@ class EnrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase, APITestCase): ...@@ -411,7 +424,12 @@ class EnrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase, APITestCase):
self.assertEqual(data[0]['course_details']['course_end'], expected_end) self.assertEqual(data[0]['course_details']['course_end'], expected_end)
def test_with_invalid_course_id(self): def test_with_invalid_course_id(self):
self.assert_enrollment_status(course_id='entirely/fake/course', expected_status=status.HTTP_400_BAD_REQUEST) self.assert_enrollment_status(
course_id='entirely/fake/course',
expected_status=status.HTTP_400_BAD_REQUEST,
min_mongo_calls=3,
max_mongo_calls=4
)
def test_get_enrollment_details_bad_course(self): def test_get_enrollment_details_bad_course(self):
resp = self.client.get( resp = self.client.get(
...@@ -797,7 +815,12 @@ class EnrollmentEmbargoTest(EnrollmentTestMixin, UrlResetMixin, ModuleStoreTestC ...@@ -797,7 +815,12 @@ class EnrollmentEmbargoTest(EnrollmentTestMixin, UrlResetMixin, ModuleStoreTestC
def setUp(self): def setUp(self):
""" Create a course and user, then log in. """ """ Create a course and user, then log in. """
super(EnrollmentEmbargoTest, self).setUp('embargo') super(EnrollmentEmbargoTest, self).setUp('embargo')
self.course = CourseFactory.create() self.course = CourseFactory.create()
# Load a CourseOverview. This initial load should result in a cache
# miss; the modulestore is queried and course metadata is cached.
__ = CourseOverview.get_from_id(self.course.id)
self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD)
self.client.login(username=self.USERNAME, password=self.PASSWORD) self.client.login(username=self.USERNAME, password=self.PASSWORD)
self.url = reverse('courseenrollments') self.url = reverse('courseenrollments')
......
...@@ -10,16 +10,19 @@ file and check it in at the same time as your model changes. To do that, ...@@ -10,16 +10,19 @@ file and check it in at the same time as your model changes. To do that,
2. ./manage.py lms schemamigration student --auto description_of_your_change 2. ./manage.py lms schemamigration student --auto description_of_your_change
3. Add the migration file created in edx-platform/common/djangoapps/student/migrations/ 3. Add the migration file created in edx-platform/common/djangoapps/student/migrations/
""" """
from collections import defaultdict, OrderedDict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import total_ordering
import hashlib import hashlib
from importlib import import_module
import json import json
import logging import logging
from pytz import UTC from pytz import UTC
import uuid
from collections import defaultdict, OrderedDict
import dogstats_wrapper as dog_stats_api
from urllib import urlencode from urllib import urlencode
import uuid
import analytics
from config_models.models import ConfigurationModel
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.conf import settings from django.conf import settings
from django.utils import timezone from django.utils import timezone
...@@ -33,28 +36,21 @@ from django.dispatch import receiver, Signal ...@@ -33,28 +36,21 @@ from django.dispatch import receiver, Signal
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.utils.translation import ugettext_noop from django.utils.translation import ugettext_noop
from django_countries.fields import CountryField from django_countries.fields import CountryField
from config_models.models import ConfigurationModel import dogstats_wrapper as dog_stats_api
from track import contexts
from eventtracking import tracker from eventtracking import tracker
from importlib import import_module from opaque_keys.edx.keys import CourseKey
from south.modelsinspector import add_introspection_rules
from opaque_keys.edx.locations import SlashSeparatedCourseKey from opaque_keys.edx.locations import SlashSeparatedCourseKey
from south.modelsinspector import add_introspection_rules
import lms.lib.comment_client as cc from track import contexts
from util.model_utils import emit_field_changed_events, get_changed_fields_dict
from util.query import use_read_replica_if_available
from xmodule_django.models import CourseKeyField, NoneToEmptyManager from xmodule_django.models import CourseKeyField, NoneToEmptyManager
from xmodule.modulestore.exceptions import ItemNotFoundError
from xmodule.modulestore.django import modulestore
from opaque_keys.edx.keys import CourseKey
from functools import total_ordering
from certificates.models import GeneratedCertificate from certificates.models import GeneratedCertificate
from course_modes.models import CourseMode from course_modes.models import CourseMode
import lms.lib.comment_client as cc
from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
from util.model_utils import emit_field_changed_events, get_changed_fields_dict
from util.query import use_read_replica_if_available
import analytics
UNENROLL_DONE = Signal(providing_args=["course_enrollment", "skip_refund"]) UNENROLL_DONE = Signal(providing_args=["course_enrollment", "skip_refund"])
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -1064,18 +1060,15 @@ class CourseEnrollment(models.Model): ...@@ -1064,18 +1060,15 @@ class CourseEnrollment(models.Model):
""" """
# All the server-side checks for whether a user is allowed to enroll. # All the server-side checks for whether a user is allowed to enroll.
try: try:
course = modulestore().get_course(course_key) course = CourseOverview.get_from_id(course_key)
except ItemNotFoundError: except CourseOverview.DoesNotExist:
log.warning( # This is here to preserve legacy behavior which allowed enrollment in courses
u"User %s failed to enroll in non-existent course %s", # announced before the start of content creation.
user.username, if check_access:
course_key.to_deprecated_string(), log.warning(u"User %s failed to enroll in non-existent course %s", user.username, unicode(course_key))
) raise NonExistentCourseError
raise NonExistentCourseError
if check_access: if check_access:
if course is None:
raise NonExistentCourseError
if CourseEnrollment.is_enrollment_closed(user, course): if CourseEnrollment.is_enrollment_closed(user, course):
log.warning( log.warning(
u"User %s failed to enroll in course %s because enrollment is closed", u"User %s failed to enroll in course %s because enrollment is closed",
...@@ -1320,7 +1313,8 @@ class CourseEnrollment(models.Model): ...@@ -1320,7 +1313,8 @@ class CourseEnrollment(models.Model):
@property @property
def course(self): def course(self):
return modulestore().get_course(self.course_id) # Deprecated. Please use the `course_overview` property instead.
return self.course_overview
@property @property
def course_overview(self): def course_overview(self):
...@@ -1334,7 +1328,6 @@ class CourseEnrollment(models.Model): ...@@ -1334,7 +1328,6 @@ class CourseEnrollment(models.Model):
become stale. become stale.
""" """
if not self._course_overview: if not self._course_overview:
from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
try: try:
self._course_overview = CourseOverview.get_from_id(self.course_id) self._course_overview = CourseOverview.get_from_id(self.course_id)
except (CourseOverview.DoesNotExist, IOError): except (CourseOverview.DoesNotExist, IOError):
......
...@@ -2,26 +2,27 @@ ...@@ -2,26 +2,27 @@
Unit tests for getting the list of courses for a user through iterating all courses and Unit tests for getting the list of courses for a user through iterating all courses and
by reversing group name formats. by reversing group name formats.
""" """
import unittest
from django.conf import settings
from django.test.client import Client
import mock import mock
from mock import patch, Mock
from student.tests.factories import UserFactory from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
from student.models import CourseEnrollment
from student.roles import GlobalStaff from student.roles import GlobalStaff
from student.tests.factories import UserFactory
from student.views import get_course_enrollments
from xmodule.error_module import ErrorDescriptor
from xmodule.modulestore import ModuleStoreEnum from xmodule.modulestore import ModuleStoreEnum
from xmodule.modulestore.django import modulestore
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.modulestore.django import modulestore
from xmodule.error_module import ErrorDescriptor
from django.test.client import Client
from student.models import CourseEnrollment
from student.views import get_course_enrollments
from util.milestones_helpers import ( from util.milestones_helpers import (
get_pre_requisite_courses_not_completed, get_pre_requisite_courses_not_completed,
set_prerequisite_courses, set_prerequisite_courses,
seed_milestone_relationship_types seed_milestone_relationship_types
) )
import unittest
from django.conf import settings
class TestCourseListing(ModuleStoreTestCase): class TestCourseListing(ModuleStoreTestCase):
...@@ -91,10 +92,12 @@ class TestCourseListing(ModuleStoreTestCase): ...@@ -91,10 +92,12 @@ class TestCourseListing(ModuleStoreTestCase):
course_key = mongo_store.make_course_key('Org1', 'Course1', 'Run1') course_key = mongo_store.make_course_key('Org1', 'Course1', 'Run1')
self._create_course_with_access_groups(course_key, default_store=ModuleStoreEnum.Type.mongo) self._create_course_with_access_groups(course_key, default_store=ModuleStoreEnum.Type.mongo)
with patch('xmodule.modulestore.mongo.base.MongoKeyValueStore', Mock(side_effect=Exception)): with mock.patch('xmodule.modulestore.mongo.base.MongoKeyValueStore', mock.Mock(side_effect=Exception)):
self.assertIsInstance(modulestore().get_course(course_key), ErrorDescriptor) self.assertIsInstance(modulestore().get_course(course_key), ErrorDescriptor)
# get courses through iterating all courses # Invalidate (e.g., delete) the corresponding CourseOverview, forcing get_course to be called.
CourseOverview.objects.filter(id=course_key).delete()
courses_list = list(get_course_enrollments(self.student, None, [])) courses_list = list(get_course_enrollments(self.student, None, []))
self.assertEqual(courses_list, []) self.assertEqual(courses_list, [])
......
...@@ -112,10 +112,10 @@ class TestRecentEnrollments(ModuleStoreTestCase): ...@@ -112,10 +112,10 @@ class TestRecentEnrollments(ModuleStoreTestCase):
recent_course_list = _get_recently_enrolled_courses(courses_list) recent_course_list = _get_recently_enrolled_courses(courses_list)
self.assertEqual(len(recent_course_list), 5) self.assertEqual(len(recent_course_list), 5)
self.assertEqual(recent_course_list[1].course, courses[0]) self.assertEqual(recent_course_list[1].course.id, courses[0].id)
self.assertEqual(recent_course_list[2].course, courses[1]) self.assertEqual(recent_course_list[2].course.id, courses[1].id)
self.assertEqual(recent_course_list[3].course, courses[2]) self.assertEqual(recent_course_list[3].course.id, courses[2].id)
self.assertEqual(recent_course_list[4].course, courses[3]) self.assertEqual(recent_course_list[4].course.id, courses[3].id)
def test_dashboard_rendering(self): def test_dashboard_rendering(self):
""" """
......
...@@ -494,9 +494,9 @@ class DashboardTest(ModuleStoreTestCase): ...@@ -494,9 +494,9 @@ class DashboardTest(ModuleStoreTestCase):
""" """
Check that the student dashboard makes use of course metadata caching. Check that the student dashboard makes use of course metadata caching.
The first time the student dashboard displays a specific course, it will After enrolling a student in a course, that course's metadata should be
make a call to the module store. After that first request, though, the cached as a CourseOverview. The student dashboard should never have to make
course's metadata should be cached as a CourseOverview. calls to the modulestore.
Arguments: Arguments:
modulestore_type (ModuleStoreEnum.Type): Type of modulestore to create modulestore_type (ModuleStoreEnum.Type): Type of modulestore to create
...@@ -511,23 +511,21 @@ class DashboardTest(ModuleStoreTestCase): ...@@ -511,23 +511,21 @@ class DashboardTest(ModuleStoreTestCase):
involve adding fields to CourseOverview so that loading a full involve adding fields to CourseOverview so that loading a full
CourseDescriptor isn't necessary. CourseDescriptor isn't necessary.
""" """
# Create a course, log in the user, and enroll them in the course. # Create a course and log in the user.
test_course = CourseFactory.create(default_store=modulestore_type) test_course = CourseFactory.create(default_store=modulestore_type)
self.client.login(username="jack", password="test") self.client.login(username="jack", password="test")
CourseEnrollment.enroll(self.user, test_course.id)
# The first request will result in a modulestore query. # Enrolling the user in the course will result in a modulestore query.
with check_mongo_calls(expected_mongo_calls): with check_mongo_calls(expected_mongo_calls):
response_1 = self.client.get(reverse('dashboard')) CourseEnrollment.enroll(self.user, test_course.id)
self.assertEquals(response_1.status_code, 200)
# Subsequent requests will only result in SQL queries to load the # Subsequent requests will only result in SQL queries to load the
# CourseOverview object that has been created. # CourseOverview object that has been created.
with check_mongo_calls(0): with check_mongo_calls(0):
response_1 = self.client.get(reverse('dashboard'))
self.assertEquals(response_1.status_code, 200)
response_2 = self.client.get(reverse('dashboard')) response_2 = self.client.get(reverse('dashboard'))
self.assertEquals(response_2.status_code, 200) self.assertEquals(response_2.status_code, 200)
response_3 = self.client.get(reverse('dashboard'))
self.assertEquals(response_3.status_code, 200)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
@patch.dict(settings.FEATURES, {"IS_EDX_DOMAIN": True}) @patch.dict(settings.FEATURES, {"IS_EDX_DOMAIN": True})
......
...@@ -1353,7 +1353,7 @@ class ModuleStoreWriteBase(ModuleStoreReadBase, ModuleStoreWrite): ...@@ -1353,7 +1353,7 @@ class ModuleStoreWriteBase(ModuleStoreReadBase, ModuleStoreWrite):
otherwise a publish will be signalled at the end of the bulk operation otherwise a publish will be signalled at the end of the bulk operation
Arguments: Arguments:
library_updated - library_updated to which the signal applies library_key - library_key to which the signal applies
""" """
signal_handler = getattr(self, 'signal_handler', None) signal_handler = getattr(self, 'signal_handler', None)
if signal_handler: if signal_handler:
...@@ -1363,6 +1363,14 @@ class ModuleStoreWriteBase(ModuleStoreReadBase, ModuleStoreWrite): ...@@ -1363,6 +1363,14 @@ class ModuleStoreWriteBase(ModuleStoreReadBase, ModuleStoreWrite):
else: else:
signal_handler.send("library_updated", library_key=library_key) signal_handler.send("library_updated", library_key=library_key)
def _emit_course_deleted_signal(self, course_key):
"""
Helper method used to emit the course_deleted signal.
"""
signal_handler = getattr(self, 'signal_handler', None)
if signal_handler:
signal_handler.send("course_deleted", course_key=course_key)
def only_xmodules(identifier, entry_points): def only_xmodules(identifier, entry_points):
"""Only use entry_points that are supplied by the xmodule package""" """Only use entry_points that are supplied by the xmodule package"""
......
...@@ -87,11 +87,13 @@ class SignalHandler(object): ...@@ -87,11 +87,13 @@ class SignalHandler(object):
do the actual work. do the actual work.
""" """
course_published = django.dispatch.Signal(providing_args=["course_key"]) course_published = django.dispatch.Signal(providing_args=["course_key"])
course_deleted = django.dispatch.Signal(providing_args=["course_key"])
library_updated = django.dispatch.Signal(providing_args=["library_key"]) library_updated = django.dispatch.Signal(providing_args=["library_key"])
_mapping = { _mapping = {
"course_published": course_published, "course_published": course_published,
"library_updated": library_updated "course_deleted": course_deleted,
"library_updated": library_updated,
} }
def __init__(self, modulestore_class): def __init__(self, modulestore_class):
......
...@@ -167,6 +167,8 @@ class DraftModuleStore(MongoModuleStore): ...@@ -167,6 +167,8 @@ class DraftModuleStore(MongoModuleStore):
self.collection.remove(course_query, multi=True) self.collection.remove(course_query, multi=True)
self.delete_all_asset_metadata(course_key, user_id) self.delete_all_asset_metadata(course_key, user_id)
self._emit_course_deleted_signal(course_key)
def clone_course(self, source_course_id, dest_course_id, user_id, fields=None, **kwargs): def clone_course(self, source_course_id, dest_course_id, user_id, fields=None, **kwargs):
""" """
Only called if cloning within this store or if env doesn't set up mixed. Only called if cloning within this store or if env doesn't set up mixed.
......
...@@ -2443,6 +2443,8 @@ class SplitMongoModuleStore(SplitBulkWriteMixin, ModuleStoreWriteBase): ...@@ -2443,6 +2443,8 @@ class SplitMongoModuleStore(SplitBulkWriteMixin, ModuleStoreWriteBase):
# in case the course is later restored. # in case the course is later restored.
# super(SplitMongoModuleStore, self).delete_course(course_key, user_id) # super(SplitMongoModuleStore, self).delete_course(course_key, user_id)
self._emit_course_deleted_signal(course_key)
@contract(block_map="dict(BlockKey: dict)", block_key=BlockKey) @contract(block_map="dict(BlockKey: dict)", block_key=BlockKey)
def inherit_settings( def inherit_settings(
self, block_map, block_key, inherited_settings_map, inheriting_settings=None, inherited_from=None self, block_map, block_key, inherited_settings_map, inheriting_settings=None, inherited_from=None
......
...@@ -2414,6 +2414,34 @@ class TestMixedModuleStore(CommonMixedModuleStoreSetup): ...@@ -2414,6 +2414,34 @@ class TestMixedModuleStore(CommonMixedModuleStoreSetup):
self.assertEqual(receiver.call_count, 0) self.assertEqual(receiver.call_count, 0)
self.assertEqual(receiver.call_count, 0) self.assertEqual(receiver.call_count, 0)
@ddt.data(ModuleStoreEnum.Type.mongo, ModuleStoreEnum.Type.split)
def test_course_deleted_signal(self, default):
with MongoContentstoreBuilder().build() as contentstore:
self.store = MixedModuleStore(
contentstore=contentstore,
create_modulestore_instance=create_modulestore_instance,
mappings={},
signal_handler=SignalHandler(MixedModuleStore),
**self.OPTIONS
)
self.addCleanup(self.store.close_all_connections)
with self.store.default_store(default):
self.assertIsNotNone(self.store.thread_cache.default_store.signal_handler)
with mock_signal_receiver(SignalHandler.course_deleted) as receiver:
self.assertEqual(receiver.call_count, 0)
# Create a course
course = self.store.create_course('org_x', 'course_y', 'run_z', self.user_id)
course_key = course.id
# Delete the course
course = self.store.delete_course(course_key, self.user_id)
# Verify that the signal was emitted
self.assertEqual(receiver.call_count, 1)
@ddt.ddt @ddt.ddt
@attr('mongo') @attr('mongo')
......
...@@ -10,8 +10,8 @@ Note: The access control logic in this file does NOT check for enrollment in ...@@ -10,8 +10,8 @@ Note: The access control logic in this file does NOT check for enrollment in
If enrollment is to be checked, use get_course_with_access in courseware.courses. If enrollment is to be checked, use get_course_with_access in courseware.courses.
It is a wrapper around has_access that additionally checks for enrollment. It is a wrapper around has_access that additionally checks for enrollment.
""" """
import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging
import pytz import pytz
from django.conf import settings from django.conf import settings
...@@ -245,6 +245,58 @@ def _can_load_course_on_mobile(user, course): ...@@ -245,6 +245,58 @@ def _can_load_course_on_mobile(user, course):
) )
def _can_enroll_courselike(user, courselike):
"""
Ascertain if the user can enroll in the given courselike object.
Arguments:
user (User): The user attempting to enroll.
courselike (CourseDescriptor or CourseOverview): The object representing the
course in which the user is trying to enroll.
Returns:
AccessResponse, indicating whether the user can enroll.
"""
enrollment_domain = courselike.enrollment_domain
# Courselike objects (e.g., course descriptors and CourseOverviews) have an attribute named `id`
# which actually points to a CourseKey. Sigh.
course_key = courselike.id
# If using a registration method to restrict enrollment (e.g., Shibboleth)
if settings.FEATURES.get('RESTRICT_ENROLL_BY_REG_METHOD') and enrollment_domain:
if user is not None and user.is_authenticated() and \
ExternalAuthMap.objects.filter(user=user, external_domain=enrollment_domain):
debug("Allow: external_auth of " + enrollment_domain)
reg_method_ok = True
else:
reg_method_ok = False
else:
reg_method_ok = True
# If the user appears in CourseEnrollmentAllowed paired with the given course key,
# they may enroll. Note that as dictated by the legacy database schema, the filter
# call includes a `course_id` kwarg which requires a CourseKey.
if user is not None and user.is_authenticated():
if CourseEnrollmentAllowed.objects.filter(email=user.email, course_id=course_key):
return ACCESS_GRANTED
if _has_staff_access_to_descriptor(user, courselike, course_key):
return ACCESS_GRANTED
if courselike.invitation_only:
debug("Deny: invitation only")
return ACCESS_DENIED
now = datetime.now(UTC())
enrollment_start = courselike.enrollment_start or datetime.min.replace(tzinfo=pytz.UTC)
enrollment_end = courselike.enrollment_end or datetime.max.replace(tzinfo=pytz.UTC)
if reg_method_ok and enrollment_start < now < enrollment_end:
debug("Allow: in enrollment period")
return ACCESS_GRANTED
return ACCESS_DENIED
def _has_access_course_desc(user, action, course): def _has_access_course_desc(user, action, course):
""" """
Check if user has access to a course descriptor. Check if user has access to a course descriptor.
...@@ -271,54 +323,7 @@ def _has_access_course_desc(user, action, course): ...@@ -271,54 +323,7 @@ def _has_access_course_desc(user, action, course):
return _has_access_descriptor(user, 'load', course, course.id) return _has_access_descriptor(user, 'load', course, course.id)
def can_enroll(): def can_enroll():
""" return _can_enroll_courselike(user, course)
First check if restriction of enrollment by login method is enabled, both
globally and by the course.
If it is, then the user must pass the criterion set by the course, e.g. that ExternalAuthMap
was set by 'shib:https://idp.stanford.edu/", in addition to requirements below.
Rest of requirements:
(CourseEnrollmentAllowed always overrides)
or
(staff can always enroll)
or
Enrollment can only happen in the course enrollment period, if one exists, and
course is not invitation only.
"""
# if using registration method to restrict (say shibboleth)
if settings.FEATURES.get('RESTRICT_ENROLL_BY_REG_METHOD') and course.enrollment_domain:
if user is not None and user.is_authenticated() and \
ExternalAuthMap.objects.filter(user=user, external_domain=course.enrollment_domain):
debug("Allow: external_auth of " + course.enrollment_domain)
reg_method_ok = True
else:
reg_method_ok = False
else:
reg_method_ok = True # if not using this access check, it's always OK.
now = datetime.now(UTC())
start = course.enrollment_start or datetime.min.replace(tzinfo=pytz.UTC)
end = course.enrollment_end or datetime.max.replace(tzinfo=pytz.UTC)
# if user is in CourseEnrollmentAllowed with right course key then can also enroll
# (note that course.id actually points to a CourseKey)
# (the filter call uses course_id= since that's the legacy database schema)
# (sorry that it's confusing :( )
if user is not None and user.is_authenticated() and CourseEnrollmentAllowed:
if CourseEnrollmentAllowed.objects.filter(email=user.email, course_id=course.id):
return ACCESS_GRANTED
if _has_staff_access_to_descriptor(user, course, course.id):
return ACCESS_GRANTED
# Invitation_only doesn't apply to CourseEnrollmentAllowed or has_staff_access_access
if course.invitation_only:
debug("Deny: invitation only")
return ACCESS_DENIED
if reg_method_ok and start < now < end:
debug("Allow: in enrollment period")
return ACCESS_GRANTED
def see_exists(): def see_exists():
""" """
...@@ -412,6 +417,7 @@ def _can_load_course_overview(user, course_overview): ...@@ -412,6 +417,7 @@ def _can_load_course_overview(user, course_overview):
) )
_COURSE_OVERVIEW_CHECKERS = { _COURSE_OVERVIEW_CHECKERS = {
'enroll': _can_enroll_courselike,
'load': _can_load_course_overview, 'load': _can_load_course_overview,
'load_mobile': lambda user, course_overview: ( 'load_mobile': lambda user, course_overview: (
_can_load_course_overview(user, course_overview) _can_load_course_overview(user, course_overview)
......
...@@ -515,6 +515,12 @@ class CourseOverviewAccessTestCase(ModuleStoreTestCase): ...@@ -515,6 +515,12 @@ class CourseOverviewAccessTestCase(ModuleStoreTestCase):
self.user_staff = UserFactory.create(is_staff=True) self.user_staff = UserFactory.create(is_staff=True)
self.user_anonymous = AnonymousUserFactory.create() self.user_anonymous = AnonymousUserFactory.create()
ENROLL_TEST_DATA = list(itertools.product(
['user_normal', 'user_staff', 'user_anonymous'],
['enroll'],
['course_default', 'course_started', 'course_not_started', 'course_staff_only'],
))
LOAD_TEST_DATA = list(itertools.product( LOAD_TEST_DATA = list(itertools.product(
['user_normal', 'user_beta_tester', 'user_staff'], ['user_normal', 'user_beta_tester', 'user_staff'],
['load'], ['load'],
...@@ -533,7 +539,7 @@ class CourseOverviewAccessTestCase(ModuleStoreTestCase): ...@@ -533,7 +539,7 @@ class CourseOverviewAccessTestCase(ModuleStoreTestCase):
['course_default', 'course_with_pre_requisite', 'course_with_pre_requisites'], ['course_default', 'course_with_pre_requisite', 'course_with_pre_requisites'],
)) ))
@ddt.data(*(LOAD_TEST_DATA + LOAD_MOBILE_TEST_DATA + PREREQUISITES_TEST_DATA)) @ddt.data(*(ENROLL_TEST_DATA + LOAD_TEST_DATA + LOAD_MOBILE_TEST_DATA + PREREQUISITES_TEST_DATA))
@ddt.unpack @ddt.unpack
def test_course_overview_access(self, user_attr_name, action, course_attr_name): def test_course_overview_access(self, user_attr_name, action, course_attr_name):
""" """
......
# -*- coding: utf-8 -*-
from south.utils import datetime_utils as datetime
from south.db import db
from south.v2 import SchemaMigration
from django.db import models
class Migration(SchemaMigration):
def forwards(self, orm):
# The default values for these new columns may not match the actual
# values of courses already present in the table. To ensure that the
# cached values are correct, we must clear the table before adding any
# new columns.
db.clear_table('course_overviews_courseoverview')
# Adding field 'CourseOverview.enrollment_start'
db.add_column('course_overviews_courseoverview', 'enrollment_start',
self.gf('django.db.models.fields.DateTimeField')(null=True),
keep_default=False)
# Adding field 'CourseOverview.enrollment_end'
db.add_column('course_overviews_courseoverview', 'enrollment_end',
self.gf('django.db.models.fields.DateTimeField')(null=True),
keep_default=False)
# Adding field 'CourseOverview.enrollment_domain'
db.add_column('course_overviews_courseoverview', 'enrollment_domain',
self.gf('django.db.models.fields.TextField')(null=True),
keep_default=False)
# Adding field 'CourseOverview.invitation_only'
db.add_column('course_overviews_courseoverview', 'invitation_only',
self.gf('django.db.models.fields.BooleanField')(default=False),
keep_default=False)
# Adding field 'CourseOverview.max_student_enrollments_allowed'
db.add_column('course_overviews_courseoverview', 'max_student_enrollments_allowed',
self.gf('django.db.models.fields.IntegerField')(null=True),
keep_default=False)
def backwards(self, orm):
# Deleting field 'CourseOverview.enrollment_start'
db.delete_column('course_overviews_courseoverview', 'enrollment_start')
# Deleting field 'CourseOverview.enrollment_end'
db.delete_column('course_overviews_courseoverview', 'enrollment_end')
# Deleting field 'CourseOverview.enrollment_domain'
db.delete_column('course_overviews_courseoverview', 'enrollment_domain')
# Deleting field 'CourseOverview.invitation_only'
db.delete_column('course_overviews_courseoverview', 'invitation_only')
# Deleting field 'CourseOverview.max_student_enrollments_allowed'
db.delete_column('course_overviews_courseoverview', 'max_student_enrollments_allowed')
models = {
'course_overviews.courseoverview': {
'Meta': {'object_name': 'CourseOverview'},
'_location': ('xmodule_django.models.UsageKeyField', [], {'max_length': '255'}),
'_pre_requisite_courses_json': ('django.db.models.fields.TextField', [], {}),
'advertised_start': ('django.db.models.fields.TextField', [], {'null': 'True'}),
'cert_html_view_enabled': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'cert_name_long': ('django.db.models.fields.TextField', [], {}),
'cert_name_short': ('django.db.models.fields.TextField', [], {}),
'certificates_display_behavior': ('django.db.models.fields.TextField', [], {'null': 'True'}),
'certificates_show_before_end': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'course_image_url': ('django.db.models.fields.TextField', [], {}),
'days_early_for_beta': ('django.db.models.fields.FloatField', [], {'null': 'True'}),
'display_name': ('django.db.models.fields.TextField', [], {'null': 'True'}),
'display_number_with_default': ('django.db.models.fields.TextField', [], {}),
'display_org_with_default': ('django.db.models.fields.TextField', [], {}),
'end': ('django.db.models.fields.DateTimeField', [], {'null': 'True'}),
'end_of_course_survey_url': ('django.db.models.fields.TextField', [], {'null': 'True'}),
'enrollment_domain': ('django.db.models.fields.TextField', [], {'null': 'True'}),
'enrollment_end': ('django.db.models.fields.DateTimeField', [], {'null': 'True'}),
'enrollment_start': ('django.db.models.fields.DateTimeField', [], {'null': 'True'}),
'facebook_url': ('django.db.models.fields.TextField', [], {'null': 'True'}),
'has_any_active_web_certificate': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'id': ('xmodule_django.models.CourseKeyField', [], {'max_length': '255', 'primary_key': 'True', 'db_index': 'True'}),
'invitation_only': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'lowest_passing_grade': ('django.db.models.fields.DecimalField', [], {'max_digits': '5', 'decimal_places': '2'}),
'max_student_enrollments_allowed': ('django.db.models.fields.IntegerField', [], {'null': 'True'}),
'mobile_available': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
'social_sharing_url': ('django.db.models.fields.TextField', [], {'null': 'True'}),
'start': ('django.db.models.fields.DateTimeField', [], {'null': 'True'}),
'visible_to_staff_only': ('django.db.models.fields.BooleanField', [], {'default': 'False'})
}
}
complete_apps = ['course_overviews']
\ No newline at end of file
...@@ -5,7 +5,7 @@ Declaration of CourseOverview model ...@@ -5,7 +5,7 @@ Declaration of CourseOverview model
import json import json
import django.db.models import django.db.models
from django.db.models.fields import BooleanField, DateTimeField, DecimalField, TextField, FloatField from django.db.models.fields import BooleanField, DateTimeField, DecimalField, TextField, FloatField, IntegerField
from django.utils.translation import ugettext from django.utils.translation import ugettext
from util.date_utils import strftime_localized from util.date_utils import strftime_localized
...@@ -60,6 +60,13 @@ class CourseOverview(django.db.models.Model): ...@@ -60,6 +60,13 @@ class CourseOverview(django.db.models.Model):
visible_to_staff_only = BooleanField() visible_to_staff_only = BooleanField()
_pre_requisite_courses_json = TextField() # JSON representation of list of CourseKey strings _pre_requisite_courses_json = TextField() # JSON representation of list of CourseKey strings
# Enrollment details
enrollment_start = DateTimeField(null=True)
enrollment_end = DateTimeField(null=True)
enrollment_domain = TextField(null=True)
invitation_only = BooleanField(default=False)
max_student_enrollments_allowed = IntegerField(null=True)
@staticmethod @staticmethod
def _create_from_course(course): def _create_from_course(course):
""" """
...@@ -114,7 +121,13 @@ class CourseOverview(django.db.models.Model): ...@@ -114,7 +121,13 @@ class CourseOverview(django.db.models.Model):
days_early_for_beta=course.days_early_for_beta, days_early_for_beta=course.days_early_for_beta,
mobile_available=course.mobile_available, mobile_available=course.mobile_available,
visible_to_staff_only=course.visible_to_staff_only, visible_to_staff_only=course.visible_to_staff_only,
_pre_requisite_courses_json=json.dumps(course.pre_requisite_courses) _pre_requisite_courses_json=json.dumps(course.pre_requisite_courses),
enrollment_start=course.enrollment_start,
enrollment_end=course.enrollment_end,
enrollment_domain=course.enrollment_domain,
invitation_only=course.invitation_only,
max_student_enrollments_allowed=course.max_student_enrollments_allowed,
) )
@staticmethod @staticmethod
......
...@@ -3,9 +3,8 @@ Signal handler for invalidating cached course overviews ...@@ -3,9 +3,8 @@ Signal handler for invalidating cached course overviews
""" """
from django.dispatch.dispatcher import receiver from django.dispatch.dispatcher import receiver
from xmodule.modulestore.django import SignalHandler
from .models import CourseOverview from .models import CourseOverview
from xmodule.modulestore.django import SignalHandler
@receiver(SignalHandler.course_published) @receiver(SignalHandler.course_published)
...@@ -15,3 +14,12 @@ def _listen_for_course_publish(sender, course_key, **kwargs): # pylint: disable ...@@ -15,3 +14,12 @@ def _listen_for_course_publish(sender, course_key, **kwargs): # pylint: disable
invalidates the corresponding CourseOverview cache entry if one exists. invalidates the corresponding CourseOverview cache entry if one exists.
""" """
CourseOverview.objects.filter(id=course_key).delete() CourseOverview.objects.filter(id=course_key).delete()
@receiver(SignalHandler.course_deleted)
def _listen_for_course_delete(sender, course_key, **kwargs): # pylint: disable=unused-argument
"""
Catches the signal that a course has been deleted from Studio and
invalidates the corresponding CourseOverview cache entry if one exists.
"""
CourseOverview.objects.filter(id=course_key).delete()
...@@ -91,6 +91,9 @@ class CourseOverviewTestCase(ModuleStoreTestCase): ...@@ -91,6 +91,9 @@ class CourseOverviewTestCase(ModuleStoreTestCase):
'display_name_with_default', 'display_name_with_default',
'start_date_is_still_default', 'start_date_is_still_default',
'pre_requisite_courses', 'pre_requisite_courses',
'enrollment_domain',
'invitation_only',
'max_student_enrollments_allowed',
] ]
for attribute_name in fields_to_test: for attribute_name in fields_to_test:
course_value = getattr(course, attribute_name) course_value = getattr(course, attribute_name)
...@@ -125,24 +128,38 @@ class CourseOverviewTestCase(ModuleStoreTestCase): ...@@ -125,24 +128,38 @@ class CourseOverviewTestCase(ModuleStoreTestCase):
# resulting values are often off by fractions of a second. So, as a # resulting values are often off by fractions of a second. So, as a
# workaround, we simply test if the start and end times are the same # workaround, we simply test if the start and end times are the same
# number of seconds from the Unix epoch. # number of seconds from the Unix epoch.
others_to_test = [( others_to_test = [
course_image_url(course), (
course_overview_cache_miss.course_image_url, course_image_url(course),
course_overview_cache_hit.course_image_url course_overview_cache_miss.course_image_url,
), ( course_overview_cache_hit.course_image_url
get_active_web_certificate(course) is not None, ),
course_overview_cache_miss.has_any_active_web_certificate, (
course_overview_cache_hit.has_any_active_web_certificate get_active_web_certificate(course) is not None,
course_overview_cache_miss.has_any_active_web_certificate,
), ( course_overview_cache_hit.has_any_active_web_certificate
get_seconds_since_epoch(course.start), ),
get_seconds_since_epoch(course_overview_cache_miss.start), (
get_seconds_since_epoch(course_overview_cache_hit.start), get_seconds_since_epoch(course.start),
), ( get_seconds_since_epoch(course_overview_cache_miss.start),
get_seconds_since_epoch(course.end), get_seconds_since_epoch(course_overview_cache_hit.start),
get_seconds_since_epoch(course_overview_cache_miss.end), ),
get_seconds_since_epoch(course_overview_cache_hit.end), (
)] get_seconds_since_epoch(course.end),
get_seconds_since_epoch(course_overview_cache_miss.end),
get_seconds_since_epoch(course_overview_cache_hit.end),
),
(
get_seconds_since_epoch(course.enrollment_start),
get_seconds_since_epoch(course_overview_cache_miss.enrollment_start),
get_seconds_since_epoch(course_overview_cache_hit.enrollment_start),
),
(
get_seconds_since_epoch(course.enrollment_end),
get_seconds_since_epoch(course_overview_cache_miss.enrollment_end),
get_seconds_since_epoch(course_overview_cache_hit.enrollment_end),
),
]
for (course_value, cache_miss_value, cache_hit_value) in others_to_test: for (course_value, cache_miss_value, cache_hit_value) in others_to_test:
self.assertEqual(course_value, cache_miss_value) self.assertEqual(course_value, cache_miss_value)
self.assertEqual(cache_miss_value, cache_hit_value) self.assertEqual(cache_miss_value, cache_hit_value)
...@@ -211,7 +228,7 @@ class CourseOverviewTestCase(ModuleStoreTestCase): ...@@ -211,7 +228,7 @@ class CourseOverviewTestCase(ModuleStoreTestCase):
@ddt.data(ModuleStoreEnum.Type.mongo, ModuleStoreEnum.Type.split) @ddt.data(ModuleStoreEnum.Type.mongo, ModuleStoreEnum.Type.split)
def test_course_overview_cache_invalidation(self, modulestore_type): def test_course_overview_cache_invalidation(self, modulestore_type):
""" """
Tests that when a course is published, the corresponding Tests that when a course is published or deleted, the corresponding
course_overview is removed from the cache. course_overview is removed from the cache.
Arguments: Arguments:
...@@ -236,6 +253,11 @@ class CourseOverviewTestCase(ModuleStoreTestCase): ...@@ -236,6 +253,11 @@ class CourseOverviewTestCase(ModuleStoreTestCase):
course_overview_2 = CourseOverview.get_from_id(course.id) course_overview_2 = CourseOverview.get_from_id(course.id)
self.assertFalse(course_overview_2.mobile_available) self.assertFalse(course_overview_2.mobile_available)
# Verify that when the course is deleted, the corresponding CourseOverview is deleted as well.
with self.assertRaises(CourseOverview.DoesNotExist):
self.store.delete_course(course.id, ModuleStoreEnum.UserID.test)
CourseOverview.get_from_id(course.id)
@ddt.data((ModuleStoreEnum.Type.mongo, 1, 1), (ModuleStoreEnum.Type.split, 3, 4)) @ddt.data((ModuleStoreEnum.Type.mongo, 1, 1), (ModuleStoreEnum.Type.split, 3, 4))
@ddt.unpack @ddt.unpack
def test_course_overview_caching(self, modulestore_type, min_mongo_calls, max_mongo_calls): def test_course_overview_caching(self, modulestore_type, min_mongo_calls, max_mongo_calls):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment