Commit e3d820a7 by Calen Pennington

Merge pull request #8519 from jazkarta/check-ccx-enabled

WIP: Limit application of CCX overrides to enabled courses and providers
parents 11e10499 b8e63cbb
...@@ -47,6 +47,14 @@ class CustomCoursesForEdxOverrideProvider(FieldOverrideProvider): ...@@ -47,6 +47,14 @@ class CustomCoursesForEdxOverrideProvider(FieldOverrideProvider):
return get_override_for_ccx(ccx, block, name, default) return get_override_for_ccx(ccx, block, name, default)
return default return default
@classmethod
def enabled_for(cls, course):
"""CCX field overrides are enabled per-course
protect against missing attributes
"""
return getattr(course, 'enable_ccx', False)
def get_current_ccx(course_key): def get_current_ccx(course_key):
""" """
......
...@@ -7,6 +7,7 @@ import itertools ...@@ -7,6 +7,7 @@ import itertools
import mock import mock
from courseware.views import progress # pylint: disable=import-error from courseware.views import progress # pylint: disable=import-error
from courseware.field_overrides import OverrideFieldData
from datetime import datetime from datetime import datetime
from django.conf import settings from django.conf import settings
from django.core.cache import get_cache from django.core.cache import get_cache
...@@ -56,7 +57,7 @@ class FieldOverridePerformanceTestCase(ProceduralCourseTestMixin, ...@@ -56,7 +57,7 @@ class FieldOverridePerformanceTestCase(ProceduralCourseTestMixin,
MakoMiddleware().process_request(self.request) MakoMiddleware().process_request(self.request)
def setup_course(self, size): def setup_course(self, size, enable_ccx):
""" """
Build a gradable course where each node has `size` children. Build a gradable course where each node has `size` children.
""" """
...@@ -98,7 +99,8 @@ class FieldOverridePerformanceTestCase(ProceduralCourseTestMixin, ...@@ -98,7 +99,8 @@ class FieldOverridePerformanceTestCase(ProceduralCourseTestMixin,
self.course = CourseFactory.create( self.course = CourseFactory.create(
graded=True, graded=True,
start=datetime.now(UTC), start=datetime.now(UTC),
grading_policy=grading_policy grading_policy=grading_policy,
enable_ccx=enable_ccx,
) )
self.populate_course(size) self.populate_course(size)
...@@ -117,11 +119,11 @@ class FieldOverridePerformanceTestCase(ProceduralCourseTestMixin, ...@@ -117,11 +119,11 @@ class FieldOverridePerformanceTestCase(ProceduralCourseTestMixin,
student_id=self.student.id student_id=self.student.id
) )
def instrument_course_progress_render(self, dataset_index, queries, reads, xblocks): def instrument_course_progress_render(self, course_width, enable_ccx, queries, reads, xblocks):
""" """
Renders the progress page, instrumenting Mongo reads and SQL queries. Renders the progress page, instrumenting Mongo reads and SQL queries.
""" """
self.setup_course(dataset_index + 1) self.setup_course(course_width, enable_ccx)
# Switch to published-only mode to simulate the LMS # Switch to published-only mode to simulate the LMS
with self.settings(MODULESTORE_BRANCH='published-only'): with self.settings(MODULESTORE_BRANCH='published-only'):
...@@ -135,17 +137,21 @@ class FieldOverridePerformanceTestCase(ProceduralCourseTestMixin, ...@@ -135,17 +137,21 @@ class FieldOverridePerformanceTestCase(ProceduralCourseTestMixin,
# We clear the request cache to simulate a new request in the LMS. # We clear the request cache to simulate a new request in the LMS.
RequestCache.clear_request_cache() RequestCache.clear_request_cache()
# Reset the list of provider classes, so that our django settings changes
# can actually take affect.
OverrideFieldData.provider_classes = None
with self.assertNumQueries(queries): with self.assertNumQueries(queries):
with check_mongo_calls(reads): with check_mongo_calls(reads):
with check_sum_of_calls(XBlock, ['__init__'], xblocks): with check_sum_of_calls(XBlock, ['__init__'], xblocks):
self.grade_course(self.course) self.grade_course(self.course)
@ddt.data(*itertools.product(('no_overrides', 'ccx'), range(3))) @ddt.data(*itertools.product(('no_overrides', 'ccx'), range(1, 4), (True, False)))
@ddt.unpack @ddt.unpack
@override_settings( @override_settings(
FIELD_OVERRIDE_PROVIDERS=(), FIELD_OVERRIDE_PROVIDERS=(),
) )
def test_field_overrides(self, overrides, dataset_index): def test_field_overrides(self, overrides, course_width, enable_ccx):
""" """
Test without any field overrides. Test without any field overrides.
""" """
...@@ -154,8 +160,8 @@ class FieldOverridePerformanceTestCase(ProceduralCourseTestMixin, ...@@ -154,8 +160,8 @@ class FieldOverridePerformanceTestCase(ProceduralCourseTestMixin,
'ccx': ('ccx.overrides.CustomCoursesForEdxOverrideProvider',) 'ccx': ('ccx.overrides.CustomCoursesForEdxOverrideProvider',)
} }
with self.settings(FIELD_OVERRIDE_PROVIDERS=providers[overrides]): with self.settings(FIELD_OVERRIDE_PROVIDERS=providers[overrides]):
queries, reads, xblocks = self.TEST_DATA[overrides][dataset_index] queries, reads, xblocks = self.TEST_DATA[(overrides, course_width, enable_ccx)]
self.instrument_course_progress_render(dataset_index, queries, reads, xblocks) self.instrument_course_progress_render(course_width, enable_ccx, queries, reads, xblocks)
class TestFieldOverrideMongoPerformance(FieldOverridePerformanceTestCase): class TestFieldOverrideMongoPerformance(FieldOverridePerformanceTestCase):
...@@ -166,12 +172,19 @@ class TestFieldOverrideMongoPerformance(FieldOverridePerformanceTestCase): ...@@ -166,12 +172,19 @@ class TestFieldOverrideMongoPerformance(FieldOverridePerformanceTestCase):
__test__ = True __test__ = True
TEST_DATA = { TEST_DATA = {
'no_overrides': [ # (providers, course_width, enable_ccx): # of sql queries, # of mongo queries, # of xblocks
(27, 7, 19), (135, 7, 131), (595, 7, 537) ('no_overrides', 1, True): (27, 7, 19),
], ('no_overrides', 2, True): (135, 7, 131),
'ccx': [ ('no_overrides', 3, True): (595, 7, 537),
(27, 7, 47), (135, 7, 455), (595, 7, 2037) ('ccx', 1, True): (27, 7, 47),
], ('ccx', 2, True): (135, 7, 455),
('ccx', 3, True): (595, 7, 2037),
('no_overrides', 1, False): (27, 7, 19),
('no_overrides', 2, False): (135, 7, 131),
('no_overrides', 3, False): (595, 7, 537),
('ccx', 1, False): (27, 7, 19),
('ccx', 2, False): (135, 7, 131),
('ccx', 3, False): (595, 7, 537),
} }
...@@ -183,10 +196,16 @@ class TestFieldOverrideSplitPerformance(FieldOverridePerformanceTestCase): ...@@ -183,10 +196,16 @@ class TestFieldOverrideSplitPerformance(FieldOverridePerformanceTestCase):
__test__ = True __test__ = True
TEST_DATA = { TEST_DATA = {
'no_overrides': [ ('no_overrides', 1, True): (27, 4, 9),
(27, 4, 9), (135, 19, 54), (595, 84, 215) ('no_overrides', 2, True): (135, 19, 54),
], ('no_overrides', 3, True): (595, 84, 215),
'ccx': [ ('ccx', 1, True): (27, 4, 9),
(27, 4, 9), (135, 19, 54), (595, 84, 215) ('ccx', 2, True): (135, 19, 54),
] ('ccx', 3, True): (595, 84, 215),
('no_overrides', 1, False): (27, 4, 9),
('no_overrides', 2, False): (135, 19, 54),
('no_overrides', 3, False): (595, 84, 215),
('ccx', 1, False): (27, 4, 9),
('ccx', 2, False): (135, 19, 54),
('ccx', 3, False): (595, 84, 215),
} }
...@@ -36,6 +36,7 @@ class TestFieldOverrides(ModuleStoreTestCase): ...@@ -36,6 +36,7 @@ class TestFieldOverrides(ModuleStoreTestCase):
""" """
super(TestFieldOverrides, self).setUp() super(TestFieldOverrides, self).setUp()
self.course = course = CourseFactory.create() self.course = course = CourseFactory.create()
self.course.enable_ccx = True
# Create a course outline # Create a course outline
self.mooc_start = start = datetime.datetime( self.mooc_start = start = datetime.datetime(
...@@ -71,7 +72,7 @@ class TestFieldOverrides(ModuleStoreTestCase): ...@@ -71,7 +72,7 @@ class TestFieldOverrides(ModuleStoreTestCase):
OverrideFieldData.provider_classes = None OverrideFieldData.provider_classes = None
for block in iter_blocks(ccx.course): for block in iter_blocks(ccx.course):
block._field_data = OverrideFieldData.wrap( # pylint: disable=protected-access block._field_data = OverrideFieldData.wrap( # pylint: disable=protected-access
AdminFactory.create(), block._field_data) # pylint: disable=protected-access AdminFactory.create(), course, block._field_data) # pylint: disable=protected-access
def cleanup_provider_classes(): def cleanup_provider_classes():
""" """
......
...@@ -466,7 +466,11 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -466,7 +466,11 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase):
Set up tests Set up tests
""" """
super(TestCCXGrades, self).setUp() super(TestCCXGrades, self).setUp()
course = CourseFactory.create() self.course = course = CourseFactory.create(enable_ccx=True)
# Create instructor account
self.coach = coach = AdminFactory.create()
self.client.login(username=coach.username, password="test")
# Create a course outline # Create a course outline
self.mooc_start = start = datetime.datetime( self.mooc_start = start = datetime.datetime(
...@@ -491,9 +495,6 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -491,9 +495,6 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase):
] for section in sections ] for section in sections
] ]
# Create instructor account
self.coach = coach = AdminFactory.create()
# Create CCX # Create CCX
role = CourseCcxCoachRole(course.id) role = CourseCcxCoachRole(course.id)
role.add_users(coach) role.add_users(coach)
...@@ -505,7 +506,7 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -505,7 +506,7 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase):
OverrideFieldData.provider_classes = None OverrideFieldData.provider_classes = None
# pylint: disable=protected-access # pylint: disable=protected-access
for block in iter_blocks(course): for block in iter_blocks(course):
block._field_data = OverrideFieldData.wrap(coach, block._field_data) block._field_data = OverrideFieldData.wrap(coach, course, block._field_data)
new_cache = {'tabs': [], 'discussion_topics': []} new_cache = {'tabs': [], 'discussion_topics': []}
if 'grading_policy' in block._field_data_cache: if 'grading_policy' in block._field_data_cache:
new_cache['grading_policy'] = block._field_data_cache['grading_policy'] new_cache['grading_policy'] = block._field_data_cache['grading_policy']
...@@ -559,6 +560,7 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -559,6 +560,7 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase):
@patch('ccx.views.render_to_response', intercept_renderer) @patch('ccx.views.render_to_response', intercept_renderer)
def test_gradebook(self): def test_gradebook(self):
self.course.enable_ccx = True
url = reverse( url = reverse(
'ccx_gradebook', 'ccx_gradebook',
kwargs={'course_id': self.ccx_key} kwargs={'course_id': self.ccx_key}
...@@ -574,6 +576,7 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -574,6 +576,7 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase):
len(student_info['grade_summary']['section_breakdown']), 4) len(student_info['grade_summary']['section_breakdown']), 4)
def test_grades_csv(self): def test_grades_csv(self):
self.course.enable_ccx = True
url = reverse( url = reverse(
'ccx_grades_csv', 'ccx_grades_csv',
kwargs={'course_id': self.ccx_key} kwargs={'course_id': self.ccx_key}
...@@ -593,6 +596,7 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -593,6 +596,7 @@ class TestCCXGrades(ModuleStoreTestCase, LoginEnrollmentTestCase):
@patch('courseware.views.render_to_response', intercept_renderer) @patch('courseware.views.render_to_response', intercept_renderer)
def test_student_progress(self): def test_student_progress(self):
self.course.enable_ccx = True
patch_context = patch('courseware.views.get_course_with_access') patch_context = patch('courseware.views.get_course_with_access')
get_course = patch_context.start() get_course = patch_context.start()
get_course.return_value = self.course get_course.return_value = self.course
......
...@@ -477,7 +477,8 @@ def prep_course_for_grading(course, request): ...@@ -477,7 +477,8 @@ def prep_course_for_grading(course, request):
field_data_cache = FieldDataCache.cache_for_descriptor_descendents( field_data_cache = FieldDataCache.cache_for_descriptor_descendents(
course.id, request.user, course, depth=2) course.id, request.user, course, depth=2)
course = get_module_for_descriptor( course = get_module_for_descriptor(
request.user, request, course, field_data_cache, course.id) request.user, request, course, field_data_cache, course.id, course=course
)
course._field_data_cache = {} # pylint: disable=protected-access course._field_data_cache = {} # pylint: disable=protected-access
course.set_grading_policy(course.grading_policy) course.set_grading_policy(course.grading_policy)
......
...@@ -555,7 +555,8 @@ class CourseBlocksAndNavigation(ListAPIView): ...@@ -555,7 +555,8 @@ class CourseBlocksAndNavigation(ListAPIView):
request_info.request, request_info.request,
block_info.block, block_info.block,
request_info.field_data_cache, request_info.field_data_cache,
request_info.course.id request_info.course.id,
course=request_info.course
) )
# verify the user has access to this block # verify the user has access to this block
......
...@@ -203,7 +203,8 @@ def get_course_about_section(course, section_key): ...@@ -203,7 +203,8 @@ def get_course_about_section(course, section_key):
field_data_cache, field_data_cache,
log_if_not_found=False, log_if_not_found=False,
wrap_xmodule_display=False, wrap_xmodule_display=False,
static_asset_path=course.static_asset_path static_asset_path=course.static_asset_path,
course=course
) )
html = '' html = ''
...@@ -256,7 +257,8 @@ def get_course_info_section_module(request, course, section_key): ...@@ -256,7 +257,8 @@ def get_course_info_section_module(request, course, section_key):
field_data_cache, field_data_cache,
log_if_not_found=False, log_if_not_found=False,
wrap_xmodule_display=False, wrap_xmodule_display=False,
static_asset_path=course.static_asset_path static_asset_path=course.static_asset_path,
course=course
) )
......
...@@ -144,7 +144,8 @@ def get_entrance_exam_score(request, course): ...@@ -144,7 +144,8 @@ def get_entrance_exam_score(request, course):
request, request,
descriptor, descriptor,
field_data_cache, field_data_cache,
course.id course.id,
course=course
) )
exam_module_generators = yield_dynamic_descriptor_descendants( exam_module_generators = yield_dynamic_descriptor_descendants(
......
...@@ -19,11 +19,12 @@ import threading ...@@ -19,11 +19,12 @@ import threading
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from django.conf import settings from django.conf import settings
from request_cache.middleware import RequestCache
from xblock.field_data import FieldData from xblock.field_data import FieldData
from xmodule.modulestore.inheritance import InheritanceMixin from xmodule.modulestore.inheritance import InheritanceMixin
NOTSET = object() NOTSET = object()
ENABLED_OVERRIDE_PROVIDERS_KEY = "courseware.field_overrides.enabled_providers"
def resolve_dotted(name): def resolve_dotted(name):
...@@ -61,7 +62,7 @@ class OverrideFieldData(FieldData): ...@@ -61,7 +62,7 @@ class OverrideFieldData(FieldData):
provider_classes = None provider_classes = None
@classmethod @classmethod
def wrap(cls, user, wrapped): def wrap(cls, user, course, wrapped):
""" """
Will return a :class:`OverrideFieldData` which wraps the field data Will return a :class:`OverrideFieldData` which wraps the field data
given in `wrapped` for the given `user`, if override providers are given in `wrapped` for the given `user`, if override providers are
...@@ -75,14 +76,42 @@ class OverrideFieldData(FieldData): ...@@ -75,14 +76,42 @@ class OverrideFieldData(FieldData):
(resolve_dotted(name) for name in (resolve_dotted(name) for name in
settings.FIELD_OVERRIDE_PROVIDERS)) settings.FIELD_OVERRIDE_PROVIDERS))
if cls.provider_classes: enabled_providers = cls._providers_for_course(course)
return cls(user, wrapped)
if enabled_providers:
# TODO: we might not actually want to return here. Might be better
# to check for instance.providers after the instance is built. This
# would allow for the case where we have registered providers but
# none are enabled for the provided course
return cls(user, wrapped, enabled_providers)
return wrapped return wrapped
def __init__(self, user, fallback): @classmethod
def _providers_for_course(cls, course):
"""
Return a filtered list of enabled providers based
on the course passed in. Cache this result per request to avoid
needing to call the provider filter api hundreds of times.
Arguments:
course: The course XBlock
"""
request_cache = RequestCache.get_request_cache()
enabled_providers = request_cache.data.get(
ENABLED_OVERRIDE_PROVIDERS_KEY, NOTSET
)
if enabled_providers == NOTSET:
enabled_providers = tuple(
(provider_class for provider_class in cls.provider_classes if provider_class.enabled_for(course))
)
request_cache.data[ENABLED_OVERRIDE_PROVIDERS_KEY] = enabled_providers
return enabled_providers
def __init__(self, user, fallback, providers):
self.fallback = fallback self.fallback = fallback
self.providers = tuple((cls(user) for cls in self.provider_classes)) self.providers = tuple(provider(user) for provider in providers)
def get_override(self, block, name): def get_override(self, block, name):
""" """
...@@ -109,6 +138,9 @@ class OverrideFieldData(FieldData): ...@@ -109,6 +138,9 @@ class OverrideFieldData(FieldData):
self.fallback.delete(block, name) self.fallback.delete(block, name)
def has(self, block, name): def has(self, block, name):
if not self.providers:
return self.fallback.has(block, name)
has = self.get_override(block, name) has = self.get_override(block, name)
if has is NOTSET: if has is NOTSET:
# If this is an inheritable field and an override is set above, # If this is an inheritable field and an override is set above,
...@@ -128,7 +160,7 @@ class OverrideFieldData(FieldData): ...@@ -128,7 +160,7 @@ class OverrideFieldData(FieldData):
def default(self, block, name): def default(self, block, name):
# The `default` method is overloaded by the field storage system to # The `default` method is overloaded by the field storage system to
# also handle inheritance. # also handle inheritance.
if not overrides_disabled(): if self.providers and not overrides_disabled():
inheritable = InheritanceMixin.fields.keys() inheritable = InheritanceMixin.fields.keys()
if name in inheritable: if name in inheritable:
for ancestor in _lineage(block): for ancestor in _lineage(block):
...@@ -192,6 +224,17 @@ class FieldOverrideProvider(object): ...@@ -192,6 +224,17 @@ class FieldOverrideProvider(object):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def enabled_for(self, course): # pragma no cover
"""
Return True if this provider should be enabled for a given course
Return False otherwise
Concrete implementations are responsible for implementing this method
"""
return False
def _lineage(block): def _lineage(block):
""" """
......
...@@ -207,7 +207,9 @@ def _grade(student, request, course, keep_raw_scores): ...@@ -207,7 +207,9 @@ def _grade(student, request, course, keep_raw_scores):
# would be simpler # would be simpler
with manual_transaction(): with manual_transaction():
field_data_cache = FieldDataCache([descriptor], course.id, student) field_data_cache = FieldDataCache([descriptor], course.id, student)
return get_module_for_descriptor(student, request, descriptor, field_data_cache, course.id) return get_module_for_descriptor(
student, request, descriptor, field_data_cache, course.id, course=course
)
for module_descriptor in yield_dynamic_descriptor_descendants( for module_descriptor in yield_dynamic_descriptor_descendants(
section_descriptor, student.id, create_module section_descriptor, student.id, create_module
...@@ -337,7 +339,9 @@ def _progress_summary(student, request, course): ...@@ -337,7 +339,9 @@ def _progress_summary(student, request, course):
) )
# TODO: We need the request to pass into here. If we could # TODO: We need the request to pass into here. If we could
# forego that, our arguments would be simpler # forego that, our arguments would be simpler
course_module = get_module_for_descriptor(student, request, course, field_data_cache, course.id) course_module = get_module_for_descriptor(
student, request, course, field_data_cache, course.id, course=course
)
if not course_module: if not course_module:
# This student must not have access to the course. # This student must not have access to the course.
return None return None
......
...@@ -138,7 +138,9 @@ def toc_for_course(request, course, active_chapter, active_section, field_data_c ...@@ -138,7 +138,9 @@ def toc_for_course(request, course, active_chapter, active_section, field_data_c
''' '''
with modulestore().bulk_operations(course.id): with modulestore().bulk_operations(course.id):
course_module = get_module_for_descriptor(request.user, request, course, field_data_cache, course.id) course_module = get_module_for_descriptor(
request.user, request, course, field_data_cache, course.id, course=course
)
if course_module is None: if course_module is None:
return None return None
...@@ -190,7 +192,7 @@ def toc_for_course(request, course, active_chapter, active_section, field_data_c ...@@ -190,7 +192,7 @@ def toc_for_course(request, course, active_chapter, active_section, field_data_c
def get_module(user, request, usage_key, field_data_cache, def get_module(user, request, usage_key, field_data_cache,
position=None, log_if_not_found=True, wrap_xmodule_display=True, position=None, log_if_not_found=True, wrap_xmodule_display=True,
grade_bucket_type=None, depth=0, grade_bucket_type=None, depth=0,
static_asset_path=''): static_asset_path='', course=None):
""" """
Get an instance of the xmodule class identified by location, Get an instance of the xmodule class identified by location,
setting the state based on an existing StudentModule, or creating one if none setting the state based on an existing StudentModule, or creating one if none
...@@ -224,7 +226,8 @@ def get_module(user, request, usage_key, field_data_cache, ...@@ -224,7 +226,8 @@ def get_module(user, request, usage_key, field_data_cache,
position=position, position=position,
wrap_xmodule_display=wrap_xmodule_display, wrap_xmodule_display=wrap_xmodule_display,
grade_bucket_type=grade_bucket_type, grade_bucket_type=grade_bucket_type,
static_asset_path=static_asset_path) static_asset_path=static_asset_path,
course=course)
except ItemNotFoundError: except ItemNotFoundError:
if log_if_not_found: if log_if_not_found:
log.debug("Error in get_module: ItemNotFoundError") log.debug("Error in get_module: ItemNotFoundError")
...@@ -253,7 +256,8 @@ def get_xqueue_callback_url_prefix(request): ...@@ -253,7 +256,8 @@ def get_xqueue_callback_url_prefix(request):
def get_module_for_descriptor(user, request, descriptor, field_data_cache, course_key, def get_module_for_descriptor(user, request, descriptor, field_data_cache, course_key,
position=None, wrap_xmodule_display=True, grade_bucket_type=None, position=None, wrap_xmodule_display=True, grade_bucket_type=None,
static_asset_path='', disable_staff_debug_info=False): static_asset_path='', disable_staff_debug_info=False,
course=None):
""" """
Implements get_module, extracting out the request-specific functionality. Implements get_module, extracting out the request-specific functionality.
...@@ -280,6 +284,7 @@ def get_module_for_descriptor(user, request, descriptor, field_data_cache, cours ...@@ -280,6 +284,7 @@ def get_module_for_descriptor(user, request, descriptor, field_data_cache, cours
user_location=user_location, user_location=user_location,
request_token=xblock_request_token(request), request_token=xblock_request_token(request),
disable_staff_debug_info=disable_staff_debug_info, disable_staff_debug_info=disable_staff_debug_info,
course=course
) )
...@@ -287,7 +292,8 @@ def get_module_system_for_user(user, field_data_cache, # TODO # pylint: disabl ...@@ -287,7 +292,8 @@ def get_module_system_for_user(user, field_data_cache, # TODO # pylint: disabl
# Arguments preceding this comment have user binding, those following don't # Arguments preceding this comment have user binding, those following don't
descriptor, course_id, track_function, xqueue_callback_url_prefix, descriptor, course_id, track_function, xqueue_callback_url_prefix,
request_token, position=None, wrap_xmodule_display=True, grade_bucket_type=None, request_token, position=None, wrap_xmodule_display=True, grade_bucket_type=None,
static_asset_path='', user_location=None, disable_staff_debug_info=False): static_asset_path='', user_location=None, disable_staff_debug_info=False,
course=None):
""" """
Helper function that returns a module system and student_data bound to a user and a descriptor. Helper function that returns a module system and student_data bound to a user and a descriptor.
...@@ -382,6 +388,7 @@ def get_module_system_for_user(user, field_data_cache, # TODO # pylint: disabl ...@@ -382,6 +388,7 @@ def get_module_system_for_user(user, field_data_cache, # TODO # pylint: disabl
static_asset_path=static_asset_path, static_asset_path=static_asset_path,
user_location=user_location, user_location=user_location,
request_token=request_token, request_token=request_token,
course=course
) )
def _fulfill_content_milestones(user, course_key, content_key): def _fulfill_content_milestones(user, course_key, content_key):
...@@ -508,14 +515,15 @@ def get_module_system_for_user(user, field_data_cache, # TODO # pylint: disabl ...@@ -508,14 +515,15 @@ def get_module_system_for_user(user, field_data_cache, # TODO # pylint: disabl
grade_bucket_type=grade_bucket_type, grade_bucket_type=grade_bucket_type,
static_asset_path=static_asset_path, static_asset_path=static_asset_path,
user_location=user_location, user_location=user_location,
request_token=request_token request_token=request_token,
course=course
) )
module.descriptor.bind_for_student( module.descriptor.bind_for_student(
inner_system, inner_system,
real_user.id, real_user.id,
[ [
partial(OverrideFieldData.wrap, real_user), partial(OverrideFieldData.wrap, real_user, course),
partial(LmsFieldData, student_data=inner_student_data), partial(LmsFieldData, student_data=inner_student_data),
], ],
) )
...@@ -681,10 +689,13 @@ def get_module_system_for_user(user, field_data_cache, # TODO # pylint: disabl ...@@ -681,10 +689,13 @@ def get_module_system_for_user(user, field_data_cache, # TODO # pylint: disabl
return system, field_data return system, field_data
# TODO: Find all the places that this method is called and figure out how to
# get a loaded course passed into it
def get_module_for_descriptor_internal(user, descriptor, field_data_cache, course_id, # pylint: disable=invalid-name def get_module_for_descriptor_internal(user, descriptor, field_data_cache, course_id, # pylint: disable=invalid-name
track_function, xqueue_callback_url_prefix, request_token, track_function, xqueue_callback_url_prefix, request_token,
position=None, wrap_xmodule_display=True, grade_bucket_type=None, position=None, wrap_xmodule_display=True, grade_bucket_type=None,
static_asset_path='', user_location=None, disable_staff_debug_info=False): static_asset_path='', user_location=None, disable_staff_debug_info=False,
course=None):
""" """
Actually implement get_module, without requiring a request. Actually implement get_module, without requiring a request.
...@@ -708,13 +719,14 @@ def get_module_for_descriptor_internal(user, descriptor, field_data_cache, cours ...@@ -708,13 +719,14 @@ def get_module_for_descriptor_internal(user, descriptor, field_data_cache, cours
user_location=user_location, user_location=user_location,
request_token=request_token, request_token=request_token,
disable_staff_debug_info=disable_staff_debug_info, disable_staff_debug_info=disable_staff_debug_info,
course=course
) )
descriptor.bind_for_student( descriptor.bind_for_student(
system, system,
user.id, user.id,
[ [
partial(OverrideFieldData.wrap, user), partial(OverrideFieldData.wrap, user, course),
partial(LmsFieldData, student_data=student_data), partial(LmsFieldData, student_data=student_data),
], ],
) )
...@@ -732,7 +744,7 @@ def get_module_for_descriptor_internal(user, descriptor, field_data_cache, cours ...@@ -732,7 +744,7 @@ def get_module_for_descriptor_internal(user, descriptor, field_data_cache, cours
return descriptor return descriptor
def load_single_xblock(request, user_id, course_id, usage_key_string): def load_single_xblock(request, user_id, course_id, usage_key_string, course=None):
""" """
Load a single XBlock identified by usage_key_string. Load a single XBlock identified by usage_key_string.
""" """
...@@ -746,7 +758,7 @@ def load_single_xblock(request, user_id, course_id, usage_key_string): ...@@ -746,7 +758,7 @@ def load_single_xblock(request, user_id, course_id, usage_key_string):
modulestore().get_item(usage_key), modulestore().get_item(usage_key),
depth=0, depth=0,
) )
instance = get_module(user, request, usage_key, field_data_cache, grade_bucket_type='xqueue') instance = get_module(user, request, usage_key, field_data_cache, grade_bucket_type='xqueue', course=course)
if instance is None: if instance is None:
msg = "No module {0} for user {1}--access denied?".format(usage_key_string, user) msg = "No module {0} for user {1}--access denied?".format(usage_key_string, user)
log.debug(msg) log.debug(msg)
...@@ -772,7 +784,12 @@ def xqueue_callback(request, course_id, userid, mod_id, dispatch): ...@@ -772,7 +784,12 @@ def xqueue_callback(request, course_id, userid, mod_id, dispatch):
if not isinstance(header, dict) or 'lms_key' not in header: if not isinstance(header, dict) or 'lms_key' not in header:
raise Http404 raise Http404
instance = load_single_xblock(request, userid, course_id, mod_id) course_key = CourseKey.from_string(course_id)
with modulestore().bulk_operations(course_key):
course = modulestore().get_course(course_key, depth=0)
instance = load_single_xblock(request, userid, course_id, mod_id, course=course)
# Transfer 'queuekey' from xqueue response header to the data. # Transfer 'queuekey' from xqueue response header to the data.
# This is required to use the interface defined by 'handle_ajax' # This is required to use the interface defined by 'handle_ajax'
...@@ -799,7 +816,10 @@ def handle_xblock_callback_noauth(request, course_id, usage_id, handler, suffix= ...@@ -799,7 +816,10 @@ def handle_xblock_callback_noauth(request, course_id, usage_id, handler, suffix=
""" """
request.user.known = False request.user.known = False
return _invoke_xblock_handler(request, course_id, usage_id, handler, suffix) course_key = CourseKey.from_string(course_id)
with modulestore().bulk_operations(course_key):
course = modulestore().get_course(course_key, depth=0)
return _invoke_xblock_handler(request, course_id, usage_id, handler, suffix, course=course)
def handle_xblock_callback(request, course_id, usage_id, handler, suffix=None): def handle_xblock_callback(request, course_id, usage_id, handler, suffix=None):
...@@ -820,7 +840,18 @@ def handle_xblock_callback(request, course_id, usage_id, handler, suffix=None): ...@@ -820,7 +840,18 @@ def handle_xblock_callback(request, course_id, usage_id, handler, suffix=None):
if not request.user.is_authenticated(): if not request.user.is_authenticated():
return HttpResponse('Unauthenticated', status=403) return HttpResponse('Unauthenticated', status=403)
return _invoke_xblock_handler(request, course_id, usage_id, handler, suffix) try:
course_key = CourseKey.from_string(course_id)
except InvalidKeyError:
raise Http404("Invalid location")
with modulestore().bulk_operations(course_key):
try:
course = modulestore().get_course(course_key)
except ItemNotFoundError:
raise Http404("invalid location")
return _invoke_xblock_handler(request, course_id, usage_id, handler, suffix, course=course)
def xblock_resource(request, block_type, uri): # pylint: disable=unused-argument def xblock_resource(request, block_type, uri): # pylint: disable=unused-argument
...@@ -840,7 +871,7 @@ def xblock_resource(request, block_type, uri): # pylint: disable=unused-argumen ...@@ -840,7 +871,7 @@ def xblock_resource(request, block_type, uri): # pylint: disable=unused-argumen
return HttpResponse(content, mimetype=mimetype) return HttpResponse(content, mimetype=mimetype)
def get_module_by_usage_id(request, course_id, usage_id, disable_staff_debug_info=False): def get_module_by_usage_id(request, course_id, usage_id, disable_staff_debug_info=False, course=None):
""" """
Gets a module instance based on its `usage_id` in a course, for a given request/user Gets a module instance based on its `usage_id` in a course, for a given request/user
...@@ -890,7 +921,8 @@ def get_module_by_usage_id(request, course_id, usage_id, disable_staff_debug_inf ...@@ -890,7 +921,8 @@ def get_module_by_usage_id(request, course_id, usage_id, disable_staff_debug_inf
descriptor, descriptor,
field_data_cache, field_data_cache,
usage_key.course_key, usage_key.course_key,
disable_staff_debug_info=disable_staff_debug_info disable_staff_debug_info=disable_staff_debug_info,
course=course
) )
if instance is None: if instance is None:
# Either permissions just changed, or someone is trying to be clever # Either permissions just changed, or someone is trying to be clever
...@@ -901,7 +933,7 @@ def get_module_by_usage_id(request, course_id, usage_id, disable_staff_debug_inf ...@@ -901,7 +933,7 @@ def get_module_by_usage_id(request, course_id, usage_id, disable_staff_debug_inf
return (instance, tracking_context) return (instance, tracking_context)
def _invoke_xblock_handler(request, course_id, usage_id, handler, suffix): def _invoke_xblock_handler(request, course_id, usage_id, handler, suffix, course=None):
""" """
Invoke an XBlock handler, either authenticated or not. Invoke an XBlock handler, either authenticated or not.
...@@ -926,7 +958,7 @@ def _invoke_xblock_handler(request, course_id, usage_id, handler, suffix): ...@@ -926,7 +958,7 @@ def _invoke_xblock_handler(request, course_id, usage_id, handler, suffix):
raise Http404 raise Http404
with modulestore().bulk_operations(course_key): with modulestore().bulk_operations(course_key):
instance, tracking_context = get_module_by_usage_id(request, course_id, usage_id) instance, tracking_context = get_module_by_usage_id(request, course_id, usage_id, course=course)
# Name the transaction so that we can view XBlock handlers separately in # Name the transaction so that we can view XBlock handlers separately in
# New Relic. The suffix is necessary for XModule handlers because the # New Relic. The suffix is necessary for XModule handlers because the
...@@ -991,7 +1023,14 @@ def xblock_view(request, course_id, usage_id, view_name): ...@@ -991,7 +1023,14 @@ def xblock_view(request, course_id, usage_id, view_name):
if not request.user.is_authenticated(): if not request.user.is_authenticated():
raise PermissionDenied raise PermissionDenied
instance, _ = get_module_by_usage_id(request, course_id, usage_id) try:
course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id)
except InvalidKeyError:
raise Http404("Invalid location")
with modulestore().bulk_operations(course_key):
course = modulestore().get_course(course_key)
instance, _ = get_module_by_usage_id(request, course_id, usage_id, course=course)
try: try:
fragment = instance.render(view_name, context=request.GET) fragment = instance.render(view_name, context=request.GET)
......
...@@ -17,6 +17,11 @@ class IndividualStudentOverrideProvider(FieldOverrideProvider): ...@@ -17,6 +17,11 @@ class IndividualStudentOverrideProvider(FieldOverrideProvider):
def get(self, block, name, default): def get(self, block, name, default):
return get_override_for_user(self.user, block, name, default) return get_override_for_user(self.user, block, name, default)
@classmethod
def enabled_for(cls, course):
"""This simple override provider is always enabled"""
return True
def get_override_for_user(user, block, name, default=None): def get_override_for_user(user, block, name, default=None):
""" """
......
...@@ -262,7 +262,8 @@ class CourseInstantiationTests(ModuleStoreTestCase): ...@@ -262,7 +262,8 @@ class CourseInstantiationTests(ModuleStoreTestCase):
fake_request, fake_request,
course, course,
field_data_cache, field_data_cache,
course.id course.id,
course=course
) )
for chapter in course_module.get_children(): for chapter in course_module.get_children():
for section in chapter.get_children(): for section in chapter.get_children():
......
...@@ -4,9 +4,12 @@ Tests for `field_overrides` module. ...@@ -4,9 +4,12 @@ Tests for `field_overrides` module.
import unittest import unittest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
from django.test import TestCase
from django.test.utils import override_settings from django.test.utils import override_settings
from xblock.field_data import DictFieldData from xblock.field_data import DictFieldData
from xmodule.modulestore.tests.factories import CourseFactory
from xmodule.modulestore.tests.django_utils import (
ModuleStoreTestCase,
)
from ..field_overrides import ( from ..field_overrides import (
disable_overrides, disable_overrides,
...@@ -22,13 +25,14 @@ TESTUSER = "testuser" ...@@ -22,13 +25,14 @@ TESTUSER = "testuser"
@attr('shard_1') @attr('shard_1')
@override_settings(FIELD_OVERRIDE_PROVIDERS=( @override_settings(FIELD_OVERRIDE_PROVIDERS=(
'courseware.tests.test_field_overrides.TestOverrideProvider',)) 'courseware.tests.test_field_overrides.TestOverrideProvider',))
class OverrideFieldDataTests(TestCase): class OverrideFieldDataTests(ModuleStoreTestCase):
""" """
Tests for `OverrideFieldData`. Tests for `OverrideFieldData`.
""" """
def setUp(self): def setUp(self):
super(OverrideFieldDataTests, self).setUp() super(OverrideFieldDataTests, self).setUp()
self.course = CourseFactory.create(enable_ccx=True)
OverrideFieldData.provider_classes = None OverrideFieldData.provider_classes = None
def tearDown(self): def tearDown(self):
...@@ -39,7 +43,7 @@ class OverrideFieldDataTests(TestCase): ...@@ -39,7 +43,7 @@ class OverrideFieldDataTests(TestCase):
""" """
Factory method. Factory method.
""" """
return OverrideFieldData.wrap(TESTUSER, DictFieldData({ return OverrideFieldData.wrap(TESTUSER, self.course, DictFieldData({
'foo': 'bar', 'foo': 'bar',
'bees': 'knees', 'bees': 'knees',
})) }))
...@@ -124,3 +128,7 @@ class TestOverrideProvider(FieldOverrideProvider): ...@@ -124,3 +128,7 @@ class TestOverrideProvider(FieldOverrideProvider):
if name == 'oh': if name == 'oh':
return 'man' return 'man'
return default return default
@classmethod
def enabled_for(cls, course):
return True
...@@ -184,7 +184,13 @@ class ModuleRenderTestCase(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -184,7 +184,13 @@ class ModuleRenderTestCase(ModuleStoreTestCase, LoginEnrollmentTestCase):
with patch('courseware.module_render.load_single_xblock', return_value=self.mock_module): with patch('courseware.module_render.load_single_xblock', return_value=self.mock_module):
# call xqueue_callback with our mocked information # call xqueue_callback with our mocked information
request = self.request_factory.post(self.callback_url, data) request = self.request_factory.post(self.callback_url, data)
render.xqueue_callback(request, self.course_key, self.mock_user.id, self.mock_module.id, self.dispatch) render.xqueue_callback(
request,
unicode(self.course_key),
self.mock_user.id,
self.mock_module.id,
self.dispatch
)
# Verify that handle ajax is called with the correct data # Verify that handle ajax is called with the correct data
request.POST['queuekey'] = fake_key request.POST['queuekey'] = fake_key
...@@ -200,12 +206,24 @@ class ModuleRenderTestCase(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -200,12 +206,24 @@ class ModuleRenderTestCase(ModuleStoreTestCase, LoginEnrollmentTestCase):
# Test with missing xqueue data # Test with missing xqueue data
with self.assertRaises(Http404): with self.assertRaises(Http404):
request = self.request_factory.post(self.callback_url, {}) request = self.request_factory.post(self.callback_url, {})
render.xqueue_callback(request, self.course_key, self.mock_user.id, self.mock_module.id, self.dispatch) render.xqueue_callback(
request,
unicode(self.course_key),
self.mock_user.id,
self.mock_module.id,
self.dispatch
)
# Test with missing xqueue_header # Test with missing xqueue_header
with self.assertRaises(Http404): with self.assertRaises(Http404):
request = self.request_factory.post(self.callback_url, data) request = self.request_factory.post(self.callback_url, data)
render.xqueue_callback(request, self.course_key, self.mock_user.id, self.mock_module.id, self.dispatch) render.xqueue_callback(
request,
unicode(self.course_key),
self.mock_user.id,
self.mock_module.id,
self.dispatch
)
def test_get_score_bucket(self): def test_get_score_bucket(self):
self.assertEquals(render.get_score_bucket(0, 10), 'incorrect') self.assertEquals(render.get_score_bucket(0, 10), 'incorrect')
...@@ -275,11 +293,28 @@ class ModuleRenderTestCase(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -275,11 +293,28 @@ class ModuleRenderTestCase(ModuleStoreTestCase, LoginEnrollmentTestCase):
course = CourseFactory() course = CourseFactory()
descriptor = ItemFactory(category=block_type, parent=course) descriptor = ItemFactory(category=block_type, parent=course)
field_data_cache = FieldDataCache([self.toy_course, descriptor], self.toy_course.id, self.mock_user) field_data_cache = FieldDataCache([self.toy_course, descriptor], self.toy_course.id, self.mock_user)
render.get_module_for_descriptor(self.mock_user, request, descriptor, field_data_cache, self.toy_course.id) # This is verifying that caching doesn't cause an error during get_module_for_descriptor, which
render.get_module_for_descriptor(self.mock_user, request, descriptor, field_data_cache, self.toy_course.id) # is why it calls the method twice identically.
render.get_module_for_descriptor(
self.mock_user,
request,
descriptor,
field_data_cache,
self.toy_course.id,
course=self.toy_course
)
render.get_module_for_descriptor(
self.mock_user,
request,
descriptor,
field_data_cache,
self.toy_course.id,
course=self.toy_course
)
@override_settings(FIELD_OVERRIDE_PROVIDERS=( @override_settings(FIELD_OVERRIDE_PROVIDERS=(
'ccx.overrides.CustomCoursesForEdxOverrideProvider',)) 'ccx.overrides.CustomCoursesForEdxOverrideProvider',
))
def test_rebind_different_users_ccx(self): def test_rebind_different_users_ccx(self):
""" """
This tests the rebinding a descriptor to a student does not result This tests the rebinding a descriptor to a student does not result
...@@ -287,18 +322,18 @@ class ModuleRenderTestCase(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -287,18 +322,18 @@ class ModuleRenderTestCase(ModuleStoreTestCase, LoginEnrollmentTestCase):
""" """
request = self.request_factory.get('') request = self.request_factory.get('')
request.user = self.mock_user request.user = self.mock_user
course = CourseFactory() course = CourseFactory.create(enable_ccx=True)
descriptor = ItemFactory(category='html', parent=course) descriptor = ItemFactory(category='html', parent=course)
field_data_cache = FieldDataCache( field_data_cache = FieldDataCache(
[self.toy_course, descriptor], self.toy_course.id, self.mock_user [course, descriptor], course.id, self.mock_user
) )
# grab what _field_data was originally set to # grab what _field_data was originally set to
original_field_data = descriptor._field_data # pylint: disable=protected-access, no-member original_field_data = descriptor._field_data # pylint: disable=protected-access, no-member
render.get_module_for_descriptor( render.get_module_for_descriptor(
self.mock_user, request, descriptor, field_data_cache, self.toy_course.id self.mock_user, request, descriptor, field_data_cache, course.id, course=course
) )
# check that _unwrapped_field_data is the same as the original # check that _unwrapped_field_data is the same as the original
...@@ -314,7 +349,8 @@ class ModuleRenderTestCase(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -314,7 +349,8 @@ class ModuleRenderTestCase(ModuleStoreTestCase, LoginEnrollmentTestCase):
request, request,
descriptor, descriptor,
field_data_cache, field_data_cache,
self.toy_course.id course.id,
course=course
) )
# _field_data should now be wrapped by LmsFieldData # _field_data should now be wrapped by LmsFieldData
...@@ -832,6 +868,7 @@ class JsonInitDataTest(ModuleStoreTestCase): ...@@ -832,6 +868,7 @@ class JsonInitDataTest(ModuleStoreTestCase):
descriptor, descriptor,
field_data_cache, field_data_cache,
course.id, # pylint: disable=no-member course.id, # pylint: disable=no-member
course=course
) )
html = module.render(STUDENT_VIEW).content html = module.render(STUDENT_VIEW).content
self.assertIn(json_output, html) self.assertIn(json_output, html)
...@@ -1098,6 +1135,8 @@ class TestAnonymousStudentId(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -1098,6 +1135,8 @@ class TestAnonymousStudentId(ModuleStoreTestCase, LoginEnrollmentTestCase):
def setUp(self): def setUp(self):
super(TestAnonymousStudentId, self).setUp(create_user=False) super(TestAnonymousStudentId, self).setUp(create_user=False)
self.user = UserFactory() self.user = UserFactory()
self.course_key = self.create_toy_course()
self.course = modulestore().get_course(self.course_key)
@patch('courseware.module_render.has_access', Mock(return_value=True)) @patch('courseware.module_render.has_access', Mock(return_value=True))
def _get_anonymous_id(self, course_id, xblock_class): def _get_anonymous_id(self, course_id, xblock_class):
...@@ -1135,6 +1174,7 @@ class TestAnonymousStudentId(ModuleStoreTestCase, LoginEnrollmentTestCase): ...@@ -1135,6 +1174,7 @@ class TestAnonymousStudentId(ModuleStoreTestCase, LoginEnrollmentTestCase):
track_function=Mock(name='track_function'), # Track Function track_function=Mock(name='track_function'), # Track Function
xqueue_callback_url_prefix=Mock(name='xqueue_callback_url_prefix'), # XQueue Callback Url Prefix xqueue_callback_url_prefix=Mock(name='xqueue_callback_url_prefix'), # XQueue Callback Url Prefix
request_token='request_token', request_token='request_token',
course=self.course,
).xmodule_runtime.anonymous_student_id ).xmodule_runtime.anonymous_student_id
@ddt.data(*PER_STUDENT_ANONYMIZED_DESCRIPTORS) @ddt.data(*PER_STUDENT_ANONYMIZED_DESCRIPTORS)
...@@ -1444,7 +1484,8 @@ class LMSXBlockServiceBindingTest(ModuleStoreTestCase): ...@@ -1444,7 +1484,8 @@ class LMSXBlockServiceBindingTest(ModuleStoreTestCase):
self.course.id, self.course.id,
self.track_function, self.track_function,
self.xqueue_callback_url_prefix, self.xqueue_callback_url_prefix,
self.request_token self.request_token,
course=self.course
) )
service = runtime.service(descriptor, expected_service) service = runtime.service(descriptor, expected_service)
self.assertIsNotNone(service) self.assertIsNotNone(service)
...@@ -1462,7 +1503,8 @@ class LMSXBlockServiceBindingTest(ModuleStoreTestCase): ...@@ -1462,7 +1503,8 @@ class LMSXBlockServiceBindingTest(ModuleStoreTestCase):
self.course.id, self.course.id,
self.track_function, self.track_function,
self.xqueue_callback_url_prefix, self.xqueue_callback_url_prefix,
self.request_token self.request_token,
course=self.course
) )
self.assertFalse(getattr(runtime, u'user_is_beta_tester')) self.assertFalse(getattr(runtime, u'user_is_beta_tester'))
...@@ -1607,6 +1649,7 @@ class TestFilteredChildren(ModuleStoreTestCase): ...@@ -1607,6 +1649,7 @@ class TestFilteredChildren(ModuleStoreTestCase):
block, block,
field_data_cache, field_data_cache,
course_id, course_id,
course=self.course
) )
def _has_access(self, user, action, obj, course_key=None): def _has_access(self, user, action, obj, course_key=None):
......
...@@ -310,7 +310,8 @@ class SplitTestPosition(ModuleStoreTestCase): ...@@ -310,7 +310,8 @@ class SplitTestPosition(ModuleStoreTestCase):
MagicMock(name='request'), MagicMock(name='request'),
self.course, self.course,
mock_field_data_cache, mock_field_data_cache,
self.course.id self.course.id,
course=self.course
) )
# Now that we have the course, change the position and save, nothing should explode! # Now that we have the course, change the position and save, nothing should explode!
......
...@@ -253,7 +253,7 @@ def save_child_position(seq_module, child_name): ...@@ -253,7 +253,7 @@ def save_child_position(seq_module, child_name):
seq_module.save() seq_module.save()
def save_positions_recursively_up(user, request, field_data_cache, xmodule): def save_positions_recursively_up(user, request, field_data_cache, xmodule, course=None):
""" """
Recurses up the course tree starting from a leaf Recurses up the course tree starting from a leaf
Saving the position property based on the previous node as it goes Saving the position property based on the previous node as it goes
...@@ -265,7 +265,14 @@ def save_positions_recursively_up(user, request, field_data_cache, xmodule): ...@@ -265,7 +265,14 @@ def save_positions_recursively_up(user, request, field_data_cache, xmodule):
parent = None parent = None
if parent_location: if parent_location:
parent_descriptor = modulestore().get_item(parent_location) parent_descriptor = modulestore().get_item(parent_location)
parent = get_module_for_descriptor(user, request, parent_descriptor, field_data_cache, current_module.location.course_key) parent = get_module_for_descriptor(
user,
request,
parent_descriptor,
field_data_cache,
current_module.location.course_key,
course=course
)
if parent and hasattr(parent, 'position'): if parent and hasattr(parent, 'position'):
save_child_position(parent, current_module.location.name) save_child_position(parent, current_module.location.name)
...@@ -412,7 +419,9 @@ def _index_bulk_op(request, course_key, chapter, section, position): ...@@ -412,7 +419,9 @@ def _index_bulk_op(request, course_key, chapter, section, position):
field_data_cache = FieldDataCache.cache_for_descriptor_descendents( field_data_cache = FieldDataCache.cache_for_descriptor_descendents(
course_key, user, course, depth=2) course_key, user, course, depth=2)
course_module = get_module_for_descriptor(user, request, course, field_data_cache, course_key) course_module = get_module_for_descriptor(
user, request, course, field_data_cache, course_key, course=course
)
if course_module is None: if course_module is None:
log.warning(u'If you see this, something went wrong: if we got this' log.warning(u'If you see this, something went wrong: if we got this'
u' far, should have gotten a course module for this user') u' far, should have gotten a course module for this user')
...@@ -532,7 +541,8 @@ def _index_bulk_op(request, course_key, chapter, section, position): ...@@ -532,7 +541,8 @@ def _index_bulk_op(request, course_key, chapter, section, position):
section_descriptor, section_descriptor,
field_data_cache, field_data_cache,
course_key, course_key,
position position,
course=course
) )
if section_module is None: if section_module is None:
...@@ -1180,7 +1190,7 @@ def get_static_tab_contents(request, course, tab): ...@@ -1180,7 +1190,7 @@ def get_static_tab_contents(request, course, tab):
course.id, request.user, modulestore().get_item(loc), depth=0 course.id, request.user, modulestore().get_item(loc), depth=0
) )
tab_module = get_module( tab_module = get_module(
request.user, request, loc, field_data_cache, static_asset_path=course.static_asset_path request.user, request, loc, field_data_cache, static_asset_path=course.static_asset_path, course=course
) )
logging.debug('course_module = {0}'.format(tab_module)) logging.debug('course_module = {0}'.format(tab_module))
...@@ -1238,7 +1248,8 @@ def get_course_lti_endpoints(request, course_id): ...@@ -1238,7 +1248,8 @@ def get_course_lti_endpoints(request, course_id):
anonymous_user, anonymous_user,
descriptor descriptor
), ),
course_key course_key,
course=course
) )
for descriptor in lti_descriptors for descriptor in lti_descriptors
] ]
...@@ -1411,7 +1422,7 @@ def render_xblock(request, usage_key_string, check_if_enrolled=True): ...@@ -1411,7 +1422,7 @@ def render_xblock(request, usage_key_string, check_if_enrolled=True):
# get the block, which verifies whether the user has access to the block. # get the block, which verifies whether the user has access to the block.
block, _ = get_module_by_usage_id( block, _ = get_module_by_usage_id(
request, unicode(course_key), unicode(usage_key), disable_staff_debug_info=True request, unicode(course_key), unicode(usage_key), disable_staff_debug_info=True, course=course
) )
context = { context = {
......
...@@ -831,7 +831,9 @@ class EdxNotesViewsTest(ModuleStoreTestCase): ...@@ -831,7 +831,9 @@ class EdxNotesViewsTest(ModuleStoreTestCase):
Returns the course module. Returns the course module.
""" """
field_data_cache = FieldDataCache([self.course], self.course.id, self.user) field_data_cache = FieldDataCache([self.course], self.course.id, self.user)
return get_module_for_descriptor(self.user, MagicMock(), self.course, field_data_cache, self.course.id) return get_module_for_descriptor(
self.user, MagicMock(), self.course, field_data_cache, self.course.id, course=self.course
)
def test_edxnotes_tab(self): def test_edxnotes_tab(self):
""" """
......
...@@ -53,7 +53,9 @@ def edxnotes(request, course_id): ...@@ -53,7 +53,9 @@ def edxnotes(request, course_id):
field_data_cache = FieldDataCache.cache_for_descriptor_descendents( field_data_cache = FieldDataCache.cache_for_descriptor_descendents(
course.id, request.user, course, depth=2 course.id, request.user, course, depth=2
) )
course_module = get_module_for_descriptor(request.user, request, course, field_data_cache, course_key) course_module = get_module_for_descriptor(
request.user, request, course, field_data_cache, course_key, course=course
)
position = get_course_position(course_module) position = get_course_position(course_module)
if position: if position:
context.update({ context.update({
...@@ -103,7 +105,9 @@ def edxnotes_visibility(request, course_id): ...@@ -103,7 +105,9 @@ def edxnotes_visibility(request, course_id):
course_key = CourseKey.from_string(course_id) course_key = CourseKey.from_string(course_id)
course = get_course_with_access(request.user, "load", course_key) course = get_course_with_access(request.user, "load", course_key)
field_data_cache = FieldDataCache([course], course_key, request.user) field_data_cache = FieldDataCache([course], course_key, request.user)
course_module = get_module_for_descriptor(request.user, request, course, field_data_cache, course_key) course_module = get_module_for_descriptor(
request.user, request, course, field_data_cache, course_key, course=course
)
if not is_feature_enabled(course): if not is_feature_enabled(course):
raise Http404 raise Http404
......
...@@ -31,7 +31,7 @@ def hint_manager(request, course_id): ...@@ -31,7 +31,7 @@ def hint_manager(request, course_id):
""" """
course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id) course_key = SlashSeparatedCourseKey.from_deprecated_string(course_id)
try: try:
get_course_with_access(request.user, 'staff', course_key, depth=None) course = get_course_with_access(request.user, 'staff', course_key, depth=None)
except Http404: except Http404:
out = 'Sorry, but students are not allowed to access the hint manager!' out = 'Sorry, but students are not allowed to access the hint manager!'
return HttpResponse(out) return HttpResponse(out)
...@@ -57,13 +57,13 @@ def hint_manager(request, course_id): ...@@ -57,13 +57,13 @@ def hint_manager(request, course_id):
error_text = switch_dict[request.POST['op']](request, course_key, field) error_text = switch_dict[request.POST['op']](request, course_key, field)
if error_text is None: if error_text is None:
error_text = '' error_text = ''
render_dict = get_hints(request, course_key, field) render_dict = get_hints(request, course_key, field, course=course)
render_dict.update({'error': error_text}) render_dict.update({'error': error_text})
rendered_html = render_to_string('instructor/hint_manager_inner.html', render_dict) rendered_html = render_to_string('instructor/hint_manager_inner.html', render_dict)
return HttpResponse(json.dumps({'success': True, 'contents': rendered_html})) return HttpResponse(json.dumps({'success': True, 'contents': rendered_html}))
def get_hints(request, course_id, field): def get_hints(request, course_id, field, course=None): # pylint: disable=unused-argument
""" """
Load all of the hints submitted to the course. Load all of the hints submitted to the course.
...@@ -148,7 +148,7 @@ def location_to_problem_name(course_id, loc): ...@@ -148,7 +148,7 @@ def location_to_problem_name(course_id, loc):
return None return None
def delete_hints(request, course_id, field): def delete_hints(request, course_id, field, course=None): # pylint: disable=unused-argument
""" """
Deletes the hints specified. Deletes the hints specified.
...@@ -176,7 +176,7 @@ def delete_hints(request, course_id, field): ...@@ -176,7 +176,7 @@ def delete_hints(request, course_id, field):
this_problem.save() this_problem.save()
def change_votes(request, course_id, field): def change_votes(request, course_id, field, course=None): # pylint: disable=unused-argument
""" """
Updates the number of votes. Updates the number of votes.
...@@ -203,7 +203,7 @@ def change_votes(request, course_id, field): ...@@ -203,7 +203,7 @@ def change_votes(request, course_id, field):
this_problem.save() this_problem.save()
def add_hint(request, course_id, field): def add_hint(request, course_id, field, course=None):
""" """
Add a new hint. `request.POST`: Add a new hint. `request.POST`:
op op
...@@ -226,7 +226,14 @@ def add_hint(request, course_id, field): ...@@ -226,7 +226,14 @@ def add_hint(request, course_id, field):
except ItemNotFoundError: except ItemNotFoundError:
descriptors = [] descriptors = []
field_data_cache = model_data.FieldDataCache(descriptors, course_id, request.user) field_data_cache = model_data.FieldDataCache(descriptors, course_id, request.user)
hinter_module = module_render.get_module(request.user, request, problem_key, field_data_cache, course_id) hinter_module = module_render.get_module(
request.user,
request,
problem_key,
field_data_cache,
course_id,
course=course
)
if not hinter_module.validate_answer(answer): if not hinter_module.validate_answer(answer):
# Invalid answer. Don't add it to the database, or else the # Invalid answer. Don't add it to the database, or else the
# hinter will crash when we encounter it. # hinter will crash when we encounter it.
...@@ -247,7 +254,7 @@ def add_hint(request, course_id, field): ...@@ -247,7 +254,7 @@ def add_hint(request, course_id, field):
this_problem.save() this_problem.save()
def approve(request, course_id, field): def approve(request, course_id, field, course=None): # pylint: disable=unused-argument
""" """
Approve a list of hints, moving them from the mod_queue to the real Approve a list of hints, moving them from the mod_queue to the real
hint list. POST: hint list. POST:
......
...@@ -77,7 +77,7 @@ def post_submission_for_student(student, course, location, task_number, dry_run= ...@@ -77,7 +77,7 @@ def post_submission_for_student(student, course, location, task_number, dry_run=
request.host = hostname request.host = hostname
try: try:
module = get_module_for_student(student, location, request=request) module = get_module_for_student(student, location, request=request, course=course)
if module is None: if module is None:
print " WARNING: No state found." print " WARNING: No state found."
return False return False
......
...@@ -89,7 +89,7 @@ def calculate_task_statistics(students, course, location, task_number, write_to_ ...@@ -89,7 +89,7 @@ def calculate_task_statistics(students, course, location, task_number, write_to_
student = student_module.student student = student_module.student
print "{0}:{1}".format(student.id, student.username) print "{0}:{1}".format(student.id, student.username)
module = get_module_for_student(student, location) module = get_module_for_student(student, location, course=course)
if module is None: if module is None:
print " WARNING: No state found" print " WARNING: No state found"
students_with_no_state.append(student) students_with_no_state.append(student)
......
...@@ -227,7 +227,7 @@ class TestSetDueDateExtension(ModuleStoreTestCase): ...@@ -227,7 +227,7 @@ class TestSetDueDateExtension(ModuleStoreTestCase):
# just inject the override field storage in this brute force manner. # just inject the override field storage in this brute force manner.
for block in (course, week1, week2, week3, homework, assignment): for block in (course, week1, week2, week3, homework, assignment):
block._field_data = OverrideFieldData.wrap( # pylint: disable=protected-access block._field_data = OverrideFieldData.wrap( # pylint: disable=protected-access
user, block._field_data) # pylint: disable=protected-access user, course, block._field_data) # pylint: disable=protected-access
def tearDown(self): def tearDown(self):
super(TestSetDueDateExtension, self).tearDown() super(TestSetDueDateExtension, self).tearDown()
......
...@@ -27,7 +27,7 @@ class DummyRequest(object): ...@@ -27,7 +27,7 @@ class DummyRequest(object):
return False return False
def get_module_for_student(student, usage_key, request=None): def get_module_for_student(student, usage_key, request=None, course=None):
"""Return the module for the (student, location) using a DummyRequest.""" """Return the module for the (student, location) using a DummyRequest."""
if request is None: if request is None:
request = DummyRequest() request = DummyRequest()
...@@ -35,4 +35,4 @@ def get_module_for_student(student, usage_key, request=None): ...@@ -35,4 +35,4 @@ def get_module_for_student(student, usage_key, request=None):
descriptor = modulestore().get_item(usage_key, depth=0) descriptor = modulestore().get_item(usage_key, depth=0)
field_data_cache = FieldDataCache([descriptor], usage_key.course_key, student) field_data_cache = FieldDataCache([descriptor], usage_key.course_key, student)
return get_module(student, request, usage_key, field_data_cache) return get_module(student, request, usage_key, field_data_cache, course=course)
...@@ -406,7 +406,7 @@ def _get_track_function_for_task(student, xmodule_instance_args=None, source_pag ...@@ -406,7 +406,7 @@ def _get_track_function_for_task(student, xmodule_instance_args=None, source_pag
def _get_module_instance_for_task(course_id, student, module_descriptor, xmodule_instance_args=None, def _get_module_instance_for_task(course_id, student, module_descriptor, xmodule_instance_args=None,
grade_bucket_type=None): grade_bucket_type=None, course=None):
""" """
Fetches a StudentModule instance for a given `course_id`, `student` object, and `module_descriptor`. Fetches a StudentModule instance for a given `course_id`, `student` object, and `module_descriptor`.
...@@ -445,6 +445,8 @@ def _get_module_instance_for_task(course_id, student, module_descriptor, xmodule ...@@ -445,6 +445,8 @@ def _get_module_instance_for_task(course_id, student, module_descriptor, xmodule
grade_bucket_type=grade_bucket_type, grade_bucket_type=grade_bucket_type,
# This module isn't being used for front-end rendering # This module isn't being used for front-end rendering
request_token=None, request_token=None,
# pass in a loaded course for override enabling
course=course
) )
...@@ -465,13 +467,28 @@ def rescore_problem_module_state(xmodule_instance_args, module_descriptor, stude ...@@ -465,13 +467,28 @@ def rescore_problem_module_state(xmodule_instance_args, module_descriptor, stude
course_id = student_module.course_id course_id = student_module.course_id
student = student_module.student student = student_module.student
usage_key = student_module.module_state_key usage_key = student_module.module_state_key
instance = _get_module_instance_for_task(course_id, student, module_descriptor, xmodule_instance_args, grade_bucket_type='rescore')
with modulestore().bulk_operations(course_id):
course = get_course_by_id(course_id)
# TODO: Here is a call site where we could pass in a loaded course. I
# think we certainly need it since grading is happening here, and field
# overrides would be important in handling that correctly
instance = _get_module_instance_for_task(
course_id,
student,
module_descriptor,
xmodule_instance_args,
grade_bucket_type='rescore',
course=course
)
if instance is None: if instance is None:
# Either permissions just changed, or someone is trying to be clever # Either permissions just changed, or someone is trying to be clever
# and load something they shouldn't have access to. # and load something they shouldn't have access to.
msg = "No module {loc} for student {student}--access denied?".format(loc=usage_key, msg = "No module {loc} for student {student}--access denied?".format(
student=student) loc=usage_key,
student=student
)
TASK_LOG.debug(msg) TASK_LOG.debug(msg)
raise UpdateProblemModuleStateError(msg) raise UpdateProblemModuleStateError(msg)
...@@ -485,16 +502,40 @@ def rescore_problem_module_state(xmodule_instance_args, module_descriptor, stude ...@@ -485,16 +502,40 @@ def rescore_problem_module_state(xmodule_instance_args, module_descriptor, stude
instance.save() instance.save()
if 'success' not in result: if 'success' not in result:
# don't consider these fatal, but false means that the individual call didn't complete: # don't consider these fatal, but false means that the individual call didn't complete:
TASK_LOG.warning(u"error processing rescore call for course {course}, problem {loc} and student {student}: " TASK_LOG.warning(
u"unexpected response {msg}".format(msg=result, course=course_id, loc=usage_key, student=student)) u"error processing rescore call for course %(course)s, problem %(loc)s "
u"and student %(student)s: unexpected response %(msg)s",
dict(
msg=result,
course=course_id,
loc=usage_key,
student=student
)
)
return UPDATE_STATUS_FAILED return UPDATE_STATUS_FAILED
elif result['success'] not in ['correct', 'incorrect']: elif result['success'] not in ['correct', 'incorrect']:
TASK_LOG.warning(u"error processing rescore call for course {course}, problem {loc} and student {student}: " TASK_LOG.warning(
u"{msg}".format(msg=result['success'], course=course_id, loc=usage_key, student=student)) u"error processing rescore call for course %(course)s, problem %(loc)s "
u"and student %(student)s: %(msg)s",
dict(
msg=result['success'],
course=course_id,
loc=usage_key,
student=student
)
)
return UPDATE_STATUS_FAILED return UPDATE_STATUS_FAILED
else: else:
TASK_LOG.debug(u"successfully processed rescore call for course {course}, problem {loc} and student {student}: " TASK_LOG.debug(
u"{msg}".format(msg=result['success'], course=course_id, loc=usage_key, student=student)) u"successfully processed rescore call for course %(course)s, problem %(loc)s "
u"and student %(student)s: %(msg)s",
dict(
msg=result['success'],
course=course_id,
loc=usage_key,
student=student
)
)
return UPDATE_STATUS_SUCCEEDED return UPDATE_STATUS_SUCCEEDED
......
...@@ -106,7 +106,9 @@ class UserCourseStatus(views.APIView): ...@@ -106,7 +106,9 @@ class UserCourseStatus(views.APIView):
field_data_cache = FieldDataCache.cache_for_descriptor_descendents( field_data_cache = FieldDataCache.cache_for_descriptor_descendents(
course.id, request.user, course, depth=2) course.id, request.user, course, depth=2)
course_module = get_module_for_descriptor(request.user, request, course, field_data_cache, course.id) course_module = get_module_for_descriptor(
request.user, request, course, field_data_cache, course.id, course=course
)
path = [course_module] path = [course_module]
chapter = get_current_child(course_module, min_depth=2) chapter = get_current_child(course_module, min_depth=2)
...@@ -140,7 +142,9 @@ class UserCourseStatus(views.APIView): ...@@ -140,7 +142,9 @@ class UserCourseStatus(views.APIView):
module_descriptor = modulestore().get_item(module_key) module_descriptor = modulestore().get_item(module_key)
except ItemNotFoundError: except ItemNotFoundError:
return Response(errors.ERROR_INVALID_MODULE_ID, status=400) return Response(errors.ERROR_INVALID_MODULE_ID, status=400)
module = get_module_for_descriptor(request.user, request, module_descriptor, field_data_cache, course.id) module = get_module_for_descriptor(
request.user, request, module_descriptor, field_data_cache, course.id, course=course
)
if modification_date: if modification_date:
key = KeyValueStore.Key( key = KeyValueStore.Key(
...@@ -154,7 +158,7 @@ class UserCourseStatus(views.APIView): ...@@ -154,7 +158,7 @@ class UserCourseStatus(views.APIView):
# old modification date so skip update # old modification date so skip update
return self._get_course_info(request, course) return self._get_course_info(request, course)
save_positions_recursively_up(request.user, request, field_data_cache, module) save_positions_recursively_up(request.user, request, field_data_cache, module, course=course)
return self._get_course_info(request, course) return self._get_course_info(request, course)
@mobile_course_access(depth=2) @mobile_course_access(depth=2)
......
...@@ -4,7 +4,9 @@ Serializer for video outline ...@@ -4,7 +4,9 @@ Serializer for video outline
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from xmodule.modulestore.mongo.base import BLOCK_TYPES_WITH_CHILDREN from xmodule.modulestore.mongo.base import BLOCK_TYPES_WITH_CHILDREN
from xmodule.modulestore.django import modulestore
from courseware.access import has_access from courseware.access import has_access
from courseware.courses import get_course_by_id
from courseware.model_data import FieldDataCache from courseware.model_data import FieldDataCache
from courseware.module_render import get_module_for_descriptor from courseware.module_render import get_module_for_descriptor
from util.module_utils import get_dynamic_descriptor_children from util.module_utils import get_dynamic_descriptor_children
...@@ -49,10 +51,12 @@ class BlockOutline(object): ...@@ -49,10 +51,12 @@ class BlockOutline(object):
field_data_cache = FieldDataCache.cache_for_descriptor_descendents( field_data_cache = FieldDataCache.cache_for_descriptor_descendents(
self.course_id, self.request.user, descriptor, depth=0, self.course_id, self.request.user, descriptor, depth=0,
) )
course = get_course_by_id(self.course_id)
return get_module_for_descriptor( return get_module_for_descriptor(
self.request.user, self.request, descriptor, field_data_cache, self.course_id self.request.user, self.request, descriptor, field_data_cache, self.course_id, course=course
) )
with modulestore().bulk_operations(self.course_id):
child_to_parent = {} child_to_parent = {}
stack = [self.start_block] stack = [self.start_block]
while stack: while stack:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment