Commit d39da7fd by Calen Pennington

Merge pull request #4780 from cpennington/split-bulk-operations

Implement bulk operations on the split modulestore.
parents ae240ae7 0e7e266a
...@@ -38,7 +38,7 @@ class Command(BaseCommand): ...@@ -38,7 +38,7 @@ class Command(BaseCommand):
print("Cloning course {0} to {1}".format(source_course_id, dest_course_id)) print("Cloning course {0} to {1}".format(source_course_id, dest_course_id))
with mstore.bulk_write_operations(dest_course_id): with mstore.bulk_operations(dest_course_id):
if mstore.clone_course(source_course_id, dest_course_id, ModuleStoreEnum.UserID.mgmt_command): if mstore.clone_course(source_course_id, dest_course_id, ModuleStoreEnum.UserID.mgmt_command):
print("copying User permissions...") print("copying User permissions...")
# purposely avoids auth.add_user b/c it doesn't have a caller to authorize # purposely avoids auth.add_user b/c it doesn't have a caller to authorize
......
...@@ -867,7 +867,7 @@ class ContentStoreToyCourseTest(ContentStoreTestCase): ...@@ -867,7 +867,7 @@ class ContentStoreToyCourseTest(ContentStoreTestCase):
# so we don't need to make an extra query to compute it. # so we don't need to make an extra query to compute it.
# set the branch to 'publish' in order to prevent extra lookups of draft versions # set the branch to 'publish' in order to prevent extra lookups of draft versions
with mongo_store.branch_setting(ModuleStoreEnum.Branch.published_only): with mongo_store.branch_setting(ModuleStoreEnum.Branch.published_only):
with check_mongo_calls(mongo_store, 4, 0): with check_mongo_calls(4, 0):
course = mongo_store.get_course(course_id, depth=2) course = mongo_store.get_course(course_id, depth=2)
# make sure we pre-fetched a known sequential which should be at depth=2 # make sure we pre-fetched a known sequential which should be at depth=2
...@@ -879,7 +879,7 @@ class ContentStoreToyCourseTest(ContentStoreTestCase): ...@@ -879,7 +879,7 @@ class ContentStoreToyCourseTest(ContentStoreTestCase):
# Now, test with the branch set to draft. No extra round trips b/c it doesn't go deep enough to get # Now, test with the branch set to draft. No extra round trips b/c it doesn't go deep enough to get
# beyond direct only categories # beyond direct only categories
with mongo_store.branch_setting(ModuleStoreEnum.Branch.draft_preferred): with mongo_store.branch_setting(ModuleStoreEnum.Branch.draft_preferred):
with check_mongo_calls(mongo_store, 4, 0): with check_mongo_calls(4, 0):
mongo_store.get_course(course_id, depth=2) mongo_store.get_course(course_id, depth=2)
def test_export_course_without_content_store(self): def test_export_course_without_content_store(self):
......
...@@ -201,10 +201,11 @@ class TestCourseListing(ModuleStoreTestCase): ...@@ -201,10 +201,11 @@ class TestCourseListing(ModuleStoreTestCase):
# Now count the db queries # Now count the db queries
store = modulestore()._get_modulestore_by_type(ModuleStoreEnum.Type.mongo) store = modulestore()._get_modulestore_by_type(ModuleStoreEnum.Type.mongo)
with check_mongo_calls(store, USER_COURSES_COUNT): with check_mongo_calls(USER_COURSES_COUNT):
_accessible_courses_list_from_groups(self.request) _accessible_courses_list_from_groups(self.request)
with check_mongo_calls(store, 1): # TODO: LMS-11220: Document why this takes 6 calls
with check_mongo_calls(6):
_accessible_courses_list(self.request) _accessible_courses_list(self.request)
def test_get_course_list_with_same_course_id(self): def test_get_course_list_with_same_course_id(self):
......
...@@ -205,10 +205,9 @@ class TemplateTests(unittest.TestCase): ...@@ -205,10 +205,9 @@ class TemplateTests(unittest.TestCase):
data="<problem></problem>" data="<problem></problem>"
) )
# course root only updated 2x # The draft course root has 2 revisions: the published revision, and then the subsequent
# changes to the draft revision
version_history = self.split_store.get_block_generations(test_course.location) version_history = self.split_store.get_block_generations(test_course.location)
# create course causes 2 versions for the time being; skip the first.
version_history = version_history.children[0]
self.assertEqual(version_history.locator.version_guid, test_course.location.version_guid) self.assertEqual(version_history.locator.version_guid, test_course.location.version_guid)
self.assertEqual(len(version_history.children), 1) self.assertEqual(len(version_history.children), 1)
self.assertEqual(version_history.children[0].children, []) self.assertEqual(version_history.children[0].children, [])
......
...@@ -33,7 +33,7 @@ class ContentStoreImportTest(ModuleStoreTestCase): ...@@ -33,7 +33,7 @@ class ContentStoreImportTest(ModuleStoreTestCase):
""" """
def setUp(self): def setUp(self):
password = super(ContentStoreImportTest, self).setUp() password = super(ContentStoreImportTest, self).setUp()
self.client = Client() self.client = Client()
self.client.login(username=self.user.username, password=password) self.client.login(username=self.user.username, password=password)
...@@ -157,15 +157,15 @@ class ContentStoreImportTest(ModuleStoreTestCase): ...@@ -157,15 +157,15 @@ class ContentStoreImportTest(ModuleStoreTestCase):
store = modulestore()._get_modulestore_by_type(ModuleStoreEnum.Type.mongo) store = modulestore()._get_modulestore_by_type(ModuleStoreEnum.Type.mongo)
# we try to refresh the inheritance tree for each update_item in the import # we try to refresh the inheritance tree for each update_item in the import
with check_exact_number_of_calls(store, store.refresh_cached_metadata_inheritance_tree, 28): with check_exact_number_of_calls(store, 'refresh_cached_metadata_inheritance_tree', 28):
# _get_cached_metadata_inheritance_tree should be called only once # _get_cached_metadata_inheritance_tree should be called only once
with check_exact_number_of_calls(store, store._get_cached_metadata_inheritance_tree, 1): with check_exact_number_of_calls(store, '_get_cached_metadata_inheritance_tree', 1):
# with bulk-edit in progress, the inheritance tree should be recomputed only at the end of the import # with bulk-edit in progress, the inheritance tree should be recomputed only at the end of the import
# NOTE: On Jenkins, with memcache enabled, the number of calls here is only 1. # NOTE: On Jenkins, with memcache enabled, the number of calls here is only 1.
# Locally, without memcache, the number of calls is actually 2 (once more during the publish step) # Locally, without memcache, the number of calls is actually 2 (once more during the publish step)
with check_number_of_calls(store, store._compute_metadata_inheritance_tree, 2): with check_number_of_calls(store, '_compute_metadata_inheritance_tree', 2):
self.load_test_import_course() self.load_test_import_course()
def test_rewrite_reference_list(self): def test_rewrite_reference_list(self):
......
...@@ -72,7 +72,7 @@ def delete_course_and_groups(course_key, user_id): ...@@ -72,7 +72,7 @@ def delete_course_and_groups(course_key, user_id):
""" """
module_store = modulestore() module_store = modulestore()
with module_store.bulk_write_operations(course_key): with module_store.bulk_operations(course_key):
module_store.delete_course(course_key, user_id) module_store.delete_course(course_key, user_id)
print 'removing User permissions from course....' print 'removing User permissions from course....'
......
...@@ -423,7 +423,7 @@ def course_index(request, course_key): ...@@ -423,7 +423,7 @@ def course_index(request, course_key):
""" """
# A depth of None implies the whole course. The course outline needs this in order to compute has_changes. # A depth of None implies the whole course. The course outline needs this in order to compute has_changes.
# A unit may not have a draft version, but one of its components could, and hence the unit itself has changes. # A unit may not have a draft version, but one of its components could, and hence the unit itself has changes.
with modulestore().bulk_write_operations(course_key): with modulestore().bulk_operations(course_key):
course_module = _get_course_module(course_key, request.user, depth=None) course_module = _get_course_module(course_key, request.user, depth=None)
lms_link = get_lms_link_for_item(course_module.location) lms_link = get_lms_link_for_item(course_module.location)
sections = course_module.get_children() sections = course_module.get_children()
......
...@@ -310,6 +310,13 @@ class ModuleStoreRead(object): ...@@ -310,6 +310,13 @@ class ModuleStoreRead(object):
""" """
pass pass
@contextmanager
def bulk_operations(self, course_id):
"""
A context manager for notifying the store of bulk operations. This affects only the current thread.
"""
yield
class ModuleStoreWrite(ModuleStoreRead): class ModuleStoreWrite(ModuleStoreRead):
""" """
...@@ -543,6 +550,33 @@ class ModuleStoreReadBase(ModuleStoreRead): ...@@ -543,6 +550,33 @@ class ModuleStoreReadBase(ModuleStoreRead):
raise ValueError(u"Cannot set default store to type {}".format(store_type)) raise ValueError(u"Cannot set default store to type {}".format(store_type))
yield yield
@contextmanager
def bulk_operations(self, course_id):
"""
A context manager for notifying the store of bulk operations. This affects only the current thread.
In the case of Mongo, it temporarily disables refreshing the metadata inheritance tree
until the bulk operation is completed.
"""
# TODO: Make this multi-process-safe if future operations need it.
try:
self._begin_bulk_operation(course_id)
yield
finally:
self._end_bulk_operation(course_id)
def _begin_bulk_operation(self, course_id):
"""
Begin a bulk write operation on course_id.
"""
pass
def _end_bulk_operation(self, course_id):
"""
End the active bulk write operation on course_id.
"""
pass
class ModuleStoreWriteBase(ModuleStoreReadBase, ModuleStoreWrite): class ModuleStoreWriteBase(ModuleStoreReadBase, ModuleStoreWrite):
''' '''
...@@ -643,29 +677,6 @@ class ModuleStoreWriteBase(ModuleStoreReadBase, ModuleStoreWrite): ...@@ -643,29 +677,6 @@ class ModuleStoreWriteBase(ModuleStoreReadBase, ModuleStoreWrite):
parent.children.append(item.location) parent.children.append(item.location)
self.update_item(parent, user_id) self.update_item(parent, user_id)
@contextmanager
def bulk_write_operations(self, course_id):
"""
A context manager for notifying the store of bulk write events.
In the case of Mongo, it temporarily disables refreshing the metadata inheritance tree
until the bulk operation is completed.
"""
# TODO
# Make this multi-process-safe if future operations need it.
# Right now, only Import Course, Clone Course, and Delete Course use this, so
# it's ok if the cached metadata in the memcache is invalid when another
# request comes in for the same course.
try:
if hasattr(self, '_begin_bulk_write_operation'):
self._begin_bulk_write_operation(course_id)
yield
finally:
# check for the begin method here,
# since it's an error if an end method is not defined when a begin method is
if hasattr(self, '_begin_bulk_write_operation'):
self._end_bulk_write_operation(course_id)
def only_xmodules(identifier, entry_points): def only_xmodules(identifier, entry_points):
"""Only use entry_points that are supplied by the xmodule package""" """Only use entry_points that are supplied by the xmodule package"""
......
...@@ -54,20 +54,16 @@ class DuplicateItemError(Exception): ...@@ -54,20 +54,16 @@ class DuplicateItemError(Exception):
self, Exception.__str__(self, *args, **kwargs) self, Exception.__str__(self, *args, **kwargs)
) )
class VersionConflictError(Exception): class VersionConflictError(Exception):
""" """
The caller asked for either draft or published head and gave a version which conflicted with it. The caller asked for either draft or published head and gave a version which conflicted with it.
""" """
def __init__(self, requestedLocation, currentHeadVersionGuid): def __init__(self, requestedLocation, currentHeadVersionGuid):
super(VersionConflictError, self).__init__() super(VersionConflictError, self).__init__(u'Requested {}, but current head is {}'.format(
self.requestedLocation = requestedLocation requestedLocation,
self.currentHeadVersionGuid = currentHeadVersionGuid currentHeadVersionGuid
))
def __str__(self, *args, **kwargs):
"""
Print requested and current head info
"""
return u'Requested {} but {} is current head'.format(self.requestedLocation, self.currentHeadVersionGuid)
class DuplicateCourseError(Exception): class DuplicateCourseError(Exception):
......
...@@ -645,11 +645,11 @@ class MixedModuleStore(ModuleStoreDraftAndPublished, ModuleStoreWriteBase): ...@@ -645,11 +645,11 @@ class MixedModuleStore(ModuleStoreDraftAndPublished, ModuleStoreWriteBase):
yield yield
@contextmanager @contextmanager
def bulk_write_operations(self, course_id): def bulk_operations(self, course_id):
""" """
A context manager for notifying the store of bulk write events. A context manager for notifying the store of bulk operations.
If course_id is None, the default store is used. If course_id is None, the default store is used.
""" """
store = self._get_modulestore_for_courseid(course_id) store = self._get_modulestore_for_courseid(course_id)
with store.bulk_write_operations(course_id): with store.bulk_operations(course_id):
yield yield
...@@ -17,6 +17,7 @@ import sys ...@@ -17,6 +17,7 @@ import sys
import logging import logging
import copy import copy
import re import re
import threading
from uuid import uuid4 from uuid import uuid4
from bson.son import SON from bson.son import SON
...@@ -414,7 +415,7 @@ class MongoModuleStore(ModuleStoreDraftAndPublished, ModuleStoreWriteBase): ...@@ -414,7 +415,7 @@ class MongoModuleStore(ModuleStoreDraftAndPublished, ModuleStoreWriteBase):
# performance optimization to prevent updating the meta-data inheritance tree during # performance optimization to prevent updating the meta-data inheritance tree during
# bulk write operations # bulk write operations
self.ignore_write_events_on_courses = set() self.ignore_write_events_on_courses = threading.local()
self._course_run_cache = {} self._course_run_cache = {}
def close_connections(self): def close_connections(self):
...@@ -435,27 +436,36 @@ class MongoModuleStore(ModuleStoreDraftAndPublished, ModuleStoreWriteBase): ...@@ -435,27 +436,36 @@ class MongoModuleStore(ModuleStoreDraftAndPublished, ModuleStoreWriteBase):
connection.drop_database(self.collection.database) connection.drop_database(self.collection.database)
connection.close() connection.close()
def _begin_bulk_write_operation(self, course_id): def _begin_bulk_operation(self, course_id):
""" """
Prevent updating the meta-data inheritance cache for the given course Prevent updating the meta-data inheritance cache for the given course
""" """
self.ignore_write_events_on_courses.add(course_id) if not hasattr(self.ignore_write_events_on_courses, 'courses'):
self.ignore_write_events_on_courses.courses = set()
def _end_bulk_write_operation(self, course_id): self.ignore_write_events_on_courses.courses.add(course_id)
def _end_bulk_operation(self, course_id):
""" """
Restart updating the meta-data inheritance cache for the given course. Restart updating the meta-data inheritance cache for the given course.
Refresh the meta-data inheritance cache now since it was temporarily disabled. Refresh the meta-data inheritance cache now since it was temporarily disabled.
""" """
if course_id in self.ignore_write_events_on_courses: if not hasattr(self.ignore_write_events_on_courses, 'courses'):
self.ignore_write_events_on_courses.remove(course_id) return
if course_id in self.ignore_write_events_on_courses.courses:
self.ignore_write_events_on_courses.courses.remove(course_id)
self.refresh_cached_metadata_inheritance_tree(course_id) self.refresh_cached_metadata_inheritance_tree(course_id)
def _is_bulk_write_in_progress(self, course_id): def _is_bulk_write_in_progress(self, course_id):
""" """
Returns whether a bulk write operation is in progress for the given course. Returns whether a bulk write operation is in progress for the given course.
""" """
if not hasattr(self.ignore_write_events_on_courses, 'courses'):
return False
course_id = course_id.for_branch(None) course_id = course_id.for_branch(None)
return course_id in self.ignore_write_events_on_courses return course_id in self.ignore_write_events_on_courses.courses
def fill_in_run(self, course_key): def fill_in_run(self, course_key):
""" """
......
...@@ -68,39 +68,40 @@ def path_to_location(modulestore, usage_key): ...@@ -68,39 +68,40 @@ def path_to_location(modulestore, usage_key):
newpath = (next_usage, path) newpath = (next_usage, path)
queue.append((parent, newpath)) queue.append((parent, newpath))
if not modulestore.has_item(usage_key): with modulestore.bulk_operations(usage_key.course_key):
raise ItemNotFoundError(usage_key) if not modulestore.has_item(usage_key):
raise ItemNotFoundError(usage_key)
path = find_path_to_course()
if path is None: path = find_path_to_course()
raise NoPathToItem(usage_key) if path is None:
raise NoPathToItem(usage_key)
n = len(path)
course_id = path[0].course_key n = len(path)
# pull out the location names course_id = path[0].course_key
chapter = path[1].name if n > 1 else None # pull out the location names
section = path[2].name if n > 2 else None chapter = path[1].name if n > 1 else None
# Figure out the position section = path[2].name if n > 2 else None
position = None # Figure out the position
position = None
# This block of code will find the position of a module within a nested tree
# of modules. If a problem is on tab 2 of a sequence that's on tab 3 of a # This block of code will find the position of a module within a nested tree
# sequence, the resulting position is 3_2. However, no positional modules # of modules. If a problem is on tab 2 of a sequence that's on tab 3 of a
# (e.g. sequential and videosequence) currently deal with this form of # sequence, the resulting position is 3_2. However, no positional modules
# representing nested positions. This needs to happen before jumping to a # (e.g. sequential and videosequence) currently deal with this form of
# module nested in more than one positional module will work. # representing nested positions. This needs to happen before jumping to a
if n > 3: # module nested in more than one positional module will work.
position_list = [] if n > 3:
for path_index in range(2, n - 1): position_list = []
category = path[path_index].block_type for path_index in range(2, n - 1):
if category == 'sequential' or category == 'videosequence': category = path[path_index].block_type
section_desc = modulestore.get_item(path[path_index]) if category == 'sequential' or category == 'videosequence':
# this calls get_children rather than just children b/c old mongo includes private children section_desc = modulestore.get_item(path[path_index])
# in children but not in get_children # this calls get_children rather than just children b/c old mongo includes private children
child_locs = [c.location for c in section_desc.get_children()] # in children but not in get_children
# positions are 1-indexed, and should be strings to be consistent with child_locs = [c.location for c in section_desc.get_children()]
# url parsing. # positions are 1-indexed, and should be strings to be consistent with
position_list.append(str(child_locs.index(path[path_index + 1]) + 1)) # url parsing.
position = "_".join(position_list) position_list.append(str(child_locs.index(path[path_index + 1]) + 1))
position = "_".join(position_list)
return (course_id, chapter, section, position)
return (course_id, chapter, section, position)
...@@ -55,23 +55,27 @@ class SplitMigrator(object): ...@@ -55,23 +55,27 @@ class SplitMigrator(object):
new_run = source_course_key.run new_run = source_course_key.run
new_course_key = CourseLocator(new_org, new_course, new_run, branch=ModuleStoreEnum.BranchName.published) new_course_key = CourseLocator(new_org, new_course, new_run, branch=ModuleStoreEnum.BranchName.published)
new_fields = self._get_fields_translate_references(original_course, new_course_key, None) with self.split_modulestore.bulk_operations(new_course_key):
if fields: new_fields = self._get_fields_translate_references(original_course, new_course_key, None)
new_fields.update(fields) if fields:
new_course = self.split_modulestore.create_course( new_fields.update(fields)
new_org, new_course, new_run, user_id, new_course = self.split_modulestore.create_course(
fields=new_fields, new_org, new_course, new_run, user_id,
master_branch=ModuleStoreEnum.BranchName.published, fields=new_fields,
skip_auto_publish=True, master_branch=ModuleStoreEnum.BranchName.published,
**kwargs skip_auto_publish=True,
) **kwargs
)
with self.split_modulestore.bulk_write_operations(new_course.id):
self._copy_published_modules_to_course( self._copy_published_modules_to_course(
new_course, original_course.location, source_course_key, user_id, **kwargs new_course, original_course.location, source_course_key, user_id, **kwargs
) )
# create a new version for the drafts
with self.split_modulestore.bulk_write_operations(new_course.id): # TODO: This should be merged back into the above transaction, but can't be until split.py
# is refactored to have more coherent access patterns
with self.split_modulestore.bulk_operations(new_course_key):
# create a new version for the drafts
self._add_draft_modules_to_course(new_course.location, source_course_key, user_id, **kwargs) self._add_draft_modules_to_course(new_course.location, source_course_key, user_id, **kwargs)
return new_course.id return new_course.id
...@@ -80,7 +84,7 @@ class SplitMigrator(object): ...@@ -80,7 +84,7 @@ class SplitMigrator(object):
""" """
Copy all of the modules from the 'direct' version of the course to the new split course. Copy all of the modules from the 'direct' version of the course to the new split course.
""" """
course_version_locator = new_course.id course_version_locator = new_course.id.version_agnostic()
# iterate over published course elements. Wildcarding rather than descending b/c some elements are orphaned (e.g., # iterate over published course elements. Wildcarding rather than descending b/c some elements are orphaned (e.g.,
# course about pages, conditionals) # course about pages, conditionals)
...@@ -101,7 +105,6 @@ class SplitMigrator(object): ...@@ -101,7 +105,6 @@ class SplitMigrator(object):
fields=self._get_fields_translate_references( fields=self._get_fields_translate_references(
module, course_version_locator, new_course.location.block_id module, course_version_locator, new_course.location.block_id
), ),
continue_version=True,
skip_auto_publish=True, skip_auto_publish=True,
**kwargs **kwargs
) )
...@@ -109,7 +112,7 @@ class SplitMigrator(object): ...@@ -109,7 +112,7 @@ class SplitMigrator(object):
index_info = self.split_modulestore.get_course_index_info(course_version_locator) index_info = self.split_modulestore.get_course_index_info(course_version_locator)
versions = index_info['versions'] versions = index_info['versions']
versions[ModuleStoreEnum.BranchName.draft] = versions[ModuleStoreEnum.BranchName.published] versions[ModuleStoreEnum.BranchName.draft] = versions[ModuleStoreEnum.BranchName.published]
self.split_modulestore.update_course_index(index_info) self.split_modulestore.update_course_index(course_version_locator, index_info)
# clean up orphans in published version: in old mongo, parents pointed to the union of their published and draft # clean up orphans in published version: in old mongo, parents pointed to the union of their published and draft
# children which meant some pointers were to non-existent locations in 'direct' # children which meant some pointers were to non-existent locations in 'direct'
......
...@@ -2,7 +2,7 @@ import sys ...@@ -2,7 +2,7 @@ import sys
import logging import logging
from xblock.runtime import KvsFieldData from xblock.runtime import KvsFieldData
from xblock.fields import ScopeIds from xblock.fields import ScopeIds
from opaque_keys.edx.locator import BlockUsageLocator, LocalId, CourseLocator from opaque_keys.edx.locator import BlockUsageLocator, LocalId, CourseLocator, DefinitionLocator
from xmodule.mako_module import MakoDescriptorSystem from xmodule.mako_module import MakoDescriptorSystem
from xmodule.error_module import ErrorDescriptor from xmodule.error_module import ErrorDescriptor
from xmodule.errortracker import exc_info_to_str from xmodule.errortracker import exc_info_to_str
...@@ -10,6 +10,7 @@ from xmodule.modulestore.split_mongo import encode_key_for_mongo ...@@ -10,6 +10,7 @@ from xmodule.modulestore.split_mongo import encode_key_for_mongo
from ..exceptions import ItemNotFoundError from ..exceptions import ItemNotFoundError
from .split_mongo_kvs import SplitMongoKVS from .split_mongo_kvs import SplitMongoKVS
from fs.osfs import OSFS from fs.osfs import OSFS
from .definition_lazy_loader import DefinitionLazyLoader
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -120,9 +121,24 @@ class CachingDescriptorSystem(MakoDescriptorSystem): ...@@ -120,9 +121,24 @@ class CachingDescriptorSystem(MakoDescriptorSystem):
self.course_entry['org'] = course_entry_override['org'] self.course_entry['org'] = course_entry_override['org']
self.course_entry['course'] = course_entry_override['course'] self.course_entry['course'] = course_entry_override['course']
self.course_entry['run'] = course_entry_override['run'] self.course_entry['run'] = course_entry_override['run']
# most likely a lazy loader or the id directly
definition = json_data.get('definition', {}) definition_id = json_data.get('definition')
definition_id = self.modulestore.definition_locator(definition) block_type = json_data['category']
if definition_id is not None and not json_data.get('definition_loaded', False):
definition_loader = DefinitionLazyLoader(
self.modulestore, block_type, definition_id,
lambda fields: self.modulestore.convert_references_to_keys(
course_key, self.load_block_type(block_type),
fields, self.course_entry['structure']['blocks'],
)
)
else:
definition_loader = None
# If no definition id is provide, generate an in-memory id
if definition_id is None:
definition_id = LocalId()
# If no usage id is provided, generate an in-memory id # If no usage id is provided, generate an in-memory id
if block_id is None: if block_id is None:
...@@ -130,7 +146,7 @@ class CachingDescriptorSystem(MakoDescriptorSystem): ...@@ -130,7 +146,7 @@ class CachingDescriptorSystem(MakoDescriptorSystem):
block_locator = BlockUsageLocator( block_locator = BlockUsageLocator(
course_key, course_key,
block_type=json_data.get('category'), block_type=block_type,
block_id=block_id, block_id=block_id,
) )
...@@ -138,7 +154,7 @@ class CachingDescriptorSystem(MakoDescriptorSystem): ...@@ -138,7 +154,7 @@ class CachingDescriptorSystem(MakoDescriptorSystem):
block_locator.course_key, class_, json_data.get('fields', {}), self.course_entry['structure']['blocks'], block_locator.course_key, class_, json_data.get('fields', {}), self.course_entry['structure']['blocks'],
) )
kvs = SplitMongoKVS( kvs = SplitMongoKVS(
definition, definition_loader,
converted_fields, converted_fields,
json_data.get('_inherited_settings'), json_data.get('_inherited_settings'),
**kwargs **kwargs
...@@ -148,7 +164,7 @@ class CachingDescriptorSystem(MakoDescriptorSystem): ...@@ -148,7 +164,7 @@ class CachingDescriptorSystem(MakoDescriptorSystem):
try: try:
module = self.construct_xblock_from_class( module = self.construct_xblock_from_class(
class_, class_,
ScopeIds(None, json_data.get('category'), definition_id, block_locator), ScopeIds(None, block_type, definition_id, block_locator),
field_data, field_data,
) )
except Exception: except Exception:
...@@ -174,7 +190,7 @@ class CachingDescriptorSystem(MakoDescriptorSystem): ...@@ -174,7 +190,7 @@ class CachingDescriptorSystem(MakoDescriptorSystem):
module.previous_version = edit_info.get('previous_version') module.previous_version = edit_info.get('previous_version')
module.update_version = edit_info.get('update_version') module.update_version = edit_info.get('update_version')
module.source_version = edit_info.get('source_version', None) module.source_version = edit_info.get('source_version', None)
module.definition_locator = definition_id module.definition_locator = DefinitionLocator(block_type, definition_id)
# decache any pending field settings # decache any pending field settings
module.save() module.save()
......
from opaque_keys.edx.locator import DefinitionLocator from opaque_keys.edx.locator import DefinitionLocator
from bson import SON
class DefinitionLazyLoader(object): class DefinitionLazyLoader(object):
...@@ -24,3 +25,9 @@ class DefinitionLazyLoader(object): ...@@ -24,3 +25,9 @@ class DefinitionLazyLoader(object):
loader pointer with the result so as not to fetch more than once loader pointer with the result so as not to fetch more than once
""" """
return self.modulestore.db_connection.get_definition(self.definition_locator.definition_id) return self.modulestore.db_connection.get_definition(self.definition_locator.definition_id)
def as_son(self):
return SON((
('category', self.definition_locator.block_type),
('definition', self.definition_locator.definition_id)
))
...@@ -57,24 +57,42 @@ class MongoConnection(object): ...@@ -57,24 +57,42 @@ class MongoConnection(object):
""" """
return self.structures.find_one({'_id': key}) return self.structures.find_one({'_id': key})
def find_matching_structures(self, query): def find_structures_by_id(self, ids):
""" """
Find the structure matching the query. Right now the query must be a legal mongo query Return all structures that specified in ``ids``.
:param query: a mongo-style query of {key: [value|{$in ..}|..], ..}
Arguments:
ids (list): A list of structure ids
"""
return self.structures.find({'_id': {'$in': ids}})
def find_structures_derived_from(self, ids):
"""
Return all structures that were immediately derived from a structure listed in ``ids``.
Arguments:
ids (list): A list of structure ids
""" """
return self.structures.find(query) return self.structures.find({'previous_version': {'$in': ids}})
def insert_structure(self, structure): def find_ancestor_structures(self, original_version, block_id):
""" """
Create the structure in the db Find all structures that originated from ``original_version`` that contain ``block_id``.
Arguments:
original_version (str or ObjectID): The id of a structure
block_id (str): The id of the block in question
""" """
self.structures.insert(structure) return self.structures.find({
'original_version': original_version,
'blocks.{}.edit_info.update_version'.format(block_id): {'$exists': True}
})
def update_structure(self, structure): def upsert_structure(self, structure):
""" """
Update the db record for structure Update the db record for structure, creating that record if it doesn't already exist
""" """
self.structures.update({'_id': structure['_id']}, structure) self.structures.update({'_id': structure['_id']}, structure, upsert=True)
def get_course_index(self, key, ignore_case=False): def get_course_index(self, key, ignore_case=False):
""" """
...@@ -88,11 +106,23 @@ class MongoConnection(object): ...@@ -88,11 +106,23 @@ class MongoConnection(object):
]) ])
) )
def find_matching_course_indexes(self, query): def find_matching_course_indexes(self, branch=None, search_targets=None):
""" """
Find the course_index matching the query. Right now the query must be a legal mongo query Find the course_index matching particular conditions.
:param query: a mongo-style query of {key: [value|{$in ..}|..], ..}
Arguments:
branch: If specified, this branch must exist in the returned courses
search_targets: If specified, this must be a dictionary specifying field values
that must exist in the search_targets of the returned courses
""" """
query = son.SON()
if branch is not None:
query['versions.{}'.format(branch)] = {'$exists': True}
if search_targets:
for key, value in search_targets.iteritems():
query['search_targets.{}'.format(key)] = value
return self.course_index.find(query) return self.course_index.find(query)
def insert_course_index(self, course_index): def insert_course_index(self, course_index):
...@@ -101,13 +131,21 @@ class MongoConnection(object): ...@@ -101,13 +131,21 @@ class MongoConnection(object):
""" """
self.course_index.insert(course_index) self.course_index.insert(course_index)
def update_course_index(self, course_index): def update_course_index(self, course_index, from_index=None):
""" """
Update the db record for course_index Update the db record for course_index.
Arguments:
from_index: If set, only update an index if it matches the one specified in `from_index`.
""" """
self.course_index.update( self.course_index.update(
son.SON([('org', course_index['org']), ('course', course_index['course']), ('run', course_index['run'])]), from_index or son.SON([
course_index ('org', course_index['org']),
('course', course_index['course']),
('run', course_index['run'])
]),
course_index,
upsert=False,
) )
def delete_course_index(self, course_index): def delete_course_index(self, course_index):
......
...@@ -57,6 +57,7 @@ from path import path ...@@ -57,6 +57,7 @@ from path import path
import copy import copy
from pytz import UTC from pytz import UTC
from bson.objectid import ObjectId from bson.objectid import ObjectId
from pymongo.errors import DuplicateKeyError
from xblock.core import XBlock from xblock.core import XBlock
from xblock.fields import Scope, Reference, ReferenceList, ReferenceValueDict from xblock.fields import Scope, Reference, ReferenceList, ReferenceValueDict
...@@ -72,7 +73,6 @@ from xmodule.modulestore import ( ...@@ -72,7 +73,6 @@ from xmodule.modulestore import (
) )
from ..exceptions import ItemNotFoundError from ..exceptions import ItemNotFoundError
from .definition_lazy_loader import DefinitionLazyLoader
from .caching_descriptor_system import CachingDescriptorSystem from .caching_descriptor_system import CachingDescriptorSystem
from xmodule.modulestore.split_mongo.mongo_connection import MongoConnection from xmodule.modulestore.split_mongo.mongo_connection import MongoConnection
from xmodule.error_module import ErrorDescriptor from xmodule.error_module import ErrorDescriptor
...@@ -82,6 +82,7 @@ from types import NoneType ...@@ -82,6 +82,7 @@ from types import NoneType
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
#============================================================================== #==============================================================================
# #
# Known issue: # Known issue:
...@@ -104,7 +105,406 @@ log = logging.getLogger(__name__) ...@@ -104,7 +105,406 @@ log = logging.getLogger(__name__)
EXCLUDE_ALL = '*' EXCLUDE_ALL = '*'
class SplitMongoModuleStore(ModuleStoreWriteBase): class BulkWriteRecord(object):
def __init__(self):
self._active_count = 0
self.initial_index = None
self.index = None
self.structures = {}
self.structures_in_db = set()
@property
def active(self):
"""
Return whether this bulk write is active.
"""
return self._active_count > 0
def nest(self):
"""
Record another level of nesting of this bulk write operation
"""
self._active_count += 1
def unnest(self):
"""
Record the completion of a level of nesting of the bulk write operation
"""
self._active_count -= 1
@property
def is_root(self):
"""
Return whether the bulk write is at the root (first) level of nesting
"""
return self._active_count == 1
# TODO: This needs to track which branches have actually been modified/versioned,
# so that copying one branch to another doesn't update the original branch.
@property
def dirty_branches(self):
"""
Return a list of which branch version ids differ from what was stored
in the database at the beginning of this bulk operation.
"""
# If no course index has been set, then no branches have changed
if self.index is None:
return []
# If there was no index in the database to start with, then all branches
# are dirty by definition
if self.initial_index is None:
return self.index.get('versions', {}).keys()
# Return branches whose ids differ between self.index and self.initial_index
return [
branch
for branch, _id
in self.index.get('versions', {}).items()
if self.initial_index.get('versions', {}).get(branch) != _id
]
def structure_for_branch(self, branch):
return self.structures.get(self.index.get('versions', {}).get(branch))
def set_structure_for_branch(self, branch, structure):
self.index.get('versions', {})[branch] = structure['_id']
self.structures[structure['_id']] = structure
def __repr__(self):
return u"BulkWriteRecord<{!r}, {!r}, {!r}, {!r}, {!r}>".format(
self._active_count,
self.initial_index,
self.index,
self.structures,
self.structures_in_db,
)
class BulkWriteMixin(object):
"""
This implements the :meth:`bulk_operations` modulestore semantics for the :class:`SplitMongoModuleStore`.
In particular, it implements :meth:`_begin_bulk_operation` and
:meth:`_end_bulk_operation` to provide the external interface, and then exposes a set of methods
for interacting with course_indexes and structures that can be used by :class:`SplitMongoModuleStore`.
Internally, this mixin records the set of all active bulk operations (keyed on the active course),
and only writes those values to ``self.mongo_connection`` when :meth:`_end_bulk_operation` is called.
If a bulk write operation isn't active, then the changes are immediately written to the underlying
mongo_connection.
"""
def __init__(self, *args, **kwargs):
super(BulkWriteMixin, self).__init__(*args, **kwargs)
self._active_bulk_writes = threading.local()
def _get_bulk_write_record(self, course_key, ignore_case=False):
"""
Return the :class:`.BulkWriteRecord` for this course.
"""
if course_key is None:
return BulkWriteRecord()
if not isinstance(course_key, CourseLocator):
raise TypeError(u'{!r} is not a CourseLocator'.format(course_key))
if not hasattr(self._active_bulk_writes, 'records'):
self._active_bulk_writes.records = defaultdict(BulkWriteRecord)
# Retrieve the bulk record based on matching org/course/run (possibly ignoring case)
if course_key.org and course_key.course and course_key.run:
if ignore_case:
for key, record in self._active_bulk_writes.records.iteritems():
if (
key.org.lower() == course_key.org.lower() and
key.course.lower() == course_key.course.lower() and
key.run.lower() == course_key.run.lower()
):
return record
# If nothing matches case-insensitively, fall through to creating a new record with the passed in case
return self._active_bulk_writes.records[course_key.replace(branch=None, version_guid=None)]
else:
# If nothing org/course/run aren't set, use a bulk record that is identified just by the version_guid
return self._active_bulk_writes.records[course_key.replace(org=None, course=None, run=None, branch=None)]
@property
def _active_records(self):
"""
Yield all active (CourseLocator, BulkWriteRecord) tuples.
"""
for course_key, record in getattr(self._active_bulk_writes, 'records', {}).iteritems():
if record.active:
yield (course_key, record)
def _clear_bulk_write_record(self, course_key):
if not isinstance(course_key, CourseLocator):
raise TypeError('{!r} is not a CourseLocator'.format(course_key))
if not hasattr(self._active_bulk_writes, 'records'):
return
if course_key.org and course_key.course and course_key.run:
del self._active_bulk_writes.records[course_key.replace(branch=None, version_guid=None)]
else:
del self._active_bulk_writes.records[course_key.replace(org=None, course=None, run=None, branch=None)]
def _begin_bulk_operation(self, course_key):
"""
Begin a bulk write operation on course_key.
"""
bulk_write_record = self._get_bulk_write_record(course_key)
# Increment the number of active bulk operations (bulk operations
# on the same course can be nested)
bulk_write_record.nest()
# If this is the highest level bulk operation, then initialize it
if bulk_write_record.is_root:
bulk_write_record.initial_index = self.db_connection.get_course_index(course_key)
# Ensure that any edits to the index don't pollute the initial_index
bulk_write_record.index = copy.deepcopy(bulk_write_record.initial_index)
def _end_bulk_operation(self, course_key):
"""
End the active bulk write operation on course_key.
"""
# If no bulk write is active, return
bulk_write_record = self._get_bulk_write_record(course_key)
if not bulk_write_record.active:
return
bulk_write_record.unnest()
# If this wasn't the outermost context, then don't close out the
# bulk write operation.
if bulk_write_record.active:
return
# This is the last active bulk write. If the content is dirty,
# then update the database
for _id in bulk_write_record.structures.viewkeys() - bulk_write_record.structures_in_db:
self.db_connection.upsert_structure(bulk_write_record.structures[_id])
if bulk_write_record.index is not None and bulk_write_record.index != bulk_write_record.initial_index:
if bulk_write_record.initial_index is None:
self.db_connection.insert_course_index(bulk_write_record.index)
else:
self.db_connection.update_course_index(bulk_write_record.index, from_index=bulk_write_record.initial_index)
self._clear_bulk_write_record(course_key)
def _is_in_bulk_write_operation(self, course_key, ignore_case=False):
"""
Return whether a bulk write is active on `course_key`.
"""
return self._get_bulk_write_record(course_key, ignore_case).active
def get_course_index(self, course_key, ignore_case=False):
"""
Return the index for course_key.
"""
if self._is_in_bulk_write_operation(course_key, ignore_case):
return self._get_bulk_write_record(course_key, ignore_case).index
else:
return self.db_connection.get_course_index(course_key, ignore_case)
def insert_course_index(self, course_key, index_entry):
bulk_write_record = self._get_bulk_write_record(course_key)
if bulk_write_record.active:
bulk_write_record.index = index_entry
else:
self.db_connection.insert_course_index(index_entry)
def update_course_index(self, course_key, updated_index_entry):
"""
Change the given course's index entry.
Note, this operation can be dangerous and break running courses.
Does not return anything useful.
"""
bulk_write_record = self._get_bulk_write_record(course_key)
if bulk_write_record.active:
bulk_write_record.index = updated_index_entry
else:
self.db_connection.update_course_index(updated_index_entry)
def get_structure(self, course_key, version_guid):
bulk_write_record = self._get_bulk_write_record(course_key)
if bulk_write_record.active:
structure = bulk_write_record.structures.get(version_guid)
# The structure hasn't been loaded from the db yet, so load it
if structure is None:
structure = self.db_connection.get_structure(version_guid)
bulk_write_record.structures[version_guid] = structure
if structure is not None:
bulk_write_record.structures_in_db.add(version_guid)
return structure
else:
# cast string to ObjectId if necessary
version_guid = course_key.as_object_id(version_guid)
return self.db_connection.get_structure(version_guid)
def update_structure(self, course_key, structure):
"""
Update a course structure, respecting the current bulk operation status
(no data will be written to the database if a bulk operation is active.)
"""
self._clear_cache(structure['_id'])
bulk_write_record = self._get_bulk_write_record(course_key)
if bulk_write_record.active:
bulk_write_record.structures[structure['_id']] = structure
else:
self.db_connection.upsert_structure(structure)
def version_structure(self, course_key, structure, user_id):
"""
Copy the structure and update the history info (edited_by, edited_on, previous_version)
"""
bulk_write_record = self._get_bulk_write_record(course_key)
# If we have an active bulk write, and it's already been edited, then just use that structure
if bulk_write_record.active and course_key.branch in bulk_write_record.dirty_branches:
return bulk_write_record.structure_for_branch(course_key.branch)
# Otherwise, make a new structure
new_structure = copy.deepcopy(structure)
new_structure['_id'] = ObjectId()
new_structure['previous_version'] = structure['_id']
new_structure['edited_by'] = user_id
new_structure['edited_on'] = datetime.datetime.now(UTC)
new_structure['schema_version'] = self.SCHEMA_VERSION
# If we're in a bulk write, update the structure used there, and mark it as dirty
if bulk_write_record.active:
bulk_write_record.set_structure_for_branch(course_key.branch, new_structure)
return new_structure
def version_block(self, block_info, user_id, update_version):
"""
Update the block_info dictionary based on it having been edited
"""
if block_info['edit_info'].get('update_version') == update_version:
return
block_info['edit_info'] = {
'edited_on': datetime.datetime.now(UTC),
'edited_by': user_id,
'previous_version': block_info['edit_info']['update_version'],
'update_version': update_version,
}
def find_matching_course_indexes(self, branch=None, search_targets=None):
"""
Find the course_indexes which have the specified branch and search_targets.
"""
indexes = self.db_connection.find_matching_course_indexes(branch, search_targets)
for _, record in self._active_records:
if branch and branch not in record.index.get('versions', {}):
continue
if search_targets:
if any(
'search_targets' not in record.index or
field not in record.index['search_targets'] or
record.index['search_targets'][field] != value
for field, value in search_targets.iteritems()
):
continue
indexes.append(record.index)
return indexes
def find_structures_by_id(self, ids):
"""
Return all structures that specified in ``ids``.
If a structure with the same id is in both the cache and the database,
the cached version will be preferred.
Arguments:
ids (list): A list of structure ids
"""
structures = []
ids = set(ids)
for _, record in self._active_records:
for structure in record.structures.values():
structure_id = structure.get('_id')
if structure_id in ids:
ids.remove(structure_id)
structures.append(structure)
structures.extend(self.db_connection.find_structures_by_id(list(ids)))
return structures
def find_structures_derived_from(self, ids):
"""
Return all structures that were immediately derived from a structure listed in ``ids``.
Arguments:
ids (list): A list of structure ids
"""
found_structure_ids = set()
structures = []
for _, record in self._active_records:
for structure in record.structures.values():
if structure.get('previous_version') in ids:
structures.append(structure)
if '_id' in structure:
found_structure_ids.add(structure['_id'])
structures.extend(
structure
for structure in self.db_connection.find_structures_derived_from(ids)
if structure['_id'] not in found_structure_ids
)
return structures
def find_ancestor_structures(self, original_version, block_id):
"""
Find all structures that originated from ``original_version`` that contain ``block_id``.
Any structure found in the cache will be preferred to a structure with the same id from the database.
Arguments:
original_version (str or ObjectID): The id of a structure
block_id (str): The id of the block in question
"""
found_structure_ids = set()
structures = []
for _, record in self._active_records:
for structure in record.structures.values():
if 'original_version' not in structure:
continue
if structure['original_version'] != original_version:
continue
if block_id not in structure.get('blocks', {}):
continue
if 'update_version' not in structure['blocks'][block_id].get('edit_info', {}):
continue
structures.append(structure)
found_structure_ids.add(structure['_id'])
structures.extend(
structure
for structure in self.db_connection.find_ancestor_structures(original_version, block_id)
if structure['_id'] not in found_structure_ids
)
return structures
class SplitMongoModuleStore(BulkWriteMixin, ModuleStoreWriteBase):
""" """
A Mongodb backed ModuleStore supporting versions, inheritance, A Mongodb backed ModuleStore supporting versions, inheritance,
and sharing. and sharing.
...@@ -173,50 +573,45 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -173,50 +573,45 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
''' '''
Handles caching of items once inheritance and any other one time Handles caching of items once inheritance and any other one time
per course per fetch operations are done. per course per fetch operations are done.
:param system: a CachingDescriptorSystem
:param base_block_ids: list of block_ids to fetch
:param course_key: the destination course providing the context
:param depth: how deep below these to prefetch
:param lazy: whether to fetch definitions or use placeholders
'''
new_module_data = {}
for block_id in base_block_ids:
new_module_data = self.descendants(
system.course_entry['structure']['blocks'],
block_id,
depth,
new_module_data
)
if lazy: Arguments:
for block in new_module_data.itervalues(): system: a CachingDescriptorSystem
block['definition'] = DefinitionLazyLoader( base_block_ids: list of block_ids to fetch
self, block['category'], block['definition'], course_key: the destination course providing the context
lambda fields: self.convert_references_to_keys( depth: how deep below these to prefetch
course_key, system.load_block_type(block['category']), lazy: whether to fetch definitions or use placeholders
fields, system.course_entry['structure']['blocks'], '''
) with self.bulk_operations(course_key):
new_module_data = {}
for block_id in base_block_ids:
new_module_data = self.descendants(
system.course_entry['structure']['blocks'],
block_id,
depth,
new_module_data
) )
else:
# Load all descendants by id
descendent_definitions = self.db_connection.find_matching_definitions({
'_id': {'$in': [block['definition']
for block in new_module_data.itervalues()]}})
# turn into a map
definitions = {definition['_id']: definition
for definition in descendent_definitions}
for block in new_module_data.itervalues():
if block['definition'] in definitions:
converted_fields = self.convert_references_to_keys(
course_key, system.load_block_type(block['category']),
definitions[block['definition']].get('fields'),
system.course_entry['structure']['blocks'],
)
block['fields'].update(converted_fields)
system.module_data.update(new_module_data) if not lazy:
return system.module_data # Load all descendants by id
descendent_definitions = self.db_connection.find_matching_definitions({
'_id': {'$in': [block['definition']
for block in new_module_data.itervalues()]}})
# turn into a map
definitions = {definition['_id']: definition
for definition in descendent_definitions}
for block in new_module_data.itervalues():
if block['definition'] in definitions:
converted_fields = self.convert_references_to_keys(
course_key, system.load_block_type(block['category']),
definitions[block['definition']].get('fields'),
system.course_entry['structure']['blocks'],
)
block['fields'].update(converted_fields)
block['definition_loaded'] = True
system.module_data.update(new_module_data)
return system.module_data
def _load_items(self, course_entry, block_ids, depth=0, lazy=True, **kwargs): def _load_items(self, course_entry, block_ids, depth=0, lazy=True, **kwargs):
''' '''
...@@ -265,6 +660,8 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -265,6 +660,8 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
:param course_version_guid: if provided, clear only this entry :param course_version_guid: if provided, clear only this entry
""" """
if course_version_guid: if course_version_guid:
if not hasattr(self.thread_cache, 'course_cache'):
self.thread_cache.course_cache = {}
try: try:
del self.thread_cache.course_cache[course_version_guid] del self.thread_cache.course_cache[course_version_guid]
except KeyError: except KeyError:
...@@ -272,7 +669,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -272,7 +669,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
else: else:
self.thread_cache.course_cache = {} self.thread_cache.course_cache = {}
def _lookup_course(self, course_locator): def _lookup_course(self, course_key):
''' '''
Decode the locator into the right series of db access. Does not Decode the locator into the right series of db access. Does not
return the CourseDescriptor! It returns the actual db json from return the CourseDescriptor! It returns the actual db json from
...@@ -283,45 +680,50 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -283,45 +680,50 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
it raises VersionConflictError (the version now differs from what it was when you got your it raises VersionConflictError (the version now differs from what it was when you got your
reference) reference)
:param course_locator: any subclass of CourseLocator :param course_key: any subclass of CourseLocator
''' '''
if course_locator.org and course_locator.course and course_locator.run: if course_key.org and course_key.course and course_key.run:
if course_locator.branch is None: if course_key.branch is None:
raise InsufficientSpecificationError(course_locator) raise InsufficientSpecificationError(course_key)
# use the course id # use the course id
index = self.db_connection.get_course_index(course_locator) index = self.get_course_index(course_key)
if index is None: if index is None:
raise ItemNotFoundError(course_locator) raise ItemNotFoundError(course_key)
if course_locator.branch not in index['versions']: if course_key.branch not in index['versions']:
raise ItemNotFoundError(course_locator) raise ItemNotFoundError(course_key)
version_guid = index['versions'][course_locator.branch]
if course_locator.version_guid is not None and version_guid != course_locator.version_guid: version_guid = index['versions'][course_key.branch]
if course_key.version_guid is not None and version_guid != course_key.version_guid:
# This may be a bit too touchy but it's hard to infer intent # This may be a bit too touchy but it's hard to infer intent
raise VersionConflictError(course_locator, version_guid) raise VersionConflictError(course_key, version_guid)
elif course_locator.version_guid is None:
raise InsufficientSpecificationError(course_locator) elif course_key.version_guid is None:
raise InsufficientSpecificationError(course_key)
else: else:
# TODO should this raise an exception if branch was provided? # TODO should this raise an exception if branch was provided?
version_guid = course_locator.version_guid version_guid = course_key.version_guid
# cast string to ObjectId if necessary entry = self.get_structure(course_key, version_guid)
version_guid = course_locator.as_object_id(version_guid) if entry is None:
entry = self.db_connection.get_structure(version_guid) raise ItemNotFoundError('Structure: {}'.format(version_guid))
# b/c more than one course can use same structure, the 'org', 'course', # b/c more than one course can use same structure, the 'org', 'course',
# 'run', and 'branch' are not intrinsic to structure # 'run', and 'branch' are not intrinsic to structure
# and the one assoc'd w/ it by another fetch may not be the one relevant to this fetch; so, # and the one assoc'd w/ it by another fetch may not be the one relevant to this fetch; so,
# add it in the envelope for the structure. # add it in the envelope for the structure.
envelope = { envelope = {
'org': course_locator.org, 'org': course_key.org,
'course': course_locator.course, 'course': course_key.course,
'run': course_locator.run, 'run': course_key.run,
'branch': course_locator.branch, 'branch': course_key.branch,
'structure': entry, 'structure': entry,
} }
return envelope return envelope
def get_courses(self, branch, qualifiers=None, **kwargs): def get_courses(self, branch, **kwargs):
''' '''
Returns a list of course descriptors matching any given qualifiers. Returns a list of course descriptors matching any given qualifiers.
...@@ -332,12 +734,8 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -332,12 +734,8 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
To get specific versions via guid use get_course. To get specific versions via guid use get_course.
:param branch: the branch for which to return courses. :param branch: the branch for which to return courses.
:param qualifiers: an optional dict restricting which elements should match
''' '''
if qualifiers is None: matching_indexes = self.find_matching_course_indexes(branch)
qualifiers = {}
qualifiers.update({"versions.{}".format(branch): {"$exists": True}})
matching_indexes = self.db_connection.find_matching_course_indexes(qualifiers)
# collect ids and then query for those # collect ids and then query for those
version_guids = [] version_guids = []
...@@ -347,7 +745,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -347,7 +745,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
version_guids.append(version_guid) version_guids.append(version_guid)
id_version_map[version_guid] = course_index id_version_map[version_guid] = course_index
matching_structures = self.db_connection.find_matching_structures({'_id': {'$in': version_guids}}) matching_structures = self.find_structures_by_id(version_guids)
# get the blocks for each course index (s/b the root) # get the blocks for each course index (s/b the root)
result = [] result = []
...@@ -401,7 +799,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -401,7 +799,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
# The supplied CourseKey is of the wrong type, so it can't possibly be stored in this modulestore. # The supplied CourseKey is of the wrong type, so it can't possibly be stored in this modulestore.
return False return False
course_index = self.db_connection.get_course_index(course_id, ignore_case) course_index = self.get_course_index(course_id, ignore_case)
return CourseLocator(course_index['org'], course_index['course'], course_index['run'], course_id.branch) if course_index else None return CourseLocator(course_index['org'], course_index['course'], course_index['run'], course_id.branch) if course_index else None
def has_item(self, usage_key): def has_item(self, usage_key):
...@@ -413,7 +811,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -413,7 +811,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
if usage_key.block_id is None: if usage_key.block_id is None:
raise InsufficientSpecificationError(usage_key) raise InsufficientSpecificationError(usage_key)
try: try:
course_structure = self._lookup_course(usage_key)['structure'] course_structure = self._lookup_course(usage_key.course_key)['structure']
except ItemNotFoundError: except ItemNotFoundError:
# this error only occurs if the course does not exist # this error only occurs if the course does not exist
return False return False
...@@ -433,7 +831,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -433,7 +831,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
# The supplied UsageKey is of the wrong type, so it can't possibly be stored in this modulestore. # The supplied UsageKey is of the wrong type, so it can't possibly be stored in this modulestore.
raise ItemNotFoundError(usage_key) raise ItemNotFoundError(usage_key)
course = self._lookup_course(usage_key) course = self._lookup_course(usage_key.course_key)
items = self._load_items(course, [usage_key.block_id], depth, lazy=True, **kwargs) items = self._load_items(course, [usage_key.block_id], depth, lazy=True, **kwargs)
if len(items) == 0: if len(items) == 0:
raise ItemNotFoundError(usage_key) raise ItemNotFoundError(usage_key)
...@@ -511,7 +909,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -511,7 +909,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
:param locator: BlockUsageLocator restricting search scope :param locator: BlockUsageLocator restricting search scope
''' '''
course = self._lookup_course(locator) course = self._lookup_course(locator.course_key)
parent_id = self._get_parent_from_structure(locator.block_id, course['structure']) parent_id = self._get_parent_from_structure(locator.block_id, course['structure'])
if parent_id is None: if parent_id is None:
return None return None
...@@ -541,12 +939,12 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -541,12 +939,12 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
for block_id in items for block_id in items
] ]
def get_course_index_info(self, course_locator): def get_course_index_info(self, course_key):
""" """
The index records the initial creation of the indexed course and tracks the current version The index records the initial creation of the indexed course and tracks the current version
heads. This function is primarily for test verification but may serve some heads. This function is primarily for test verification but may serve some
more general purpose. more general purpose.
:param course_locator: must have a org, course, and run set :param course_key: must have a org, course, and run set
:return {'org': string, :return {'org': string,
versions: {'draft': the head draft version id, versions: {'draft': the head draft version id,
'published': the head published version id if any, 'published': the head published version id if any,
...@@ -555,24 +953,24 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -555,24 +953,24 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
'edited_on': when the course was originally created 'edited_on': when the course was originally created
} }
""" """
if not (course_locator.course and course_locator.run and course_locator.org): if not (course_key.course and course_key.run and course_key.org):
return None return None
index = self.db_connection.get_course_index(course_locator) index = self.get_course_index(course_key)
return index return index
# TODO figure out a way to make this info accessible from the course descriptor # TODO figure out a way to make this info accessible from the course descriptor
def get_course_history_info(self, course_locator): def get_course_history_info(self, course_key):
""" """
Because xblocks doesn't give a means to separate the course structure's meta information from Because xblocks doesn't give a means to separate the course structure's meta information from
the course xblock's, this method will get that info for the structure as a whole. the course xblock's, this method will get that info for the structure as a whole.
:param course_locator: :param course_key:
:return {'original_version': the version guid of the original version of this course, :return {'original_version': the version guid of the original version of this course,
'previous_version': the version guid of the previous version, 'previous_version': the version guid of the previous version,
'edited_by': who made the last change, 'edited_by': who made the last change,
'edited_on': when the change was made 'edited_on': when the change was made
} }
""" """
course = self._lookup_course(course_locator)['structure'] course = self._lookup_course(course_key)['structure']
return { return {
'original_version': course['original_version'], 'original_version': course['original_version'],
'previous_version': course['previous_version'], 'previous_version': course['previous_version'],
...@@ -613,22 +1011,20 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -613,22 +1011,20 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
# TODO if depth is significant, it may make sense to get all that have the same original_version # TODO if depth is significant, it may make sense to get all that have the same original_version
# and reconstruct the subtree from version_guid # and reconstruct the subtree from version_guid
next_entries = self.db_connection.find_matching_structures({'previous_version': version_guid}) next_entries = self.find_structures_derived_from([version_guid])
# must only scan cursor's once # must only scan cursor's once
next_versions = [struct for struct in next_entries] next_versions = [struct for struct in next_entries]
result = {version_guid: [CourseLocator(version_guid=struct['_id']) for struct in next_versions]} result = {version_guid: [CourseLocator(version_guid=struct['_id']) for struct in next_versions]}
depth = 1 depth = 1
while depth < version_history_depth and len(next_versions) > 0: while depth < version_history_depth and len(next_versions) > 0:
depth += 1 depth += 1
next_entries = self.db_connection.find_matching_structures({'previous_version': next_entries = self.find_structures_derived_from([struct['_id'] for struct in next_versions])
{'$in': [struct['_id'] for struct in next_versions]}})
next_versions = [struct for struct in next_entries] next_versions = [struct for struct in next_entries]
for course_structure in next_versions: for course_structure in next_versions:
result.setdefault(course_structure['previous_version'], []).append( result.setdefault(course_structure['previous_version'], []).append(
CourseLocator(version_guid=struct['_id'])) CourseLocator(version_guid=struct['_id']))
return VersionTree(course_locator, result) return VersionTree(course_locator, result)
def get_block_generations(self, block_locator): def get_block_generations(self, block_locator):
''' '''
Find the history of this block. Return as a VersionTree of each place the block changed (except Find the history of this block. Return as a VersionTree of each place the block changed (except
...@@ -639,14 +1035,11 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -639,14 +1035,11 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
''' '''
# course_agnostic means we don't care if the head and version don't align, trust the version # course_agnostic means we don't care if the head and version don't align, trust the version
course_struct = self._lookup_course(block_locator.course_agnostic())['structure'] course_struct = self._lookup_course(block_locator.course_key.course_agnostic())['structure']
block_id = block_locator.block_id block_id = block_locator.block_id
update_version_field = 'blocks.{}.edit_info.update_version'.format(block_id) all_versions_with_block = self.find_ancestor_structures(
all_versions_with_block = self.db_connection.find_matching_structures( original_version=course_struct['original_version'],
{ block_id=block_id
'original_version': course_struct['original_version'],
update_version_field: {'$exists': True},
}
) )
# find (all) root versions and build map {previous: {successors}..} # find (all) root versions and build map {previous: {successors}..}
possible_roots = [] possible_roots = []
...@@ -772,7 +1165,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -772,7 +1165,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
def create_item( def create_item(
self, user_id, course_key, block_type, block_id=None, self, user_id, course_key, block_type, block_id=None,
definition_locator=None, fields=None, definition_locator=None, fields=None,
force=False, continue_version=False, **kwargs force=False, **kwargs
): ):
""" """
Add a descriptor to persistence as an element Add a descriptor to persistence as an element
...@@ -799,10 +1192,8 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -799,10 +1192,8 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
:param block_id: if provided, must not already exist in the structure. Provides the block id for the :param block_id: if provided, must not already exist in the structure. Provides the block id for the
new item in this structure. Otherwise, one is computed using the category appended w/ a few digits. new item in this structure. Otherwise, one is computed using the category appended w/ a few digits.
:param continue_version: continue changing the current structure at the head of the course. Very dangerous This method creates a new version of the course structure unless the course has a bulk_write operation
unless used in the same request as started the change! See below about version conflicts. active.
This method creates a new version of the course structure unless continue_version is True.
It creates and inserts the new block, makes the block point It creates and inserts the new block, makes the block point
to the definition which may be new or a new version of an existing or an existing. to the definition which may be new or a new version of an existing or an existing.
...@@ -818,91 +1209,77 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -818,91 +1209,77 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
the course id'd by version_guid but instead in one w/ a new version_guid. Ensure in this case that you get the course id'd by version_guid but instead in one w/ a new version_guid. Ensure in this case that you get
the new version_guid from the locator in the returned object! the new version_guid from the locator in the returned object!
""" """
# split handles all the fields in one dict not separated by scope with self.bulk_operations(course_key):
fields = fields or {} # split handles all the fields in one dict not separated by scope
fields.update(kwargs.pop('metadata', {}) or {}) fields = fields or {}
definition_data = kwargs.pop('definition_data', {}) fields.update(kwargs.pop('metadata', {}) or {})
if definition_data: definition_data = kwargs.pop('definition_data', {})
if not isinstance(definition_data, dict): if definition_data:
definition_data = {'data': definition_data} # backward compatibility to mongo's hack if not isinstance(definition_data, dict):
fields.update(definition_data) definition_data = {'data': definition_data} # backward compatibility to mongo's hack
fields.update(definition_data)
# find course_index entry if applicable and structures entry
index_entry = self._get_index_if_valid(course_key, force, continue_version) # find course_index entry if applicable and structures entry
structure = self._lookup_course(course_key)['structure'] index_entry = self._get_index_if_valid(course_key, force)
structure = self._lookup_course(course_key)['structure']
partitioned_fields = self.partition_fields_by_scope(block_type, fields)
new_def_data = partitioned_fields.get(Scope.content, {}) partitioned_fields = self.partition_fields_by_scope(block_type, fields)
# persist the definition if persisted != passed new_def_data = partitioned_fields.get(Scope.content, {})
if (definition_locator is None or isinstance(definition_locator.definition_id, LocalId)): # persist the definition if persisted != passed
definition_locator = self.create_definition_from_data(new_def_data, block_type, user_id) if (definition_locator is None or isinstance(definition_locator.definition_id, LocalId)):
elif new_def_data is not None: definition_locator = self.create_definition_from_data(new_def_data, block_type, user_id)
definition_locator, _ = self.update_definition_from_data(definition_locator, new_def_data, user_id) elif new_def_data is not None:
definition_locator, _ = self.update_definition_from_data(definition_locator, new_def_data, user_id)
# copy the structure and modify the new one
if continue_version: # copy the structure and modify the new one
new_structure = structure new_structure = self.version_structure(course_key, structure, user_id)
else:
new_structure = self._version_structure(structure, user_id)
new_id = new_structure['_id'] new_id = new_structure['_id']
edit_info = { # generate usage id
'edited_on': datetime.datetime.now(UTC), if block_id is not None:
'edited_by': user_id, if encode_key_for_mongo(block_id) in new_structure['blocks']:
'previous_version': None, raise DuplicateItemError(block_id, self, 'structures')
'update_version': new_id, else:
} new_block_id = block_id
# generate usage id
if block_id is not None:
if encode_key_for_mongo(block_id) in new_structure['blocks']:
raise DuplicateItemError(block_id, self, 'structures')
else: else:
new_block_id = block_id new_block_id = self._generate_block_id(new_structure['blocks'], block_type)
else:
new_block_id = self._generate_block_id(new_structure['blocks'], block_type)
block_fields = partitioned_fields.get(Scope.settings, {}) block_fields = partitioned_fields.get(Scope.settings, {})
if Scope.children in partitioned_fields: if Scope.children in partitioned_fields:
block_fields.update(partitioned_fields[Scope.children]) block_fields.update(partitioned_fields[Scope.children])
self._update_block_in_structure(new_structure, new_block_id, { self._update_block_in_structure(new_structure, new_block_id, self._new_block(
"category": block_type, user_id,
"definition": definition_locator.definition_id, block_type,
"fields": self._serialize_fields(block_type, block_fields), block_fields,
'edit_info': edit_info, definition_locator.definition_id,
}) new_id,
))
if continue_version:
# db update
self.db_connection.update_structure(new_structure)
# clear cache so things get refetched and inheritance recomputed
self._clear_cache(new_id)
else:
self.db_connection.insert_structure(new_structure)
# update the index entry if appropriate self.update_structure(course_key, new_structure)
if index_entry is not None:
# see if any search targets changed
if fields is not None:
self._update_search_targets(index_entry, fields)
if not continue_version:
self._update_head(index_entry, course_key.branch, new_id)
item_loc = BlockUsageLocator(
course_key.version_agnostic(),
block_type=block_type,
block_id=new_block_id,
)
else:
item_loc = BlockUsageLocator(
CourseLocator(version_guid=new_id),
block_type=block_type,
block_id=new_block_id,
)
# reconstruct the new_item from the cache # update the index entry if appropriate
return self.get_item(item_loc) if index_entry is not None:
# see if any search targets changed
if fields is not None:
self._update_search_targets(index_entry, fields)
self._update_head(course_key, index_entry, course_key.branch, new_id)
item_loc = BlockUsageLocator(
course_key.version_agnostic(),
block_type=block_type,
block_id=new_block_id,
)
else:
item_loc = BlockUsageLocator(
CourseLocator(version_guid=new_id),
block_type=block_type,
block_id=new_block_id,
)
# reconstruct the new_item from the cache
return self.get_item(item_loc)
def create_child(self, user_id, parent_usage_key, block_type, block_id=None, fields=None, continue_version=False, **kwargs): def create_child(self, user_id, parent_usage_key, block_type, block_id=None, fields=None, **kwargs):
""" """
Creates and saves a new xblock that as a child of the specified block Creates and saves a new xblock that as a child of the specified block
...@@ -918,32 +1295,28 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -918,32 +1295,28 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
fields (dict): A dictionary specifying initial values for some or all fields fields (dict): A dictionary specifying initial values for some or all fields
in the newly created block in the newly created block
""" """
xblock = self.create_item( with self.bulk_operations(parent_usage_key.course_key):
user_id, parent_usage_key.course_key, block_type, block_id=block_id, fields=fields, xblock = self.create_item(
continue_version=continue_version, user_id, parent_usage_key.course_key, block_type, block_id=block_id, fields=fields,
**kwargs) **kwargs)
# don't version the structure as create_item handled that already. # don't version the structure as create_item handled that already.
new_structure = self._lookup_course(xblock.location.course_key)['structure'] new_structure = self._lookup_course(xblock.location.course_key)['structure']
# add new block as child and update parent's version # add new block as child and update parent's version
encoded_block_id = encode_key_for_mongo(parent_usage_key.block_id) encoded_block_id = encode_key_for_mongo(parent_usage_key.block_id)
parent = new_structure['blocks'][encoded_block_id] if encoded_block_id not in new_structure['blocks']:
parent['fields'].setdefault('children', []).append(xblock.location.block_id) raise ItemNotFoundError(parent_usage_key)
if parent['edit_info']['update_version'] != new_structure['_id']:
# if the parent hadn't been previously changed in this bulk transaction, indicate that it's
# part of the bulk transaction
parent['edit_info'] = {
'edited_on': datetime.datetime.now(UTC),
'edited_by': user_id,
'previous_version': parent['edit_info']['update_version'],
'update_version': new_structure['_id'],
}
# db update parent = new_structure['blocks'][encoded_block_id]
self.db_connection.update_structure(new_structure) parent['fields'].setdefault('children', []).append(xblock.location.block_id)
# clear cache so things get refetched and inheritance recomputed if parent['edit_info']['update_version'] != new_structure['_id']:
self._clear_cache(new_structure['_id']) # if the parent hadn't been previously changed in this bulk transaction, indicate that it's
# part of the bulk transaction
self.version_block(parent, user_id, new_structure['_id'])
# db update
self.update_structure(parent_usage_key.course_key, new_structure)
# don't need to update the index b/c create_item did it for this version # don't need to update the index b/c create_item did it for this version
return xblock return xblock
...@@ -1023,7 +1396,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1023,7 +1396,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
assert master_branch is not None assert master_branch is not None
# check course and run's uniqueness # check course and run's uniqueness
locator = CourseLocator(org=org, course=course, run=run, branch=master_branch) locator = CourseLocator(org=org, course=course, run=run, branch=master_branch)
index = self.db_connection.get_course_index(locator) index = self.get_course_index(locator)
if index is not None: if index is not None:
raise DuplicateCourseError(locator, index) raise DuplicateCourseError(locator, index)
...@@ -1061,18 +1434,16 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1061,18 +1434,16 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
) )
new_id = draft_structure['_id'] new_id = draft_structure['_id']
self.db_connection.insert_structure(draft_structure)
if versions_dict is None: if versions_dict is None:
versions_dict = {master_branch: new_id} versions_dict = {master_branch: new_id}
else: else:
versions_dict[master_branch] = new_id versions_dict[master_branch] = new_id
elif definition_fields or block_fields: # pointing to existing course w/ some overrides elif block_fields or definition_fields: # pointing to existing course w/ some overrides
# just get the draft_version structure # just get the draft_version structure
draft_version = CourseLocator(version_guid=versions_dict[master_branch]) draft_version = CourseLocator(version_guid=versions_dict[master_branch])
draft_structure = self._lookup_course(draft_version)['structure'] draft_structure = self._lookup_course(draft_version)['structure']
draft_structure = self._version_structure(draft_structure, user_id) draft_structure = self.version_structure(locator, draft_structure, user_id)
new_id = draft_structure['_id'] new_id = draft_structure['_id']
encoded_block_id = encode_key_for_mongo(draft_structure['root']) encoded_block_id = encode_key_for_mongo(draft_structure['root'])
root_block = draft_structure['blocks'][encoded_block_id] root_block = draft_structure['blocks'][encoded_block_id]
...@@ -1093,27 +1464,33 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1093,27 +1464,33 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
root_block['edit_info']['previous_version'] = root_block['edit_info'].get('update_version') root_block['edit_info']['previous_version'] = root_block['edit_info'].get('update_version')
root_block['edit_info']['update_version'] = new_id root_block['edit_info']['update_version'] = new_id
self.db_connection.insert_structure(draft_structure)
versions_dict[master_branch] = new_id versions_dict[master_branch] = new_id
else: # Pointing to an existing course structure
new_id = versions_dict[master_branch]
draft_version = CourseLocator(version_guid=new_id)
draft_structure = self._lookup_course(draft_version)['structure']
index_entry = { locator = locator.replace(version_guid=new_id)
'_id': ObjectId(), with self.bulk_operations(locator):
'org': org, self.update_structure(locator, draft_structure)
'course': course, index_entry = {
'run': run, '_id': ObjectId(),
'edited_by': user_id, 'org': org,
'edited_on': datetime.datetime.now(UTC), 'course': course,
'versions': versions_dict, 'run': run,
'schema_version': self.SCHEMA_VERSION, 'edited_by': user_id,
'search_targets': search_targets or {}, 'edited_on': datetime.datetime.now(UTC),
} 'versions': versions_dict,
if fields is not None: 'schema_version': self.SCHEMA_VERSION,
self._update_search_targets(index_entry, fields) 'search_targets': search_targets or {},
self.db_connection.insert_course_index(index_entry) }
if fields is not None:
self._update_search_targets(index_entry, fields)
self.insert_course_index(locator, index_entry)
# expensive hack to persist default field values set in __init__ method (e.g., wiki_slug) # expensive hack to persist default field values set in __init__ method (e.g., wiki_slug)
course = self.get_course(locator, **kwargs) course = self.get_course(locator, **kwargs)
return self.update_item(course, user_id, **kwargs) return self.update_item(course, user_id, **kwargs)
def update_item(self, descriptor, user_id, allow_not_found=False, force=False, **kwargs): def update_item(self, descriptor, user_id, allow_not_found=False, force=False, **kwargs):
""" """
...@@ -1143,87 +1520,83 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1143,87 +1520,83 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
""" """
Broke out guts of update_item for short-circuited internal use only Broke out guts of update_item for short-circuited internal use only
""" """
if allow_not_found and isinstance(block_id, (LocalId, NoneType)): with self.bulk_operations(course_key):
fields = {} if allow_not_found and isinstance(block_id, (LocalId, NoneType)):
for subfields in partitioned_fields.itervalues():
fields.update(subfields)
return self.create_item(
user_id, course_key, block_type, fields=fields, force=force
)
original_structure = self._lookup_course(course_key)['structure']
index_entry = self._get_index_if_valid(course_key, force)
original_entry = self._get_block_from_structure(original_structure, block_id)
if original_entry is None:
if allow_not_found:
fields = {} fields = {}
for subfields in partitioned_fields.itervalues(): for subfields in partitioned_fields.itervalues():
fields.update(subfields) fields.update(subfields)
return self.create_item( return self.create_item(
user_id, course_key, block_type, block_id=block_id, fields=fields, force=force, user_id, course_key, block_type, fields=fields, force=force
) )
else:
raise ItemNotFoundError(course_key.make_usage_key(block_type, block_id))
is_updated = False original_structure = self._lookup_course(course_key)['structure']
definition_fields = partitioned_fields[Scope.content] index_entry = self._get_index_if_valid(course_key, force)
if definition_locator is None:
definition_locator = DefinitionLocator(original_entry['category'], original_entry['definition']) original_entry = self._get_block_from_structure(original_structure, block_id)
if definition_fields: if original_entry is None:
definition_locator, is_updated = self.update_definition_from_data( if allow_not_found:
definition_locator, definition_fields, user_id fields = {}
) for subfields in partitioned_fields.itervalues():
fields.update(subfields)
# check metadata return self.create_item(
settings = partitioned_fields[Scope.settings] user_id, course_key, block_type, block_id=block_id, fields=fields, force=force,
settings = self._serialize_fields(block_type, settings) )
if not is_updated: else:
is_updated = self._compare_settings(settings, original_entry['fields']) raise ItemNotFoundError(course_key.make_usage_key(block_type, block_id))
is_updated = False
definition_fields = partitioned_fields[Scope.content]
if definition_locator is None:
definition_locator = DefinitionLocator(original_entry['category'], original_entry['definition'])
if definition_fields:
definition_locator, is_updated = self.update_definition_from_data(
definition_locator, definition_fields, user_id
)
# check children # check metadata
if partitioned_fields.get(Scope.children, {}): # purposely not 'is not None' settings = partitioned_fields[Scope.settings]
serialized_children = [child.block_id for child in partitioned_fields[Scope.children]['children']] settings = self._serialize_fields(block_type, settings)
is_updated = is_updated or original_entry['fields'].get('children', []) != serialized_children if not is_updated:
if is_updated: is_updated = self._compare_settings(settings, original_entry['fields'])
settings['children'] = serialized_children
# if updated, rev the structure # check children
if is_updated: if partitioned_fields.get(Scope.children, {}): # purposely not 'is not None'
new_structure = self._version_structure(original_structure, user_id) serialized_children = [child.block_id for child in partitioned_fields[Scope.children]['children']]
block_data = self._get_block_from_structure(new_structure, block_id) is_updated = is_updated or original_entry['fields'].get('children', []) != serialized_children
if is_updated:
settings['children'] = serialized_children
block_data["definition"] = definition_locator.definition_id # if updated, rev the structure
block_data["fields"] = settings if is_updated:
new_structure = self.version_structure(course_key, original_structure, user_id)
block_data = self._get_block_from_structure(new_structure, block_id)
block_data["definition"] = definition_locator.definition_id
block_data["fields"] = settings
new_id = new_structure['_id']
self.version_block(block_data, user_id, new_id)
self.update_structure(course_key, new_structure)
# update the index entry if appropriate
if index_entry is not None:
self._update_search_targets(index_entry, definition_fields)
self._update_search_targets(index_entry, settings)
course_key = CourseLocator(
org=index_entry['org'],
course=index_entry['course'],
run=index_entry['run'],
branch=course_key.branch,
version_guid=new_id
)
self._update_head(course_key, index_entry, course_key.branch, new_id)
else:
course_key = CourseLocator(version_guid=new_id)
new_id = new_structure['_id'] # fetch and return the new item--fetching is unnecessary but a good qc step
block_data['edit_info'] = { new_locator = course_key.make_usage_key(block_type, block_id)
'edited_on': datetime.datetime.now(UTC), return self.get_item(new_locator, **kwargs)
'edited_by': user_id,
'previous_version': block_data['edit_info']['update_version'],
'update_version': new_id,
}
self.db_connection.insert_structure(new_structure)
# update the index entry if appropriate
if index_entry is not None:
self._update_search_targets(index_entry, definition_fields)
self._update_search_targets(index_entry, settings)
self._update_head(index_entry, course_key.branch, new_id)
course_key = CourseLocator(
org=index_entry['org'],
course=index_entry['course'],
run=index_entry['run'],
branch=course_key.branch,
version_guid=new_id
)
else: else:
course_key = CourseLocator(version_guid=new_id) return None
# fetch and return the new item--fetching is unnecessary but a good qc step
new_locator = course_key.make_usage_key(block_type, block_id)
return self.get_item(new_locator, **kwargs)
else:
return None
# pylint: disable=unused-argument # pylint: disable=unused-argument
def create_xblock( def create_xblock(
...@@ -1286,23 +1659,25 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1286,23 +1659,25 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
:param user_id: who's doing the change :param user_id: who's doing the change
""" """
# find course_index entry if applicable and structures entry # find course_index entry if applicable and structures entry
index_entry = self._get_index_if_valid(xblock.location, force) course_key = xblock.location.course_key
structure = self._lookup_course(xblock.location)['structure'] with self.bulk_operations(course_key):
new_structure = self._version_structure(structure, user_id) index_entry = self._get_index_if_valid(course_key, force)
new_id = new_structure['_id'] structure = self._lookup_course(course_key)['structure']
is_updated = self._persist_subdag(xblock, user_id, new_structure['blocks'], new_id) new_structure = self.version_structure(course_key, structure, user_id)
new_id = new_structure['_id']
is_updated = self._persist_subdag(xblock, user_id, new_structure['blocks'], new_id)
if is_updated: if is_updated:
self.db_connection.insert_structure(new_structure) self.update_structure(course_key, new_structure)
# update the index entry if appropriate # update the index entry if appropriate
if index_entry is not None: if index_entry is not None:
self._update_head(index_entry, xblock.location.branch, new_id) self._update_head(course_key, index_entry, xblock.location.branch, new_id)
# fetch and return the new item--fetching is unnecessary but a good qc step # fetch and return the new item--fetching is unnecessary but a good qc step
return self.get_item(xblock.location.for_version(new_id)) return self.get_item(xblock.location.for_version(new_id))
else: else:
return xblock return xblock
def _persist_subdag(self, xblock, user_id, structure_blocks, new_id): def _persist_subdag(self, xblock, user_id, structure_blocks, new_id):
# persist the definition if persisted != passed # persist the definition if persisted != passed
...@@ -1350,18 +1725,22 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1350,18 +1725,22 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
block_fields['children'] = children block_fields['children'] = children
if is_updated: if is_updated:
previous_version = None if is_new else structure_blocks[encoded_block_id]['edit_info'].get('update_version') if is_new:
structure_blocks[encoded_block_id] = { block_info = self._new_block(
"category": xblock.category, user_id,
"definition": xblock.definition_locator.definition_id, xblock.category,
"fields": block_fields, block_fields,
'edit_info': { xblock.definition_locator.definition_id,
'previous_version': previous_version, new_id,
'update_version': new_id, raw=True
'edited_by': user_id, )
'edited_on': datetime.datetime.now(UTC) else:
} block_info = structure_blocks[encoded_block_id]
} block_info['fields'] = block_fields
block_info['definition'] = xblock.definition_locator.definition_id
self.version_block(block_info, user_id, new_id)
structure_blocks[encoded_block_id] = block_info
return is_updated return is_updated
...@@ -1415,71 +1794,63 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1415,71 +1794,63 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
subtree but the ancestors up to and including the course root are not published. subtree but the ancestors up to and including the course root are not published.
""" """
# get the destination's index, and source and destination structures. # get the destination's index, and source and destination structures.
source_structure = self._lookup_course(source_course)['structure'] with self.bulk_operations(source_course):
index_entry = self.db_connection.get_course_index(destination_course) with self.bulk_operations(destination_course):
if index_entry is None: source_structure = self._lookup_course(source_course)['structure']
# brand new course index_entry = self.get_course_index(destination_course)
raise ItemNotFoundError(destination_course) if index_entry is None:
if destination_course.branch not in index_entry['versions']: # brand new course
# must be copying the dag root if there's no current dag raise ItemNotFoundError(destination_course)
root_block_id = source_structure['root'] if destination_course.branch not in index_entry['versions']:
if not any(root_block_id == subtree.block_id for subtree in subtree_list): # must be copying the dag root if there's no current dag
raise ItemNotFoundError(u'Must publish course root {}'.format(root_block_id)) root_block_id = source_structure['root']
root_source = source_structure['blocks'][root_block_id] if not any(root_block_id == subtree.block_id for subtree in subtree_list):
# create branch raise ItemNotFoundError(u'Must publish course root {}'.format(root_block_id))
destination_structure = self._new_structure( root_source = source_structure['blocks'][root_block_id]
user_id, root_block_id, root_category=root_source['category'], # create branch
# leave off the fields b/c the children must be filtered destination_structure = self._new_structure(
definition_id=root_source['definition'], user_id, root_block_id, root_category=root_source['category'],
) # leave off the fields b/c the children must be filtered
else: definition_id=root_source['definition'],
destination_structure = self._lookup_course(destination_course)['structure'] )
destination_structure = self._version_structure(destination_structure, user_id) else:
destination_structure = self._lookup_course(destination_course)['structure']
if blacklist != EXCLUDE_ALL: destination_structure = self.version_structure(destination_course, destination_structure, user_id)
blacklist = [shunned.block_id for shunned in blacklist or []]
# iterate over subtree list filtering out blacklist. if blacklist != EXCLUDE_ALL:
orphans = set() blacklist = [shunned.block_id for shunned in blacklist or []]
destination_blocks = destination_structure['blocks'] # iterate over subtree list filtering out blacklist.
for subtree_root in subtree_list: orphans = set()
if subtree_root.block_id != source_structure['root']: destination_blocks = destination_structure['blocks']
# find the parents and put root in the right sequence for subtree_root in subtree_list:
parent = self._get_parent_from_structure(subtree_root.block_id, source_structure) if subtree_root.block_id != source_structure['root']:
if parent is not None: # may be a detached category xblock # find the parents and put root in the right sequence
if not parent in destination_blocks: parent = self._get_parent_from_structure(subtree_root.block_id, source_structure)
raise ItemNotFoundError(parent) if parent is not None: # may be a detached category xblock
if not parent in destination_blocks:
raise ItemNotFoundError(parent)
orphans.update(
self._sync_children(
source_structure['blocks'][parent],
destination_blocks[parent],
subtree_root.block_id
)
)
# update/create the subtree and its children in destination (skipping blacklist)
orphans.update( orphans.update(
self._sync_children( self._copy_subdag(
source_structure['blocks'][parent], user_id, destination_structure['_id'],
destination_blocks[parent], subtree_root.block_id, source_structure['blocks'], destination_blocks, blacklist
subtree_root.block_id
) )
) )
# update/create the subtree and its children in destination (skipping blacklist) # remove any remaining orphans
orphans.update( for orphan in orphans:
self._copy_subdag( # orphans will include moved as well as deleted xblocks. Only delete the deleted ones.
user_id, destination_structure['_id'], self._delete_if_true_orphan(orphan, destination_structure)
subtree_root.block_id, source_structure['blocks'], destination_blocks, blacklist
)
)
# remove any remaining orphans
for orphan in orphans:
# orphans will include moved as well as deleted xblocks. Only delete the deleted ones.
self._delete_if_true_orphan(orphan, destination_structure)
# update the db # update the db
self.db_connection.insert_structure(destination_structure) self.update_structure(destination_course, destination_structure)
self._update_head(index_entry, destination_course.branch, destination_structure['_id']) self._update_head(destination_course, index_entry, destination_course.branch, destination_structure['_id'])
def update_course_index(self, updated_index_entry):
"""
Change the given course's index entry.
Note, this operation can be dangerous and break running courses.
Does not return anything useful.
"""
self.db_connection.update_course_index(updated_index_entry)
def delete_item(self, usage_locator, user_id, force=False): def delete_item(self, usage_locator, user_id, force=False):
""" """
...@@ -1500,46 +1871,47 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1500,46 +1871,47 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
# The supplied UsageKey is of the wrong type, so it can't possibly be stored in this modulestore. # The supplied UsageKey is of the wrong type, so it can't possibly be stored in this modulestore.
raise ItemNotFoundError(usage_locator) raise ItemNotFoundError(usage_locator)
original_structure = self._lookup_course(usage_locator.course_key)['structure'] with self.bulk_operations(usage_locator.course_key):
if original_structure['root'] == usage_locator.block_id: original_structure = self._lookup_course(usage_locator.course_key)['structure']
raise ValueError("Cannot delete the root of a course") if original_structure['root'] == usage_locator.block_id:
if encode_key_for_mongo(usage_locator.block_id) not in original_structure['blocks']: raise ValueError("Cannot delete the root of a course")
raise ValueError("Cannot delete a block that does not exist") if encode_key_for_mongo(usage_locator.block_id) not in original_structure['blocks']:
index_entry = self._get_index_if_valid(usage_locator, force) raise ValueError("Cannot delete a block that does not exist")
new_structure = self._version_structure(original_structure, user_id) index_entry = self._get_index_if_valid(usage_locator.course_key, force)
new_blocks = new_structure['blocks'] new_structure = self.version_structure(usage_locator.course_key, original_structure, user_id)
new_id = new_structure['_id'] new_blocks = new_structure['blocks']
encoded_block_id = self._get_parent_from_structure(usage_locator.block_id, original_structure) new_id = new_structure['_id']
if encoded_block_id: encoded_block_id = self._get_parent_from_structure(usage_locator.block_id, original_structure)
parent_block = new_blocks[encoded_block_id] if encoded_block_id:
parent_block['fields']['children'].remove(usage_locator.block_id) parent_block = new_blocks[encoded_block_id]
parent_block['edit_info']['edited_on'] = datetime.datetime.now(UTC) parent_block['fields']['children'].remove(usage_locator.block_id)
parent_block['edit_info']['edited_by'] = user_id parent_block['edit_info']['edited_on'] = datetime.datetime.now(UTC)
parent_block['edit_info']['previous_version'] = parent_block['edit_info']['update_version'] parent_block['edit_info']['edited_by'] = user_id
parent_block['edit_info']['update_version'] = new_id parent_block['edit_info']['previous_version'] = parent_block['edit_info']['update_version']
parent_block['edit_info']['update_version'] = new_id
def remove_subtree(block_id):
""" def remove_subtree(block_id):
Remove the subtree rooted at block_id """
""" Remove the subtree rooted at block_id
encoded_block_id = encode_key_for_mongo(block_id) """
for child in new_blocks[encoded_block_id]['fields'].get('children', []): encoded_block_id = encode_key_for_mongo(block_id)
remove_subtree(child) for child in new_blocks[encoded_block_id]['fields'].get('children', []):
del new_blocks[encoded_block_id] remove_subtree(child)
del new_blocks[encoded_block_id]
remove_subtree(usage_locator.block_id)
remove_subtree(usage_locator.block_id)
# update index if appropriate and structures
self.db_connection.insert_structure(new_structure) # update index if appropriate and structures
self.update_structure(usage_locator.course_key, new_structure)
if index_entry is not None: if index_entry is not None:
# update the index entry if appropriate # update the index entry if appropriate
self._update_head(index_entry, usage_locator.branch, new_id) self._update_head(usage_locator.course_key, index_entry, usage_locator.branch, new_id)
result = usage_locator.course_key.for_version(new_id) result = usage_locator.course_key.for_version(new_id)
else: else:
result = CourseLocator(version_guid=new_id) result = CourseLocator(version_guid=new_id)
return result return result
def delete_course(self, course_key, user_id): def delete_course(self, course_key, user_id):
""" """
...@@ -1549,7 +1921,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1549,7 +1921,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
with a versions hash to restore the course; however, the edited_on and with a versions hash to restore the course; however, the edited_on and
edited_by won't reflect the originals, of course. edited_by won't reflect the originals, of course.
""" """
index = self.db_connection.get_course_index(course_key) index = self.get_course_index(course_key)
if index is None: if index is None:
raise ItemNotFoundError(course_key) raise ItemNotFoundError(course_key)
# this is the only real delete in the system. should it do something else? # this is the only real delete in the system. should it do something else?
...@@ -1610,23 +1982,10 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1610,23 +1982,10 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
if depth is None or depth > 0: if depth is None or depth > 0:
depth = depth - 1 if depth is not None else None depth = depth - 1 if depth is not None else None
for child in descendent_map[block_id]['fields'].get('children', []): for child in descendent_map[block_id]['fields'].get('children', []):
descendent_map = self.descendants(block_map, child, depth, descendent_map = self.descendants(block_map, child, depth, descendent_map)
descendent_map)
return descendent_map return descendent_map
def definition_locator(self, definition):
'''
Pull the id out of the definition w/ correct semantics for its
representation
'''
if isinstance(definition, DefinitionLazyLoader):
return definition.definition_locator
elif '_id' not in definition:
return DefinitionLocator(definition.get('category'), LocalId())
else:
return DefinitionLocator(definition['category'], definition['_id'])
def get_modulestore_type(self, course_key=None): def get_modulestore_type(self, course_key=None):
""" """
Returns an enumeration-like type reflecting the type of this modulestore, per ModuleStoreEnum.Type. Returns an enumeration-like type reflecting the type of this modulestore, per ModuleStoreEnum.Type.
...@@ -1651,9 +2010,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1651,9 +2010,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
block_id for block_id in block['fields']["children"] block_id for block_id in block['fields']["children"]
if encode_key_for_mongo(block_id) in original_structure['blocks'] if encode_key_for_mongo(block_id) in original_structure['blocks']
] ]
self.db_connection.update_structure(original_structure) self.update_structure(course_locator, original_structure)
# clear cache again b/c inheritance may be wrong over orphans
self._clear_cache(original_structure['_id'])
def convert_references_to_keys(self, course_key, xblock_class, jsonfields, blocks): def convert_references_to_keys(self, course_key, xblock_class, jsonfields, blocks):
""" """
...@@ -1681,69 +2038,50 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1681,69 +2038,50 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
return course_key.make_usage_key('unknown', block_id) return course_key.make_usage_key('unknown', block_id)
xblock_class = self.mixologist.mix(xblock_class) xblock_class = self.mixologist.mix(xblock_class)
for field_name, value in jsonfields.iteritems(): # Make a shallow copy, so that we aren't manipulating a cached field dictionary
output_fields = dict(jsonfields)
for field_name, value in output_fields.iteritems():
if value: if value:
field = xblock_class.fields.get(field_name) field = xblock_class.fields.get(field_name)
if field is None: if field is None:
continue continue
elif isinstance(field, Reference): elif isinstance(field, Reference):
jsonfields[field_name] = robust_usage_key(value) output_fields[field_name] = robust_usage_key(value)
elif isinstance(field, ReferenceList): elif isinstance(field, ReferenceList):
jsonfields[field_name] = [robust_usage_key(ele) for ele in value] output_fields[field_name] = [robust_usage_key(ele) for ele in value]
elif isinstance(field, ReferenceValueDict): elif isinstance(field, ReferenceValueDict):
for key, subvalue in value.iteritems(): for key, subvalue in value.iteritems():
assert isinstance(subvalue, basestring) assert isinstance(subvalue, basestring)
value[key] = robust_usage_key(subvalue) value[key] = robust_usage_key(subvalue)
return jsonfields return output_fields
def _get_index_if_valid(self, locator, force=False, continue_version=False): def _get_index_if_valid(self, course_key, force=False):
""" """
If the locator identifies a course and points to its draft (or plausibly its draft), If the course_key identifies a course and points to its draft (or plausibly its draft),
then return the index entry. then return the index entry.
raises VersionConflictError if not the right version raises VersionConflictError if not the right version
:param locator: a courselocator :param course_key: a CourseLocator
:param force: if false, raises VersionConflictError if the current head of the course != the one identified :param force: if false, raises VersionConflictError if the current head of the course != the one identified
by locator. Cannot be True if continue_version is True by course_key
:param continue_version: if True, assumes this operation requires a head version and will not create a new """
version but instead continue an existing transaction on this version. This flag cannot be True if force is True. if course_key.org is None or course_key.course is None or course_key.run is None or course_key.branch is None:
""" return None
if locator.org is None or locator.course is None or locator.run is None or locator.branch is None:
if continue_version:
raise InsufficientSpecificationError(
"To continue a version, the locator must point to one ({}).".format(locator)
)
else:
return None
else: else:
index_entry = self.db_connection.get_course_index(locator) index_entry = self.get_course_index(course_key)
is_head = ( is_head = (
locator.version_guid is None or course_key.version_guid is None or
index_entry['versions'][locator.branch] == locator.version_guid index_entry['versions'][course_key.branch] == course_key.version_guid
) )
if (is_head or (force and not continue_version)): if (is_head or force):
return index_entry return index_entry
else: else:
raise VersionConflictError( raise VersionConflictError(
locator, course_key,
index_entry['versions'][locator.branch] index_entry['versions'][course_key.branch]
) )
def _version_structure(self, structure, user_id):
"""
Copy the structure and update the history info (edited_by, edited_on, previous_version)
:param structure:
:param user_id:
"""
new_structure = copy.deepcopy(structure)
new_structure['_id'] = ObjectId()
new_structure['previous_version'] = structure['_id']
new_structure['edited_by'] = user_id
new_structure['edited_on'] = datetime.datetime.now(UTC)
new_structure['schema_version'] = self.SCHEMA_VERSION
return new_structure
def _find_local_root(self, element_to_find, possibility, tree): def _find_local_root(self, element_to_find, possibility, tree):
if possibility not in tree: if possibility not in tree:
return False return False
...@@ -1766,7 +2104,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1766,7 +2104,7 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
if field_name in self.SEARCH_TARGET_DICT: if field_name in self.SEARCH_TARGET_DICT:
index_entry.setdefault('search_targets', {})[field_name] = field_value index_entry.setdefault('search_targets', {})[field_name] = field_value
def _update_head(self, index_entry, branch, new_id): def _update_head(self, course_key, index_entry, branch, new_id):
""" """
Update the active index for the given course's branch to point to new_id Update the active index for the given course's branch to point to new_id
...@@ -1774,8 +2112,10 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -1774,8 +2112,10 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
:param course_locator: :param course_locator:
:param new_id: :param new_id:
""" """
if not isinstance(new_id, ObjectId):
raise TypeError('new_id must be an ObjectId, but is {!r}'.format(new_id))
index_entry['versions'][branch] = new_id index_entry['versions'][branch] = new_id
self.db_connection.update_course_index(index_entry) self.insert_course_index(course_key, index_entry)
def partition_xblock_fields_by_scope(self, xblock): def partition_xblock_fields_by_scope(self, xblock):
""" """
...@@ -2010,8 +2350,8 @@ class SplitMongoModuleStore(ModuleStoreWriteBase): ...@@ -2010,8 +2350,8 @@ class SplitMongoModuleStore(ModuleStoreWriteBase):
Returns: list of branch-agnostic course_keys Returns: list of branch-agnostic course_keys
""" """
entries = self.db_connection.find_matching_course_indexes( entries = self.find_matching_course_indexes(
{'search_targets.{}'.format(field_name): field_value} search_targets={field_name: field_value}
) )
return [ return [
CourseLocator(entry['org'], entry['course'], entry['run']) # Branch agnostic CourseLocator(entry['org'], entry['course'], entry['run']) # Branch agnostic
......
...@@ -9,6 +9,7 @@ from xmodule.modulestore.exceptions import InsufficientSpecificationError ...@@ -9,6 +9,7 @@ from xmodule.modulestore.exceptions import InsufficientSpecificationError
from xmodule.modulestore.draft_and_published import ( from xmodule.modulestore.draft_and_published import (
ModuleStoreDraftAndPublished, DIRECT_ONLY_CATEGORIES, UnsupportedRevisionError ModuleStoreDraftAndPublished, DIRECT_ONLY_CATEGORIES, UnsupportedRevisionError
) )
from opaque_keys.edx.locator import CourseLocator
class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleStore): class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleStore):
...@@ -30,24 +31,25 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS ...@@ -30,24 +31,25 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS
Returns: a CourseDescriptor Returns: a CourseDescriptor
""" """
master_branch = kwargs.pop('master_branch', ModuleStoreEnum.BranchName.draft) master_branch = kwargs.pop('master_branch', ModuleStoreEnum.BranchName.draft)
item = super(DraftVersioningModuleStore, self).create_course( with self.bulk_operations(CourseLocator(org, course, run)):
org, course, run, user_id, master_branch=master_branch, **kwargs item = super(DraftVersioningModuleStore, self).create_course(
) org, course, run, user_id, master_branch=master_branch, **kwargs
if master_branch == ModuleStoreEnum.BranchName.draft and not skip_auto_publish: )
# any other value is hopefully only cloning or doing something which doesn't want this value add if master_branch == ModuleStoreEnum.BranchName.draft and not skip_auto_publish:
self._auto_publish_no_children(item.location, item.location.category, user_id, **kwargs) # any other value is hopefully only cloning or doing something which doesn't want this value add
self._auto_publish_no_children(item.location, item.location.category, user_id, **kwargs)
# create any other necessary things as a side effect: ensure they populate the draft branch
# and rely on auto publish to populate the published branch: split's create course doesn't # create any other necessary things as a side effect: ensure they populate the draft branch
# call super b/c it needs the auto publish above to have happened before any of the create_items # and rely on auto publish to populate the published branch: split's create course doesn't
# in this. The explicit use of SplitMongoModuleStore is intentional # call super b/c it needs the auto publish above to have happened before any of the create_items
with self.branch_setting(ModuleStoreEnum.Branch.draft_preferred, item.id): # in this. The explicit use of SplitMongoModuleStore is intentional
# pylint: disable=bad-super-call with self.branch_setting(ModuleStoreEnum.Branch.draft_preferred, item.id):
super(SplitMongoModuleStore, self).create_course( # pylint: disable=bad-super-call
org, course, run, user_id, runtime=item.runtime, **kwargs super(SplitMongoModuleStore, self).create_course(
) org, course, run, user_id, runtime=item.runtime, **kwargs
)
return item
return item
def get_course(self, course_id, depth=0, **kwargs): def get_course(self, course_id, depth=0, **kwargs):
course_id = self._map_revision_to_branch(course_id) course_id = self._map_revision_to_branch(course_id)
...@@ -87,45 +89,48 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS ...@@ -87,45 +89,48 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS
def update_item(self, descriptor, user_id, allow_not_found=False, force=False, **kwargs): def update_item(self, descriptor, user_id, allow_not_found=False, force=False, **kwargs):
descriptor.location = self._map_revision_to_branch(descriptor.location) descriptor.location = self._map_revision_to_branch(descriptor.location)
item = super(DraftVersioningModuleStore, self).update_item( with self.bulk_operations(descriptor.location.course_key):
descriptor, item = super(DraftVersioningModuleStore, self).update_item(
user_id, descriptor,
allow_not_found=allow_not_found, user_id,
force=force, allow_not_found=allow_not_found,
**kwargs force=force,
) **kwargs
self._auto_publish_no_children(item.location, item.location.category, user_id, **kwargs) )
return item self._auto_publish_no_children(item.location, item.location.category, user_id, **kwargs)
return item
def create_item( def create_item(
self, user_id, course_key, block_type, block_id=None, self, user_id, course_key, block_type, block_id=None,
definition_locator=None, fields=None, definition_locator=None, fields=None,
force=False, continue_version=False, skip_auto_publish=False, **kwargs force=False, skip_auto_publish=False, **kwargs
): ):
""" """
See :py:meth `ModuleStoreDraftAndPublished.create_item` See :py:meth `ModuleStoreDraftAndPublished.create_item`
""" """
course_key = self._map_revision_to_branch(course_key) course_key = self._map_revision_to_branch(course_key)
item = super(DraftVersioningModuleStore, self).create_item( with self.bulk_operations(course_key):
user_id, course_key, block_type, block_id=block_id, item = super(DraftVersioningModuleStore, self).create_item(
definition_locator=definition_locator, fields=fields, user_id, course_key, block_type, block_id=block_id,
force=force, continue_version=continue_version, **kwargs definition_locator=definition_locator, fields=fields,
) force=force, **kwargs
if not skip_auto_publish: )
self._auto_publish_no_children(item.location, item.location.category, user_id, **kwargs) if not skip_auto_publish:
return item self._auto_publish_no_children(item.location, item.location.category, user_id, **kwargs)
return item
def create_child( def create_child(
self, user_id, parent_usage_key, block_type, block_id=None, self, user_id, parent_usage_key, block_type, block_id=None,
fields=None, continue_version=False, **kwargs fields=None, **kwargs
): ):
parent_usage_key = self._map_revision_to_branch(parent_usage_key) parent_usage_key = self._map_revision_to_branch(parent_usage_key)
item = super(DraftVersioningModuleStore, self).create_child( with self.bulk_operations(parent_usage_key.course_key):
user_id, parent_usage_key, block_type, block_id=block_id, item = super(DraftVersioningModuleStore, self).create_child(
fields=fields, continue_version=continue_version, **kwargs user_id, parent_usage_key, block_type, block_id=block_id,
) fields=fields, **kwargs
self._auto_publish_no_children(parent_usage_key, item.location.category, user_id, **kwargs) )
return item self._auto_publish_no_children(parent_usage_key, item.location.category, user_id, **kwargs)
return item
def delete_item(self, location, user_id, revision=None, **kwargs): def delete_item(self, location, user_id, revision=None, **kwargs):
""" """
...@@ -141,26 +146,27 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS ...@@ -141,26 +146,27 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS
currently only provided by contentstore.views.item.orphan_handler currently only provided by contentstore.views.item.orphan_handler
Otherwise, raises a ValueError. Otherwise, raises a ValueError.
""" """
if revision == ModuleStoreEnum.RevisionOption.published_only: with self.bulk_operations(location.course_key):
branches_to_delete = [ModuleStoreEnum.BranchName.published] if revision == ModuleStoreEnum.RevisionOption.published_only:
elif revision == ModuleStoreEnum.RevisionOption.all: branches_to_delete = [ModuleStoreEnum.BranchName.published]
branches_to_delete = [ModuleStoreEnum.BranchName.published, ModuleStoreEnum.BranchName.draft] elif revision == ModuleStoreEnum.RevisionOption.all:
elif revision is None: branches_to_delete = [ModuleStoreEnum.BranchName.published, ModuleStoreEnum.BranchName.draft]
branches_to_delete = [ModuleStoreEnum.BranchName.draft] elif revision is None:
else: branches_to_delete = [ModuleStoreEnum.BranchName.draft]
raise UnsupportedRevisionError( else:
[ raise UnsupportedRevisionError(
None, [
ModuleStoreEnum.RevisionOption.published_only, None,
ModuleStoreEnum.RevisionOption.all ModuleStoreEnum.RevisionOption.published_only,
] ModuleStoreEnum.RevisionOption.all
) ]
)
for branch in branches_to_delete: for branch in branches_to_delete:
branched_location = location.for_branch(branch) branched_location = location.for_branch(branch)
parent_loc = self.get_parent_location(branched_location) parent_loc = self.get_parent_location(branched_location)
SplitMongoModuleStore.delete_item(self, branched_location, user_id) SplitMongoModuleStore.delete_item(self, branched_location, user_id)
self._auto_publish_no_children(parent_loc, parent_loc.category, user_id, **kwargs) self._auto_publish_no_children(parent_loc, parent_loc.category, user_id, **kwargs)
def _map_revision_to_branch(self, key, revision=None): def _map_revision_to_branch(self, key, revision=None):
""" """
...@@ -231,7 +237,7 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS ...@@ -231,7 +237,7 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS
:return: True if the draft and published versions differ :return: True if the draft and published versions differ
""" """
def get_block(branch_name): def get_block(branch_name):
course_structure = self._lookup_course(xblock.location.for_branch(branch_name))['structure'] course_structure = self._lookup_course(xblock.location.course_key.for_branch(branch_name))['structure']
return self._get_block_from_structure(course_structure, xblock.location.block_id) return self._get_block_from_structure(course_structure, xblock.location.block_id)
draft_block = get_block(ModuleStoreEnum.BranchName.draft) draft_block = get_block(ModuleStoreEnum.BranchName.draft)
...@@ -255,7 +261,9 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS ...@@ -255,7 +261,9 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS
# because for_branch obliterates the version_guid and will lead to missed version conflicts. # because for_branch obliterates the version_guid and will lead to missed version conflicts.
# TODO Instead, the for_branch implementation should be fixed in the Opaque Keys library. # TODO Instead, the for_branch implementation should be fixed in the Opaque Keys library.
location.course_key.replace(branch=ModuleStoreEnum.BranchName.draft), location.course_key.replace(branch=ModuleStoreEnum.BranchName.draft),
location.course_key.for_branch(ModuleStoreEnum.BranchName.published), # We clear out the version_guid here because the location here is from the draft branch, and that
# won't have the same version guid
location.course_key.replace(branch=ModuleStoreEnum.BranchName.published, version_guid=None),
[location], [location],
blacklist=blacklist blacklist=blacklist
) )
...@@ -266,8 +274,9 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS ...@@ -266,8 +274,9 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS
Deletes the published version of the item. Deletes the published version of the item.
Returns the newly unpublished item. Returns the newly unpublished item.
""" """
self.delete_item(location, user_id, revision=ModuleStoreEnum.RevisionOption.published_only) with self.bulk_operations(location.course_key):
return self.get_item(location.for_branch(ModuleStoreEnum.BranchName.draft), **kwargs) self.delete_item(location, user_id, revision=ModuleStoreEnum.RevisionOption.published_only)
return self.get_item(location.for_branch(ModuleStoreEnum.BranchName.draft), **kwargs)
def revert_to_published(self, location, user_id): def revert_to_published(self, location, user_id):
""" """
...@@ -348,32 +357,33 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS ...@@ -348,32 +357,33 @@ class DraftVersioningModuleStore(ModuleStoreDraftAndPublished, SplitMongoModuleS
""" """
Split-based modulestores need to import published blocks to both branches Split-based modulestores need to import published blocks to both branches
""" """
# hardcode course root block id with self.bulk_operations(course_key):
if block_type == 'course': # hardcode course root block id
block_id = self.DEFAULT_ROOT_BLOCK_ID if block_type == 'course':
new_usage_key = course_key.make_usage_key(block_type, block_id) block_id = self.DEFAULT_ROOT_BLOCK_ID
new_usage_key = course_key.make_usage_key(block_type, block_id)
if self.get_branch_setting() == ModuleStoreEnum.Branch.published_only:
# if importing a direct only, override existing draft if self.get_branch_setting() == ModuleStoreEnum.Branch.published_only:
if block_type in DIRECT_ONLY_CATEGORIES: # if importing a direct only, override existing draft
draft_course = course_key.for_branch(ModuleStoreEnum.BranchName.draft) if block_type in DIRECT_ONLY_CATEGORIES:
with self.branch_setting(ModuleStoreEnum.Branch.draft_preferred, draft_course):
draft = self.import_xblock(user_id, draft_course, block_type, block_id, fields, runtime)
self._auto_publish_no_children(draft.location, block_type, user_id)
return self.get_item(new_usage_key.for_branch(ModuleStoreEnum.BranchName.published))
# if new to published
elif not self.has_item(new_usage_key.for_branch(ModuleStoreEnum.BranchName.published)):
# check whether it's new to draft
if not self.has_item(new_usage_key.for_branch(ModuleStoreEnum.BranchName.draft)):
# add to draft too
draft_course = course_key.for_branch(ModuleStoreEnum.BranchName.draft) draft_course = course_key.for_branch(ModuleStoreEnum.BranchName.draft)
with self.branch_setting(ModuleStoreEnum.Branch.draft_preferred, draft_course): with self.branch_setting(ModuleStoreEnum.Branch.draft_preferred, draft_course):
draft = self.import_xblock(user_id, draft_course, block_type, block_id, fields, runtime) draft = self.import_xblock(user_id, draft_course, block_type, block_id, fields, runtime)
return self.publish(draft.location, user_id, blacklist=EXCLUDE_ALL) self._auto_publish_no_children(draft.location, block_type, user_id)
return self.get_item(new_usage_key.for_branch(ModuleStoreEnum.BranchName.published))
# do the import # if new to published
partitioned_fields = self.partition_fields_by_scope(block_type, fields) elif not self.has_item(new_usage_key.for_branch(ModuleStoreEnum.BranchName.published)):
course_key = self._map_revision_to_branch(course_key) # cast to branch_setting # check whether it's new to draft
return self._update_item_from_fields( if not self.has_item(new_usage_key.for_branch(ModuleStoreEnum.BranchName.draft)):
user_id, course_key, block_type, block_id, partitioned_fields, None, allow_not_found=True, force=True # add to draft too
) draft_course = course_key.for_branch(ModuleStoreEnum.BranchName.draft)
with self.branch_setting(ModuleStoreEnum.Branch.draft_preferred, draft_course):
draft = self.import_xblock(user_id, draft_course, block_type, block_id, fields, runtime)
return self.publish(draft.location, user_id, blacklist=EXCLUDE_ALL)
# do the import
partitioned_fields = self.partition_fields_by_scope(block_type, fields)
course_key = self._map_revision_to_branch(course_key) # cast to branch_setting
return self._update_item_from_fields(
user_id, course_key, block_type, block_id, partitioned_fields, None, allow_not_found=True, force=True
)
...@@ -287,7 +287,7 @@ class ModuleStoreTestCase(TestCase): ...@@ -287,7 +287,7 @@ class ModuleStoreTestCase(TestCase):
course_loc: the CourseKey for the created course course_loc: the CourseKey for the created course
""" """
with self.store.branch_setting(ModuleStoreEnum.Branch.draft_preferred, None): with self.store.branch_setting(ModuleStoreEnum.Branch.draft_preferred, None):
# with self.store.bulk_write_operations(self.store.make_course_key(org, course, run)): # with self.store.bulk_operations(self.store.make_course_key(org, course, run)):
course = self.store.create_course(org, course, run, self.user.id, fields=course_fields) course = self.store.create_course(org, course, run, self.user.id, fields=course_fields)
self.course_loc = course.location self.course_loc = course.location
...@@ -314,7 +314,7 @@ class ModuleStoreTestCase(TestCase): ...@@ -314,7 +314,7 @@ class ModuleStoreTestCase(TestCase):
""" """
Create an equivalent to the toy xml course Create an equivalent to the toy xml course
""" """
# with self.store.bulk_write_operations(self.store.make_course_key(org, course, run)): # with self.store.bulk_operations(self.store.make_course_key(org, course, run)):
self.toy_loc = self.create_sample_course( self.toy_loc = self.create_sample_course(
org, course, run, TOY_BLOCK_INFO_TREE, org, course, run, TOY_BLOCK_INFO_TREE,
{ {
......
import pprint
import pymongo.message
from factory import Factory, lazy_attribute_sequence, lazy_attribute from factory import Factory, lazy_attribute_sequence, lazy_attribute
from factory.containers import CyclicDefinitionError from factory.containers import CyclicDefinitionError
from uuid import uuid4 from uuid import uuid4
...@@ -214,88 +217,83 @@ class ItemFactory(XModuleFactory): ...@@ -214,88 +217,83 @@ class ItemFactory(XModuleFactory):
@contextmanager @contextmanager
def check_exact_number_of_calls(object_with_method, method, num_calls, method_name=None): def check_exact_number_of_calls(object_with_method, method_name, num_calls):
""" """
Instruments the given method on the given object to verify the number of calls to the Instruments the given method on the given object to verify the number of calls to the
method is exactly equal to 'num_calls'. method is exactly equal to 'num_calls'.
""" """
with check_number_of_calls(object_with_method, method, num_calls, num_calls, method_name): with check_number_of_calls(object_with_method, method_name, num_calls, num_calls):
yield yield
@contextmanager def check_number_of_calls(object_with_method, method_name, maximum_calls, minimum_calls=1):
def check_number_of_calls(object_with_method, method, maximum_calls, minimum_calls=1, method_name=None):
""" """
Instruments the given method on the given object to verify the number of calls to the method is Instruments the given method on the given object to verify the number of calls to the method is
less than or equal to the expected maximum_calls and greater than or equal to the expected minimum_calls. less than or equal to the expected maximum_calls and greater than or equal to the expected minimum_calls.
""" """
method_wrap = Mock(wraps=method) return check_sum_of_calls(object_with_method, [method_name], maximum_calls, minimum_calls)
wrap_patch = patch.object(object_with_method, method_name or method.__name__, method_wrap)
try: @contextmanager
wrap_patch.start() def check_sum_of_calls(object_, methods, maximum_calls, minimum_calls=1):
"""
Instruments the given methods on the given object to verify that the total sum of calls made to the
methods falls between minumum_calls and maximum_calls.
"""
mocks = {
method: Mock(wraps=getattr(object_, method))
for method in methods
}
with patch.multiple(object_, **mocks):
yield yield
finally: call_count = sum(mock.call_count for mock in mocks.values())
wrap_patch.stop() calls = pprint.pformat({
method_name: mock.call_args_list
for method_name, mock in mocks.items()
})
# Assertion errors don't handle multi-line values, so pretty-print to std-out instead
if not minimum_calls <= call_count <= maximum_calls:
print "Expected between {} and {} calls, {} were made. Calls: {}".format(
minimum_calls,
maximum_calls,
call_count,
calls,
)
# verify the counter actually worked by ensuring we have counted greater than (or equal to) the minimum calls # verify the counter actually worked by ensuring we have counted greater than (or equal to) the minimum calls
assert_greater_equal(method_wrap.call_count, minimum_calls) assert_greater_equal(call_count, minimum_calls)
# now verify the number of actual calls is less than (or equal to) the expected maximum # now verify the number of actual calls is less than (or equal to) the expected maximum
assert_less_equal(method_wrap.call_count, maximum_calls) assert_less_equal(call_count, maximum_calls)
@contextmanager @contextmanager
def check_mongo_calls(mongo_store, num_finds=0, num_sends=None): def check_mongo_calls(num_finds=0, num_sends=None):
""" """
Instruments the given store to count the number of calls to find (incl find_one) and the number Instruments the given store to count the number of calls to find (incl find_one) and the number
of calls to send_message which is for insert, update, and remove (if you provide num_sends). At the of calls to send_message which is for insert, update, and remove (if you provide num_sends). At the
end of the with statement, it compares the counts to the num_finds and num_sends. end of the with statement, it compares the counts to the num_finds and num_sends.
:param mongo_store: the MongoModulestore or subclass to watch or a SplitMongoModuleStore
:param num_finds: the exact number of find calls expected :param num_finds: the exact number of find calls expected
:param num_sends: If none, don't instrument the send calls. If non-none, count and compare to :param num_sends: If none, don't instrument the send calls. If non-none, count and compare to
the given int value. the given int value.
""" """
if mongo_store.get_modulestore_type() == ModuleStoreEnum.Type.mongo: with check_sum_of_calls(
with check_exact_number_of_calls(mongo_store.collection, mongo_store.collection.find, num_finds): pymongo.message,
if num_sends is not None: ['query', 'get_more'],
with check_exact_number_of_calls( num_finds,
mongo_store.database.connection, num_finds
mongo_store.database.connection._send_message, # pylint: disable=protected-access ):
num_sends, if num_sends is not None:
): with check_sum_of_calls(
yield pymongo.message,
else: ['insert', 'update', 'delete'],
num_sends,
num_sends
):
yield yield
elif mongo_store.get_modulestore_type() == ModuleStoreEnum.Type.split: else:
collections = [ yield
mongo_store.db_connection.course_index,
mongo_store.db_connection.structures,
mongo_store.db_connection.definitions,
]
# could add else clause which raises exception or just rely on the below to suss that out
try:
find_wraps = []
wrap_patches = []
for collection in collections:
find_wrap = Mock(wraps=collection.find)
find_wraps.append(find_wrap)
wrap_patch = patch.object(collection, 'find', find_wrap)
wrap_patches.append(wrap_patch)
wrap_patch.start()
if num_sends is not None:
connection = mongo_store.db_connection.database.connection
with check_exact_number_of_calls(
connection,
connection._send_message, # pylint: disable=protected-access
num_sends,
):
yield
else:
yield
finally:
map(lambda wrap_patch: wrap_patch.stop(), wrap_patches)
call_count = sum([find_wrap.call_count for find_wrap in find_wraps])
assert_equal(call_count, num_finds)
...@@ -127,15 +127,16 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -127,15 +127,16 @@ class TestMixedModuleStore(unittest.TestCase):
Create a course w/ one item in the persistence store using the given course & item location. Create a course w/ one item in the persistence store using the given course & item location.
""" """
# create course # create course
self.course = self.store.create_course(course_key.org, course_key.course, course_key.run, self.user_id) with self.store.bulk_operations(course_key):
if isinstance(self.course.id, CourseLocator): self.course = self.store.create_course(course_key.org, course_key.course, course_key.run, self.user_id)
self.course_locations[self.MONGO_COURSEID] = self.course.location if isinstance(self.course.id, CourseLocator):
else: self.course_locations[self.MONGO_COURSEID] = self.course.location
self.assertEqual(self.course.id, course_key) else:
self.assertEqual(self.course.id, course_key)
# create chapter # create chapter
chapter = self.store.create_child(self.user_id, self.course.location, 'chapter', block_id='Overview') chapter = self.store.create_child(self.user_id, self.course.location, 'chapter', block_id='Overview')
self.writable_chapter_location = chapter.location self.writable_chapter_location = chapter.location
def _create_block_hierarchy(self): def _create_block_hierarchy(self):
""" """
...@@ -188,8 +189,9 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -188,8 +189,9 @@ class TestMixedModuleStore(unittest.TestCase):
create_sub_tree(block, tree) create_sub_tree(block, tree)
setattr(self, block_info.field_name, block.location) setattr(self, block_info.field_name, block.location)
for tree in trees: with self.store.bulk_operations(self.course.id):
create_sub_tree(self.course, tree) for tree in trees:
create_sub_tree(self.course, tree)
def _course_key_from_string(self, string): def _course_key_from_string(self, string):
""" """
...@@ -270,8 +272,13 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -270,8 +272,13 @@ class TestMixedModuleStore(unittest.TestCase):
with self.assertRaises(DuplicateCourseError): with self.assertRaises(DuplicateCourseError):
self.store.create_course('org_x', 'course_y', 'run_z', self.user_id) self.store.create_course('org_x', 'course_y', 'run_z', self.user_id)
# Draft:
# - One lookup to locate an item that exists
# - Two lookups to determine an item doesn't exist (one to check mongo, one to check split)
# split has one lookup for the course and then one for the course items # split has one lookup for the course and then one for the course items
@ddt.data(('draft', 1, 0), ('split', 2, 0)) # TODO: LMS-11220: Document why draft find count is [1, 1]
# TODO: LMS-11220: Document why split find count is [2, 2]
@ddt.data(('draft', [1, 1], 0), ('split', [2, 2], 0))
@ddt.unpack @ddt.unpack
def test_has_item(self, default_ms, max_find, max_send): def test_has_item(self, default_ms, max_find, max_send):
self.initdb(default_ms) self.initdb(default_ms)
...@@ -279,15 +286,14 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -279,15 +286,14 @@ class TestMixedModuleStore(unittest.TestCase):
self.assertTrue(self.store.has_item(self.course_locations[self.XML_COURSEID1])) self.assertTrue(self.store.has_item(self.course_locations[self.XML_COURSEID1]))
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) with check_mongo_calls(max_find.pop(0), max_send):
with check_mongo_calls(mongo_store, max_find, max_send):
self.assertTrue(self.store.has_item(self.problem_x1a_1)) self.assertTrue(self.store.has_item(self.problem_x1a_1))
# try negative cases # try negative cases
self.assertFalse(self.store.has_item( self.assertFalse(self.store.has_item(
self.course_locations[self.XML_COURSEID1].replace(name='not_findable', category='problem') self.course_locations[self.XML_COURSEID1].replace(name='not_findable', category='problem')
)) ))
with check_mongo_calls(mongo_store, max_find, max_send): with check_mongo_calls(max_find.pop(0), max_send):
self.assertFalse(self.store.has_item(self.fake_location)) self.assertFalse(self.store.has_item(self.fake_location))
# verify that an error is raised when the revision is not valid # verify that an error is raised when the revision is not valid
...@@ -296,7 +302,9 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -296,7 +302,9 @@ class TestMixedModuleStore(unittest.TestCase):
# draft is 2 to compute inheritance # draft is 2 to compute inheritance
# split is 2 (would be 3 on course b/c it looks up the wiki_slug in definitions) # split is 2 (would be 3 on course b/c it looks up the wiki_slug in definitions)
@ddt.data(('draft', 2, 0), ('split', 2, 0)) # TODO: LMS-11220: Document why draft find count is [2, 2]
# TODO: LMS-11220: Document why split find count is [3, 3]
@ddt.data(('draft', [2, 2], 0), ('split', [3, 3], 0))
@ddt.unpack @ddt.unpack
def test_get_item(self, default_ms, max_find, max_send): def test_get_item(self, default_ms, max_find, max_send):
self.initdb(default_ms) self.initdb(default_ms)
...@@ -304,8 +312,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -304,8 +312,7 @@ class TestMixedModuleStore(unittest.TestCase):
self.assertIsNotNone(self.store.get_item(self.course_locations[self.XML_COURSEID1])) self.assertIsNotNone(self.store.get_item(self.course_locations[self.XML_COURSEID1]))
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) with check_mongo_calls(max_find.pop(0), max_send):
with check_mongo_calls(mongo_store, max_find, max_send):
self.assertIsNotNone(self.store.get_item(self.problem_x1a_1)) self.assertIsNotNone(self.store.get_item(self.problem_x1a_1))
# try negative cases # try negative cases
...@@ -313,7 +320,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -313,7 +320,7 @@ class TestMixedModuleStore(unittest.TestCase):
self.store.get_item( self.store.get_item(
self.course_locations[self.XML_COURSEID1].replace(name='not_findable', category='problem') self.course_locations[self.XML_COURSEID1].replace(name='not_findable', category='problem')
) )
with check_mongo_calls(mongo_store, max_find, max_send): with check_mongo_calls(max_find.pop(0), max_send):
with self.assertRaises(ItemNotFoundError): with self.assertRaises(ItemNotFoundError):
self.store.get_item(self.fake_location) self.store.get_item(self.fake_location)
...@@ -322,7 +329,8 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -322,7 +329,8 @@ class TestMixedModuleStore(unittest.TestCase):
self.store.get_item(self.fake_location, revision=ModuleStoreEnum.RevisionOption.draft_preferred) self.store.get_item(self.fake_location, revision=ModuleStoreEnum.RevisionOption.draft_preferred)
# compared to get_item for the course, draft asks for both draft and published # compared to get_item for the course, draft asks for both draft and published
@ddt.data(('draft', 8, 0), ('split', 2, 0)) # TODO: LMS-11220: Document why split find count is 3
@ddt.data(('draft', 8, 0), ('split', 3, 0))
@ddt.unpack @ddt.unpack
def test_get_items(self, default_ms, max_find, max_send): def test_get_items(self, default_ms, max_find, max_send):
self.initdb(default_ms) self.initdb(default_ms)
...@@ -334,9 +342,8 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -334,9 +342,8 @@ class TestMixedModuleStore(unittest.TestCase):
self.assertEqual(len(modules), 1) self.assertEqual(len(modules), 1)
self.assertEqual(modules[0].location, course_locn) self.assertEqual(modules[0].location, course_locn)
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID))
course_locn = self.course_locations[self.MONGO_COURSEID] course_locn = self.course_locations[self.MONGO_COURSEID]
with check_mongo_calls(mongo_store, max_find, max_send): with check_mongo_calls(max_find, max_send):
# NOTE: use get_course if you just want the course. get_items is expensive # NOTE: use get_course if you just want the course. get_items is expensive
modules = self.store.get_items(course_locn.course_key, qualifiers={'category': 'problem'}) modules = self.store.get_items(course_locn.course_key, qualifiers={'category': 'problem'})
self.assertEqual(len(modules), 6) self.assertEqual(len(modules), 6)
...@@ -349,10 +356,9 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -349,10 +356,9 @@ class TestMixedModuleStore(unittest.TestCase):
) )
# draft: 2 to look in draft and then published and then 5 for updating ancestors. # draft: 2 to look in draft and then published and then 5 for updating ancestors.
# split: 3 to get the course structure & the course definition (show_calculator is scope content) # split: 1 for the course index, 1 for the course structure before the change, 1 for the structure after the change
# before the change. 1 during change to refetch the definition. 3 afterward (b/c it calls get_item to return the "new" object). # 2 sends: insert structure, update index_entry
# 2 sends to update index & structure (calculator is a setting field) @ddt.data(('draft', 11, 5), ('split', 3, 2))
@ddt.data(('draft', 7, 5), ('split', 6, 2))
@ddt.unpack @ddt.unpack
def test_update_item(self, default_ms, max_find, max_send): def test_update_item(self, default_ms, max_find, max_send):
""" """
...@@ -368,12 +374,11 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -368,12 +374,11 @@ class TestMixedModuleStore(unittest.TestCase):
self.store.update_item(course, self.user_id) self.store.update_item(course, self.user_id)
# now do it for a r/w db # now do it for a r/w db
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID))
problem = self.store.get_item(self.problem_x1a_1) problem = self.store.get_item(self.problem_x1a_1)
# if following raised, then the test is really a noop, change it # if following raised, then the test is really a noop, change it
self.assertNotEqual(problem.max_attempts, 2, "Default changed making test meaningless") self.assertNotEqual(problem.max_attempts, 2, "Default changed making test meaningless")
problem.max_attempts = 2 problem.max_attempts = 2
with check_mongo_calls(mongo_store, max_find, max_send): with check_mongo_calls(max_find, max_send):
problem = self.store.update_item(problem, self.user_id) problem = self.store.update_item(problem, self.user_id)
self.assertEqual(problem.max_attempts, 2, "Update didn't persist") self.assertEqual(problem.max_attempts, 2, "Update didn't persist")
...@@ -434,7 +439,10 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -434,7 +439,10 @@ class TestMixedModuleStore(unittest.TestCase):
component = self.store.publish(component.location, self.user_id) component = self.store.publish(component.location, self.user_id)
self.assertFalse(self.store.has_changes(component)) self.assertFalse(self.store.has_changes(component))
@ddt.data(('draft', 7, 2), ('split', 13, 4)) # TODO: LMS-11220: Document why split find count is 4
# TODO: LMS-11220: Document why draft find count is 8
# TODO: LMS-11220: Document why split send count is 3
@ddt.data(('draft', 8, 2), ('split', 4, 3))
@ddt.unpack @ddt.unpack
def test_delete_item(self, default_ms, max_find, max_send): def test_delete_item(self, default_ms, max_find, max_send):
""" """
...@@ -446,14 +454,16 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -446,14 +454,16 @@ class TestMixedModuleStore(unittest.TestCase):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
self.store.delete_item(self.xml_chapter_location, self.user_id) self.store.delete_item(self.xml_chapter_location, self.user_id)
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) with check_mongo_calls(max_find, max_send):
with check_mongo_calls(mongo_store, max_find, max_send):
self.store.delete_item(self.writable_chapter_location, self.user_id) self.store.delete_item(self.writable_chapter_location, self.user_id)
# verify it's gone # verify it's gone
with self.assertRaises(ItemNotFoundError): with self.assertRaises(ItemNotFoundError):
self.store.get_item(self.writable_chapter_location) self.store.get_item(self.writable_chapter_location)
@ddt.data(('draft', 8, 2), ('split', 13, 4)) # TODO: LMS-11220: Document why draft find count is 9
# TODO: LMS-11220: Document why split send count is 3
@ddt.data(('draft', 9, 2), ('split', 4, 3))
@ddt.unpack @ddt.unpack
def test_delete_private_vertical(self, default_ms, max_find, max_send): def test_delete_private_vertical(self, default_ms, max_find, max_send):
""" """
...@@ -484,8 +494,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -484,8 +494,7 @@ class TestMixedModuleStore(unittest.TestCase):
self.assertIn(vert_loc, course.children) self.assertIn(vert_loc, course.children)
# delete the vertical and ensure the course no longer points to it # delete the vertical and ensure the course no longer points to it
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) with check_mongo_calls(max_find, max_send):
with check_mongo_calls(mongo_store, max_find, max_send):
self.store.delete_item(vert_loc, self.user_id) self.store.delete_item(vert_loc, self.user_id)
course = self.store.get_course(self.course_locations[self.MONGO_COURSEID].course_key, 0) course = self.store.get_course(self.course_locations[self.MONGO_COURSEID].course_key, 0)
if hasattr(private_vert.location, 'version_guid'): if hasattr(private_vert.location, 'version_guid'):
...@@ -499,7 +508,9 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -499,7 +508,9 @@ class TestMixedModuleStore(unittest.TestCase):
self.assertFalse(self.store.has_item(leaf_loc)) self.assertFalse(self.store.has_item(leaf_loc))
self.assertNotIn(vert_loc, course.children) self.assertNotIn(vert_loc, course.children)
@ddt.data(('draft', 4, 1), ('split', 5, 2)) # TODO: LMS-11220: Document why split send count is 2
# TODO: LMS-11220: Document why draft find count is 5
@ddt.data(('draft', 5, 1), ('split', 2, 2))
@ddt.unpack @ddt.unpack
def test_delete_draft_vertical(self, default_ms, max_find, max_send): def test_delete_draft_vertical(self, default_ms, max_find, max_send):
""" """
...@@ -528,24 +539,24 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -528,24 +539,24 @@ class TestMixedModuleStore(unittest.TestCase):
self.store.publish(private_vert.location, self.user_id) self.store.publish(private_vert.location, self.user_id)
private_leaf.display_name = 'change me' private_leaf.display_name = 'change me'
private_leaf = self.store.update_item(private_leaf, self.user_id) private_leaf = self.store.update_item(private_leaf, self.user_id)
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID))
# test succeeds if delete succeeds w/o error # test succeeds if delete succeeds w/o error
with check_mongo_calls(mongo_store, max_find, max_send): with check_mongo_calls(max_find, max_send):
self.store.delete_item(private_leaf.location, self.user_id) self.store.delete_item(private_leaf.location, self.user_id)
@ddt.data(('draft', 2, 0), ('split', 3, 0)) # TODO: LMS-11220: Document why split find count is 5
# TODO: LMS-11220: Document why draft find count is 4
@ddt.data(('draft', 4, 0), ('split', 5, 0))
@ddt.unpack @ddt.unpack
def test_get_courses(self, default_ms, max_find, max_send): def test_get_courses(self, default_ms, max_find, max_send):
self.initdb(default_ms) self.initdb(default_ms)
# we should have 3 total courses across all stores # we should have 3 total courses across all stores
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) with check_mongo_calls(max_find, max_send):
with check_mongo_calls(mongo_store, max_find, max_send):
courses = self.store.get_courses() courses = self.store.get_courses()
course_ids = [course.location for course in courses] course_ids = [course.location for course in courses]
self.assertEqual(len(courses), 3, "Not 3 courses: {}".format(course_ids)) self.assertEqual(len(courses), 3, "Not 3 courses: {}".format(course_ids))
self.assertIn(self.course_locations[self.MONGO_COURSEID], course_ids) self.assertIn(self.course_locations[self.MONGO_COURSEID], course_ids)
self.assertIn(self.course_locations[self.XML_COURSEID1], course_ids) self.assertIn(self.course_locations[self.XML_COURSEID1], course_ids)
self.assertIn(self.course_locations[self.XML_COURSEID2], course_ids) self.assertIn(self.course_locations[self.XML_COURSEID2], course_ids)
with self.store.branch_setting(ModuleStoreEnum.Branch.draft_preferred): with self.store.branch_setting(ModuleStoreEnum.Branch.draft_preferred):
draft_courses = self.store.get_courses(remove_branch=True) draft_courses = self.store.get_courses(remove_branch=True)
...@@ -579,8 +590,9 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -579,8 +590,9 @@ class TestMixedModuleStore(unittest.TestCase):
xml_store.create_course("org", "course", "run", self.user_id) xml_store.create_course("org", "course", "run", self.user_id)
# draft is 2 to compute inheritance # draft is 2 to compute inheritance
# split is 3 b/c it gets the definition to check whether wiki is set # split is 3 (one for the index, one for the definition to check if the wiki is set, and one for the course structure
@ddt.data(('draft', 2, 0), ('split', 3, 0)) # TODO: LMS-11220: Document why split find count is 4
@ddt.data(('draft', 2, 0), ('split', 4, 0))
@ddt.unpack @ddt.unpack
def test_get_course(self, default_ms, max_find, max_send): def test_get_course(self, default_ms, max_find, max_send):
""" """
...@@ -588,8 +600,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -588,8 +600,7 @@ class TestMixedModuleStore(unittest.TestCase):
of getting an item whose scope.content fields are looked at. of getting an item whose scope.content fields are looked at.
""" """
self.initdb(default_ms) self.initdb(default_ms)
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) with check_mongo_calls(max_find, max_send):
with check_mongo_calls(mongo_store, max_find, max_send):
course = self.store.get_item(self.course_locations[self.MONGO_COURSEID]) course = self.store.get_item(self.course_locations[self.MONGO_COURSEID])
self.assertEqual(course.id, self.course_locations[self.MONGO_COURSEID].course_key) self.assertEqual(course.id, self.course_locations[self.MONGO_COURSEID].course_key)
...@@ -598,7 +609,8 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -598,7 +609,8 @@ class TestMixedModuleStore(unittest.TestCase):
# notice this doesn't test getting a public item via draft_preferred which draft would have 2 hits (split # notice this doesn't test getting a public item via draft_preferred which draft would have 2 hits (split
# still only 2) # still only 2)
@ddt.data(('draft', 1, 0), ('split', 2, 0)) # TODO: LMS-11220: Document why draft find count is 2
@ddt.data(('draft', 2, 0), ('split', 2, 0))
@ddt.unpack @ddt.unpack
def test_get_parent_locations(self, default_ms, max_find, max_send): def test_get_parent_locations(self, default_ms, max_find, max_send):
""" """
...@@ -607,10 +619,9 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -607,10 +619,9 @@ class TestMixedModuleStore(unittest.TestCase):
self.initdb(default_ms) self.initdb(default_ms)
self._create_block_hierarchy() self._create_block_hierarchy()
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) with check_mongo_calls(max_find, max_send):
with check_mongo_calls(mongo_store, max_find, max_send):
parent = self.store.get_parent_location(self.problem_x1a_1) parent = self.store.get_parent_location(self.problem_x1a_1)
self.assertEqual(parent, self.vertical_x1a) self.assertEqual(parent, self.vertical_x1a)
parent = self.store.get_parent_location(self.xml_chapter_location) parent = self.store.get_parent_location(self.xml_chapter_location)
self.assertEqual(parent, self.course_locations[self.XML_COURSEID1]) self.assertEqual(parent, self.course_locations[self.XML_COURSEID1])
...@@ -692,7 +703,21 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -692,7 +703,21 @@ class TestMixedModuleStore(unittest.TestCase):
(child_to_delete_location, None, ModuleStoreEnum.RevisionOption.published_only), (child_to_delete_location, None, ModuleStoreEnum.RevisionOption.published_only),
]) ])
@ddt.data(('draft', [10, 3], 0), ('split', [14, 6], 0)) # Mongo reads:
# First location:
# - count problem (1)
# - For each level of ancestors: (5)
# - Count ancestor
# - retrieve ancestor
# - compute inheritable data
# Second location:
# - load vertical
# - load inheritance data
# TODO: LMS-11220: Document why draft send count is 6
# TODO: LMS-11220: Document why draft find count is 18
# TODO: LMS-11220: Document why split find count is 16
@ddt.data(('draft', [19, 6], 0), ('split', [2, 2], 0))
@ddt.unpack @ddt.unpack
def test_path_to_location(self, default_ms, num_finds, num_sends): def test_path_to_location(self, default_ms, num_finds, num_sends):
""" """
...@@ -711,9 +736,8 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -711,9 +736,8 @@ class TestMixedModuleStore(unittest.TestCase):
(course_key, "Chapter_x", None, None)), (course_key, "Chapter_x", None, None)),
) )
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID))
for location, expected in should_work: for location, expected in should_work:
with check_mongo_calls(mongo_store, num_finds.pop(0), num_sends): with check_mongo_calls(num_finds.pop(0), num_sends):
self.assertEqual(path_to_location(self.store, location), expected) self.assertEqual(path_to_location(self.store, location), expected)
not_found = ( not_found = (
...@@ -881,10 +905,9 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -881,10 +905,9 @@ class TestMixedModuleStore(unittest.TestCase):
block_id=location.block_id block_id=location.block_id
) )
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) with check_mongo_calls(max_find, max_send):
with check_mongo_calls(mongo_store, max_find, max_send):
found_orphans = self.store.get_orphans(self.course_locations[self.MONGO_COURSEID].course_key) found_orphans = self.store.get_orphans(self.course_locations[self.MONGO_COURSEID].course_key)
self.assertEqual(set(found_orphans), set(orphan_locations)) self.assertItemsEqual(found_orphans, orphan_locations)
@ddt.data('draft') @ddt.data('draft')
def test_create_item_from_parent_location(self, default_ms): def test_create_item_from_parent_location(self, default_ms):
...@@ -924,7 +947,9 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -924,7 +947,9 @@ class TestMixedModuleStore(unittest.TestCase):
self.assertEqual(self.user_id, block.subtree_edited_by) self.assertEqual(self.user_id, block.subtree_edited_by)
self.assertGreater(datetime.datetime.now(UTC), block.subtree_edited_on) self.assertGreater(datetime.datetime.now(UTC), block.subtree_edited_on)
@ddt.data(('draft', 1, 0), ('split', 1, 0)) # TODO: LMS-11220: Document why split find count is 2
# TODO: LMS-11220: Document why draft find count is 2
@ddt.data(('draft', 2, 0), ('split', 2, 0))
@ddt.unpack @ddt.unpack
def test_get_courses_for_wiki(self, default_ms, max_find, max_send): def test_get_courses_for_wiki(self, default_ms, max_find, max_send):
""" """
...@@ -941,8 +966,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -941,8 +966,7 @@ class TestMixedModuleStore(unittest.TestCase):
self.assertIn(self.course_locations[self.XML_COURSEID2].course_key, wiki_courses) self.assertIn(self.course_locations[self.XML_COURSEID2].course_key, wiki_courses)
# Test Mongo wiki # Test Mongo wiki
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) with check_mongo_calls(max_find, max_send):
with check_mongo_calls(mongo_store, max_find, max_send):
wiki_courses = self.store.get_courses_for_wiki('999') wiki_courses = self.store.get_courses_for_wiki('999')
self.assertEqual(len(wiki_courses), 1) self.assertEqual(len(wiki_courses), 1)
self.assertIn( self.assertIn(
...@@ -953,7 +977,15 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -953,7 +977,15 @@ class TestMixedModuleStore(unittest.TestCase):
self.assertEqual(len(self.store.get_courses_for_wiki('edX.simple.2012_Fall')), 0) self.assertEqual(len(self.store.get_courses_for_wiki('edX.simple.2012_Fall')), 0)
self.assertEqual(len(self.store.get_courses_for_wiki('no_such_wiki')), 0) self.assertEqual(len(self.store.get_courses_for_wiki('no_such_wiki')), 0)
@ddt.data(('draft', 2, 6), ('split', 7, 2)) # Mongo reads:
# - load vertical
# - load vertical children
# - get last error
# Split takes 1 query to read the course structure, deletes all of the entries in memory, and loads the module from an in-memory cache
# Sends:
# - insert structure
# - write index entry
@ddt.data(('draft', 3, 6), ('split', 3, 2))
@ddt.unpack @ddt.unpack
def test_unpublish(self, default_ms, max_find, max_send): def test_unpublish(self, default_ms, max_find, max_send):
""" """
...@@ -971,8 +1003,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -971,8 +1003,7 @@ class TestMixedModuleStore(unittest.TestCase):
self.assertIsNotNone(published_xblock) self.assertIsNotNone(published_xblock)
# unpublish # unpublish
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) with check_mongo_calls(max_find, max_send):
with check_mongo_calls(mongo_store, max_find, max_send):
self.store.unpublish(self.vertical_x1a, self.user_id) self.store.unpublish(self.vertical_x1a, self.user_id)
with self.assertRaises(ItemNotFoundError): with self.assertRaises(ItemNotFoundError):
...@@ -1000,8 +1031,7 @@ class TestMixedModuleStore(unittest.TestCase): ...@@ -1000,8 +1031,7 @@ class TestMixedModuleStore(unittest.TestCase):
# start off as Private # start off as Private
item = self.store.create_child(self.user_id, self.writable_chapter_location, 'problem', 'test_compute_publish_state') item = self.store.create_child(self.user_id, self.writable_chapter_location, 'problem', 'test_compute_publish_state')
item_location = item.location item_location = item.location
mongo_store = self.store._get_modulestore_for_courseid(self._course_key_from_string(self.MONGO_COURSEID)) with check_mongo_calls(max_find, max_send):
with check_mongo_calls(mongo_store, max_find, max_send):
self.assertFalse(self.store.has_published_version(item)) self.assertFalse(self.store.has_published_version(item))
# Private -> Public # Private -> Public
......
...@@ -19,23 +19,35 @@ class TestPublish(SplitWMongoCourseBoostrapper): ...@@ -19,23 +19,35 @@ class TestPublish(SplitWMongoCourseBoostrapper):
# There are 12 created items and 7 parent updates # There are 12 created items and 7 parent updates
# create course: finds: 1 to verify uniqueness, 1 to find parents # create course: finds: 1 to verify uniqueness, 1 to find parents
# sends: 1 to create course, 1 to create overview # sends: 1 to create course, 1 to create overview
with check_mongo_calls(self.draft_mongo, 5, 2): with check_mongo_calls(5, 2):
super(TestPublish, self)._create_course(split=False) # 2 inserts (course and overview) super(TestPublish, self)._create_course(split=False) # 2 inserts (course and overview)
# with bulk will delay all inheritance computations which won't be added into the mongo_calls # with bulk will delay all inheritance computations which won't be added into the mongo_calls
with self.draft_mongo.bulk_write_operations(self.old_course_key): with self.draft_mongo.bulk_operations(self.old_course_key):
# finds: 1 for parent to add child # finds: 1 for parent to add child
# sends: 1 for insert, 1 for parent (add child) # sends: 1 for insert, 1 for parent (add child)
with check_mongo_calls(self.draft_mongo, 1, 2): with check_mongo_calls(1, 2):
self._create_item('chapter', 'Chapter1', {}, {'display_name': 'Chapter 1'}, 'course', 'runid', split=False) self._create_item('chapter', 'Chapter1', {}, {'display_name': 'Chapter 1'}, 'course', 'runid', split=False)
with check_mongo_calls(self.draft_mongo, 2, 2): with check_mongo_calls(2, 2):
self._create_item('chapter', 'Chapter2', {}, {'display_name': 'Chapter 2'}, 'course', 'runid', split=False) self._create_item('chapter', 'Chapter2', {}, {'display_name': 'Chapter 2'}, 'course', 'runid', split=False)
# update info propagation is 2 levels. create looks for draft and then published and then creates # For each vertical (2) created:
with check_mongo_calls(self.draft_mongo, 8, 6): # - load draft
# - load non-draft
# - get last error
# - load parent
# - load inheritable data
with check_mongo_calls(10, 6):
self._create_item('vertical', 'Vert1', {}, {'display_name': 'Vertical 1'}, 'chapter', 'Chapter1', split=False) self._create_item('vertical', 'Vert1', {}, {'display_name': 'Vertical 1'}, 'chapter', 'Chapter1', split=False)
self._create_item('vertical', 'Vert2', {}, {'display_name': 'Vertical 2'}, 'chapter', 'Chapter1', split=False) self._create_item('vertical', 'Vert2', {}, {'display_name': 'Vertical 2'}, 'chapter', 'Chapter1', split=False)
with check_mongo_calls(self.draft_mongo, 20, 12): # For each (4) item created
# - load draft
# - load non-draft
# - get last error
# - load parent
# - load inheritable data
# - load parent
with check_mongo_calls(24, 12):
self._create_item('html', 'Html1', "<p>Goodbye</p>", {'display_name': 'Parented Html'}, 'vertical', 'Vert1', split=False) self._create_item('html', 'Html1', "<p>Goodbye</p>", {'display_name': 'Parented Html'}, 'vertical', 'Vert1', split=False)
self._create_item( self._create_item(
'discussion', 'Discussion1', 'discussion', 'Discussion1',
...@@ -63,7 +75,7 @@ class TestPublish(SplitWMongoCourseBoostrapper): ...@@ -63,7 +75,7 @@ class TestPublish(SplitWMongoCourseBoostrapper):
split=False split=False
) )
with check_mongo_calls(self.draft_mongo, 0, 2): with check_mongo_calls(0, 2):
# 2 finds b/c looking for non-existent parents # 2 finds b/c looking for non-existent parents
self._create_item('static_tab', 'staticuno', "<p>tab</p>", {'display_name': 'Tab uno'}, None, None, split=False) self._create_item('static_tab', 'staticuno', "<p>tab</p>", {'display_name': 'Tab uno'}, None, None, split=False)
self._create_item('course_info', 'updates', "<ol><li><h2>Sep 22</h2><p>test</p></li></ol>", {}, None, None, split=False) self._create_item('course_info', 'updates', "<ol><li><h2>Sep 22</h2><p>test</p></li></ol>", {}, None, None, split=False)
...@@ -76,10 +88,11 @@ class TestPublish(SplitWMongoCourseBoostrapper): ...@@ -76,10 +88,11 @@ class TestPublish(SplitWMongoCourseBoostrapper):
vert_location = self.old_course_key.make_usage_key('vertical', block_id='Vert1') vert_location = self.old_course_key.make_usage_key('vertical', block_id='Vert1')
item = self.draft_mongo.get_item(vert_location, 2) item = self.draft_mongo.get_item(vert_location, 2)
# Vert1 has 3 children; so, publishes 4 nodes which may mean 4 inserts & 1 bulk remove # Vert1 has 3 children; so, publishes 4 nodes which may mean 4 inserts & 1 bulk remove
# TODO: LMS-11220: Document why find count is 25
# 25-June-2014 find calls are 19. Probably due to inheritance recomputation? # 25-June-2014 find calls are 19. Probably due to inheritance recomputation?
# 02-July-2014 send calls are 7. 5 from above, plus 2 for updating subtree edit info for Chapter1 and course # 02-July-2014 send calls are 7. 5 from above, plus 2 for updating subtree edit info for Chapter1 and course
# find calls are 22. 19 from above, plus 3 for finding the parent of Vert1, Chapter1, and course # find calls are 22. 19 from above, plus 3 for finding the parent of Vert1, Chapter1, and course
with check_mongo_calls(self.draft_mongo, 22, 7): with check_mongo_calls(25, 7):
self.draft_mongo.publish(item.location, self.user_id) self.draft_mongo.publish(item.location, self.user_id)
# verify status # verify status
......
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
Test split modulestore w/o using any django stuff. Test split modulestore w/o using any django stuff.
""" """
import datetime import datetime
import random
import re
import unittest import unittest
import uuid import uuid
from importlib import import_module from importlib import import_module
from path import path from path import path
import re
import random
from xmodule.course_module import CourseDescriptor from xmodule.course_module import CourseDescriptor
from xmodule.modulestore import ModuleStoreEnum from xmodule.modulestore import ModuleStoreEnum
...@@ -492,13 +492,14 @@ class SplitModuleTest(unittest.TestCase): ...@@ -492,13 +492,14 @@ class SplitModuleTest(unittest.TestCase):
new_ele_dict[spec['id']] = child new_ele_dict[spec['id']] = child
course = split_store.persist_xblock_dag(course, revision['user_id']) course = split_store.persist_xblock_dag(course, revision['user_id'])
# publish "testx.wonderful" # publish "testx.wonderful"
source_course = CourseLocator(org="testx", course="wonderful", run="run", branch=BRANCH_NAME_DRAFT)
to_publish = BlockUsageLocator( to_publish = BlockUsageLocator(
CourseLocator(org="testx", course="wonderful", run="run", branch=BRANCH_NAME_DRAFT), source_course,
block_type='course', block_type='course',
block_id="head23456" block_id="head23456"
) )
destination = CourseLocator(org="testx", course="wonderful", run="run", branch=BRANCH_NAME_PUBLISHED) destination = CourseLocator(org="testx", course="wonderful", run="run", branch=BRANCH_NAME_PUBLISHED)
split_store.copy("test@edx.org", to_publish, destination, [to_publish], None) split_store.copy("test@edx.org", source_course, destination, [to_publish], None)
def setUp(self): def setUp(self):
self.user_id = random.getrandbits(32) self.user_id = random.getrandbits(32)
...@@ -607,13 +608,6 @@ class SplitModuleCourseTests(SplitModuleTest): ...@@ -607,13 +608,6 @@ class SplitModuleCourseTests(SplitModuleTest):
_verify_published_course(modulestore().get_courses(branch=BRANCH_NAME_PUBLISHED)) _verify_published_course(modulestore().get_courses(branch=BRANCH_NAME_PUBLISHED))
def test_search_qualifiers(self):
# query w/ search criteria
courses = modulestore().get_courses(branch=BRANCH_NAME_DRAFT, qualifiers={'org': 'testx'})
self.assertEqual(len(courses), 2)
self.assertIsNotNone(self.findByIdInResult(courses, "head12345"))
self.assertIsNotNone(self.findByIdInResult(courses, "head23456"))
def test_has_course(self): def test_has_course(self):
''' '''
Test the various calling forms for has_course Test the various calling forms for has_course
...@@ -985,7 +979,7 @@ class TestItemCrud(SplitModuleTest): ...@@ -985,7 +979,7 @@ class TestItemCrud(SplitModuleTest):
# grab link to course to ensure new versioning works # grab link to course to ensure new versioning works
locator = CourseLocator(org='testx', course='GreekHero', run="run", branch=BRANCH_NAME_DRAFT) locator = CourseLocator(org='testx', course='GreekHero', run="run", branch=BRANCH_NAME_DRAFT)
premod_course = modulestore().get_course(locator) premod_course = modulestore().get_course(locator)
premod_history = modulestore().get_course_history_info(premod_course.location) premod_history = modulestore().get_course_history_info(locator)
# add minimal one w/o a parent # add minimal one w/o a parent
category = 'sequential' category = 'sequential'
new_module = modulestore().create_item( new_module = modulestore().create_item(
...@@ -999,7 +993,7 @@ class TestItemCrud(SplitModuleTest): ...@@ -999,7 +993,7 @@ class TestItemCrud(SplitModuleTest):
current_course = modulestore().get_course(locator) current_course = modulestore().get_course(locator)
self.assertEqual(new_module.location.version_guid, current_course.location.version_guid) self.assertEqual(new_module.location.version_guid, current_course.location.version_guid)
history_info = modulestore().get_course_history_info(current_course.location) history_info = modulestore().get_course_history_info(current_course.location.course_key)
self.assertEqual(history_info['previous_version'], premod_course.location.version_guid) self.assertEqual(history_info['previous_version'], premod_course.location.version_guid)
self.assertEqual(history_info['original_version'], premod_history['original_version']) self.assertEqual(history_info['original_version'], premod_history['original_version'])
self.assertEqual(history_info['edited_by'], "user123") self.assertEqual(history_info['edited_by'], "user123")
...@@ -1112,84 +1106,82 @@ class TestItemCrud(SplitModuleTest): ...@@ -1112,84 +1106,82 @@ class TestItemCrud(SplitModuleTest):
chapter = modulestore().get_item(chapter_locator) chapter = modulestore().get_item(chapter_locator)
self.assertIn(problem_locator, version_agnostic(chapter.children)) self.assertIn(problem_locator, version_agnostic(chapter.children))
def test_create_continue_version(self): def test_create_bulk_operations(self):
""" """
Test create_item using the continue_version flag Test create_item using bulk_operations
""" """
# start transaction w/ simple creation # start transaction w/ simple creation
user = random.getrandbits(32) user = random.getrandbits(32)
new_course = modulestore().create_course('test_org', 'test_transaction', 'test_run', user, BRANCH_NAME_DRAFT) course_key = CourseLocator('test_org', 'test_transaction', 'test_run')
new_course_locator = new_course.id with modulestore().bulk_operations(course_key):
index_history_info = modulestore().get_course_history_info(new_course.location) new_course = modulestore().create_course('test_org', 'test_transaction', 'test_run', user, BRANCH_NAME_DRAFT)
course_block_prev_version = new_course.previous_version new_course_locator = new_course.id
course_block_update_version = new_course.update_version index_history_info = modulestore().get_course_history_info(new_course.location.course_key)
self.assertIsNotNone(new_course_locator.version_guid, "Want to test a definite version") course_block_prev_version = new_course.previous_version
versionless_course_locator = new_course_locator.version_agnostic() course_block_update_version = new_course.update_version
self.assertIsNotNone(new_course_locator.version_guid, "Want to test a definite version")
# positive simple case: no force, add chapter versionless_course_locator = new_course_locator.version_agnostic()
new_ele = modulestore().create_child(
user, new_course.location, 'chapter', # positive simple case: no force, add chapter
fields={'display_name': 'chapter 1'}, new_ele = modulestore().create_child(
continue_version=True
)
# version info shouldn't change
self.assertEqual(new_ele.update_version, course_block_update_version)
self.assertEqual(new_ele.update_version, new_ele.location.version_guid)
refetch_course = modulestore().get_course(versionless_course_locator)
self.assertEqual(refetch_course.location.version_guid, new_course.location.version_guid)
self.assertEqual(refetch_course.previous_version, course_block_prev_version)
self.assertEqual(refetch_course.update_version, course_block_update_version)
refetch_index_history_info = modulestore().get_course_history_info(refetch_course.location)
self.assertEqual(refetch_index_history_info, index_history_info)
self.assertIn(new_ele.location.version_agnostic(), version_agnostic(refetch_course.children))
# try to create existing item
with self.assertRaises(DuplicateItemError):
_fail = modulestore().create_child(
user, new_course.location, 'chapter', user, new_course.location, 'chapter',
block_id=new_ele.location.block_id, fields={'display_name': 'chapter 1'},
fields={'display_name': 'chapter 2'},
continue_version=True
) )
# version info shouldn't change
self.assertEqual(new_ele.update_version, course_block_update_version)
self.assertEqual(new_ele.update_version, new_ele.location.version_guid)
refetch_course = modulestore().get_course(versionless_course_locator)
self.assertEqual(refetch_course.location.version_guid, new_course.location.version_guid)
self.assertEqual(refetch_course.previous_version, course_block_prev_version)
self.assertEqual(refetch_course.update_version, course_block_update_version)
refetch_index_history_info = modulestore().get_course_history_info(refetch_course.location.course_key)
self.assertEqual(refetch_index_history_info, index_history_info)
self.assertIn(new_ele.location.version_agnostic(), version_agnostic(refetch_course.children))
# try to create existing item
with self.assertRaises(DuplicateItemError):
_fail = modulestore().create_child(
user, new_course.location, 'chapter',
block_id=new_ele.location.block_id,
fields={'display_name': 'chapter 2'},
)
# start a new transaction # start a new transaction
new_ele = modulestore().create_child( with modulestore().bulk_operations(course_key):
user, new_course.location, 'chapter', new_ele = modulestore().create_child(
fields={'display_name': 'chapter 2'},
continue_version=False
)
transaction_guid = new_ele.location.version_guid
# ensure force w/ continue gives exception
with self.assertRaises(VersionConflictError):
_fail = modulestore().create_child(
user, new_course.location, 'chapter', user, new_course.location, 'chapter',
fields={'display_name': 'chapter 2'}, fields={'display_name': 'chapter 2'},
force=True, continue_version=True
) )
transaction_guid = new_ele.location.version_guid
# ensure force w/ continue gives exception
with self.assertRaises(VersionConflictError):
_fail = modulestore().create_child(
user, new_course.location, 'chapter',
fields={'display_name': 'chapter 2'},
force=True
)
# ensure trying to continue the old one gives exception # ensure trying to continue the old one gives exception
with self.assertRaises(VersionConflictError): with self.assertRaises(VersionConflictError):
_fail = modulestore().create_child( _fail = modulestore().create_child(
user, new_course.location, 'chapter', user, new_course.location, 'chapter',
fields={'display_name': 'chapter 3'}, fields={'display_name': 'chapter 3'},
continue_version=True )
)
# add new child to old parent in continued (leave off version_guid) # add new child to old parent in continued (leave off version_guid)
course_module_locator = new_course.location.version_agnostic() course_module_locator = new_course.location.version_agnostic()
new_ele = modulestore().create_child( new_ele = modulestore().create_child(
user, course_module_locator, 'chapter', user, course_module_locator, 'chapter',
fields={'display_name': 'chapter 4'}, fields={'display_name': 'chapter 4'},
continue_version=True )
) self.assertNotEqual(new_ele.update_version, course_block_update_version)
self.assertNotEqual(new_ele.update_version, course_block_update_version) self.assertEqual(new_ele.location.version_guid, transaction_guid)
self.assertEqual(new_ele.location.version_guid, transaction_guid)
# check children, previous_version # check children, previous_version
refetch_course = modulestore().get_course(versionless_course_locator) refetch_course = modulestore().get_course(versionless_course_locator)
self.assertIn(new_ele.location.version_agnostic(), version_agnostic(refetch_course.children)) self.assertIn(new_ele.location.version_agnostic(), version_agnostic(refetch_course.children))
self.assertEqual(refetch_course.previous_version, course_block_update_version) self.assertEqual(refetch_course.previous_version, course_block_update_version)
self.assertEqual(refetch_course.update_version, transaction_guid) self.assertEqual(refetch_course.update_version, transaction_guid)
def test_update_metadata(self): def test_update_metadata(self):
""" """
...@@ -1221,7 +1213,7 @@ class TestItemCrud(SplitModuleTest): ...@@ -1221,7 +1213,7 @@ class TestItemCrud(SplitModuleTest):
current_course = modulestore().get_course(locator.course_key) current_course = modulestore().get_course(locator.course_key)
self.assertEqual(updated_problem.location.version_guid, current_course.location.version_guid) self.assertEqual(updated_problem.location.version_guid, current_course.location.version_guid)
history_info = modulestore().get_course_history_info(current_course.location) history_info = modulestore().get_course_history_info(current_course.location.course_key)
self.assertEqual(history_info['previous_version'], pre_version_guid) self.assertEqual(history_info['previous_version'], pre_version_guid)
self.assertEqual(history_info['edited_by'], self.user_id) self.assertEqual(history_info['edited_by'], self.user_id)
...@@ -1396,16 +1388,13 @@ class TestCourseCreation(SplitModuleTest): ...@@ -1396,16 +1388,13 @@ class TestCourseCreation(SplitModuleTest):
) )
new_locator = new_course.location new_locator = new_course.location
# check index entry # check index entry
index_info = modulestore().get_course_index_info(new_locator) index_info = modulestore().get_course_index_info(new_locator.course_key)
self.assertEqual(index_info['org'], 'test_org') self.assertEqual(index_info['org'], 'test_org')
self.assertEqual(index_info['edited_by'], 'create_user') self.assertEqual(index_info['edited_by'], 'create_user')
# check structure info # check structure info
structure_info = modulestore().get_course_history_info(new_locator) structure_info = modulestore().get_course_history_info(new_locator.course_key)
# TODO LMS-11098 "Implement bulk_write in Split" self.assertEqual(structure_info['original_version'], index_info['versions'][BRANCH_NAME_DRAFT])
# Right now, these assertions will not pass because create_course calls update_item, self.assertIsNone(structure_info['previous_version'])
# resulting in two versions. Bulk updater will fix this.
# self.assertEqual(structure_info['original_version'], index_info['versions'][BRANCH_NAME_DRAFT])
# self.assertIsNone(structure_info['previous_version'])
self.assertEqual(structure_info['edited_by'], 'create_user') self.assertEqual(structure_info['edited_by'], 'create_user')
# check the returned course object # check the returned course object
...@@ -1433,7 +1422,7 @@ class TestCourseCreation(SplitModuleTest): ...@@ -1433,7 +1422,7 @@ class TestCourseCreation(SplitModuleTest):
self.assertEqual(new_draft.edited_by, 'test@edx.org') self.assertEqual(new_draft.edited_by, 'test@edx.org')
self.assertEqual(new_draft_locator.version_guid, original_index['versions'][BRANCH_NAME_DRAFT]) self.assertEqual(new_draft_locator.version_guid, original_index['versions'][BRANCH_NAME_DRAFT])
# however the edited_by and other meta fields on course_index will be this one # however the edited_by and other meta fields on course_index will be this one
new_index = modulestore().get_course_index_info(new_draft_locator) new_index = modulestore().get_course_index_info(new_draft_locator.course_key)
self.assertEqual(new_index['edited_by'], 'leech_master') self.assertEqual(new_index['edited_by'], 'leech_master')
new_published_locator = new_draft_locator.course_key.for_branch(BRANCH_NAME_PUBLISHED) new_published_locator = new_draft_locator.course_key.for_branch(BRANCH_NAME_PUBLISHED)
...@@ -1483,7 +1472,7 @@ class TestCourseCreation(SplitModuleTest): ...@@ -1483,7 +1472,7 @@ class TestCourseCreation(SplitModuleTest):
self.assertEqual(new_draft.edited_by, 'leech_master') self.assertEqual(new_draft.edited_by, 'leech_master')
self.assertNotEqual(new_draft_locator.version_guid, original_index['versions'][BRANCH_NAME_DRAFT]) self.assertNotEqual(new_draft_locator.version_guid, original_index['versions'][BRANCH_NAME_DRAFT])
# however the edited_by and other meta fields on course_index will be this one # however the edited_by and other meta fields on course_index will be this one
new_index = modulestore().get_course_index_info(new_draft_locator) new_index = modulestore().get_course_index_info(new_draft_locator.course_key)
self.assertEqual(new_index['edited_by'], 'leech_master') self.assertEqual(new_index['edited_by'], 'leech_master')
self.assertEqual(new_draft.display_name, fields['display_name']) self.assertEqual(new_draft.display_name, fields['display_name'])
self.assertDictEqual( self.assertDictEqual(
...@@ -1504,13 +1493,13 @@ class TestCourseCreation(SplitModuleTest): ...@@ -1504,13 +1493,13 @@ class TestCourseCreation(SplitModuleTest):
head_course = modulestore().get_course(locator) head_course = modulestore().get_course(locator)
versions = course_info['versions'] versions = course_info['versions']
versions[BRANCH_NAME_DRAFT] = head_course.previous_version versions[BRANCH_NAME_DRAFT] = head_course.previous_version
modulestore().update_course_index(course_info) modulestore().update_course_index(None, course_info)
course = modulestore().get_course(locator) course = modulestore().get_course(locator)
self.assertEqual(course.location.version_guid, versions[BRANCH_NAME_DRAFT]) self.assertEqual(course.location.version_guid, versions[BRANCH_NAME_DRAFT])
# an allowed but not recommended way to publish a course # an allowed but not recommended way to publish a course
versions[BRANCH_NAME_PUBLISHED] = versions[BRANCH_NAME_DRAFT] versions[BRANCH_NAME_PUBLISHED] = versions[BRANCH_NAME_DRAFT]
modulestore().update_course_index(course_info) modulestore().update_course_index(None, course_info)
course = modulestore().get_course(locator.for_branch(BRANCH_NAME_PUBLISHED)) course = modulestore().get_course(locator.for_branch(BRANCH_NAME_PUBLISHED))
self.assertEqual(course.location.version_guid, versions[BRANCH_NAME_DRAFT]) self.assertEqual(course.location.version_guid, versions[BRANCH_NAME_DRAFT])
...@@ -1715,6 +1704,7 @@ class TestPublish(SplitModuleTest): ...@@ -1715,6 +1704,7 @@ class TestPublish(SplitModuleTest):
dest_cursor += 1 dest_cursor += 1
self.assertEqual(dest_cursor, len(dest_children)) self.assertEqual(dest_cursor, len(dest_children))
class TestSchema(SplitModuleTest): class TestSchema(SplitModuleTest):
""" """
Test the db schema (and possibly eventually migrations?) Test the db schema (and possibly eventually migrations?)
...@@ -1736,6 +1726,7 @@ class TestSchema(SplitModuleTest): ...@@ -1736,6 +1726,7 @@ class TestSchema(SplitModuleTest):
"{0.name} has records with wrong schema_version".format(collection) "{0.name} has records with wrong schema_version".format(collection)
) )
#=========================================== #===========================================
def modulestore(): def modulestore():
""" """
......
import copy
import ddt
import unittest
from bson.objectid import ObjectId
from mock import MagicMock, Mock, call
from xmodule.modulestore.split_mongo.split import BulkWriteMixin
from xmodule.modulestore.split_mongo.mongo_connection import MongoConnection
from opaque_keys.edx.locator import CourseLocator, BlockUsageLocator, VersionTree, LocalId
class TestBulkWriteMixin(unittest.TestCase):
def setUp(self):
super(TestBulkWriteMixin, self).setUp()
self.bulk = BulkWriteMixin()
self.bulk.SCHEMA_VERSION = 1
self.clear_cache = self.bulk._clear_cache = Mock(name='_clear_cache')
self.conn = self.bulk.db_connection = MagicMock(name='db_connection', spec=MongoConnection)
self.conn.get_course_index.return_value = {'initial': 'index'}
self.course_key = CourseLocator('org', 'course', 'run-a')
self.course_key_b = CourseLocator('org', 'course', 'run-b')
self.structure = {'this': 'is', 'a': 'structure', '_id': ObjectId()}
self.index_entry = {'this': 'is', 'an': 'index'}
def assertConnCalls(self, *calls):
self.assertEqual(list(calls), self.conn.mock_calls)
def assertCacheNotCleared(self):
self.assertFalse(self.clear_cache.called)
class TestBulkWriteMixinPreviousTransaction(TestBulkWriteMixin):
"""
Verify that opening and closing a transaction doesn't affect later behaviour.
"""
def setUp(self):
super(TestBulkWriteMixinPreviousTransaction, self).setUp()
self.bulk._begin_bulk_operation(self.course_key)
self.bulk.insert_course_index(self.course_key, MagicMock('prev-index-entry'))
self.bulk.update_structure(self.course_key, {'this': 'is', 'the': 'previous structure', '_id': ObjectId()})
self.bulk._end_bulk_operation(self.course_key)
self.conn.reset_mock()
self.clear_cache.reset_mock()
@ddt.ddt
class TestBulkWriteMixinClosed(TestBulkWriteMixin):
"""
Tests of the bulk write mixin when bulk operations aren't active.
"""
@ddt.data('deadbeef1234' * 2, u'deadbeef1234' * 2, ObjectId())
def test_no_bulk_read_structure(self, version_guid):
# Reading a structure when no bulk operation is active should just call
# through to the db_connection
result = self.bulk.get_structure(self.course_key, version_guid)
self.assertConnCalls(call.get_structure(self.course_key.as_object_id(version_guid)))
self.assertEqual(result, self.conn.get_structure.return_value)
self.assertCacheNotCleared()
def test_no_bulk_write_structure(self):
# Writing a structure when no bulk operation is active should just
# call through to the db_connection. It should also clear the
# system cache
self.bulk.update_structure(self.course_key, self.structure)
self.assertConnCalls(call.upsert_structure(self.structure))
self.clear_cache.assert_called_once_with(self.structure['_id'])
@ddt.data(True, False)
def test_no_bulk_read_index(self, ignore_case):
# Reading a course index when no bulk operation is active should just call
# through to the db_connection
result = self.bulk.get_course_index(self.course_key, ignore_case=ignore_case)
self.assertConnCalls(call.get_course_index(self.course_key, ignore_case))
self.assertEqual(result, self.conn.get_course_index.return_value)
self.assertCacheNotCleared()
def test_no_bulk_write_index(self):
# Writing a course index when no bulk operation is active should just call
# through to the db_connection
self.bulk.insert_course_index(self.course_key, self.index_entry)
self.assertConnCalls(call.insert_course_index(self.index_entry))
self.assertCacheNotCleared()
def test_out_of_order_end(self):
# Calling _end_bulk_operation without a corresponding _begin...
# is a noop
self.bulk._end_bulk_operation(self.course_key)
def test_write_new_index_on_close(self):
self.conn.get_course_index.return_value = None
self.bulk._begin_bulk_operation(self.course_key)
self.conn.reset_mock()
self.bulk.insert_course_index(self.course_key, self.index_entry)
self.assertConnCalls()
self.bulk._end_bulk_operation(self.course_key)
self.conn.insert_course_index.assert_called_once_with(self.index_entry)
def test_write_updated_index_on_close(self):
old_index = {'this': 'is', 'an': 'old index'}
self.conn.get_course_index.return_value = old_index
self.bulk._begin_bulk_operation(self.course_key)
self.conn.reset_mock()
self.bulk.insert_course_index(self.course_key, self.index_entry)
self.assertConnCalls()
self.bulk._end_bulk_operation(self.course_key)
self.conn.update_course_index.assert_called_once_with(self.index_entry, from_index=old_index)
def test_write_structure_on_close(self):
self.conn.get_course_index.return_value = None
self.bulk._begin_bulk_operation(self.course_key)
self.conn.reset_mock()
self.bulk.update_structure(self.course_key, self.structure)
self.assertConnCalls()
self.bulk._end_bulk_operation(self.course_key)
self.assertConnCalls(call.upsert_structure(self.structure))
def test_write_multiple_structures_on_close(self):
self.conn.get_course_index.return_value = None
self.bulk._begin_bulk_operation(self.course_key)
self.conn.reset_mock()
self.bulk.update_structure(self.course_key.replace(branch='a'), self.structure)
other_structure = {'another': 'structure', '_id': ObjectId()}
self.bulk.update_structure(self.course_key.replace(branch='b'), other_structure)
self.assertConnCalls()
self.bulk._end_bulk_operation(self.course_key)
self.assertItemsEqual(
[call.upsert_structure(self.structure), call.upsert_structure(other_structure)],
self.conn.mock_calls
)
def test_write_index_and_structure_on_close(self):
original_index = {'versions': {}}
self.conn.get_course_index.return_value = copy.deepcopy(original_index)
self.bulk._begin_bulk_operation(self.course_key)
self.conn.reset_mock()
self.bulk.update_structure(self.course_key, self.structure)
self.bulk.insert_course_index(self.course_key, {'versions': {self.course_key.branch: self.structure['_id']}})
self.assertConnCalls()
self.bulk._end_bulk_operation(self.course_key)
self.assertConnCalls(
call.upsert_structure(self.structure),
call.update_course_index(
{'versions': {self.course_key.branch: self.structure['_id']}},
from_index=original_index
)
)
def test_write_index_and_multiple_structures_on_close(self):
original_index = {'versions': {'a': ObjectId(), 'b': ObjectId()}}
self.conn.get_course_index.return_value = copy.deepcopy(original_index)
self.bulk._begin_bulk_operation(self.course_key)
self.conn.reset_mock()
self.bulk.update_structure(self.course_key.replace(branch='a'), self.structure)
other_structure = {'another': 'structure', '_id': ObjectId()}
self.bulk.update_structure(self.course_key.replace(branch='b'), other_structure)
self.bulk.insert_course_index(self.course_key, {'versions': {'a': self.structure['_id'], 'b': other_structure['_id']}})
self.bulk._end_bulk_operation(self.course_key)
self.assertItemsEqual(
[
call.upsert_structure(self.structure),
call.upsert_structure(other_structure),
call.update_course_index(
{'versions': {'a': self.structure['_id'], 'b': other_structure['_id']}},
from_index=original_index
)
],
self.conn.mock_calls
)
def test_version_structure_creates_new_version(self):
self.assertNotEquals(
self.bulk.version_structure(self.course_key, self.structure, 'user_id')['_id'],
self.structure['_id']
)
class TestBulkWriteMixinClosedAfterPrevTransaction(TestBulkWriteMixinClosed, TestBulkWriteMixinPreviousTransaction):
"""
Test that operations on with a closed transaction aren't affected by a previously executed transaction
"""
pass
@ddt.ddt
class TestBulkWriteMixinFindMethods(TestBulkWriteMixin):
"""
Tests of BulkWriteMixin methods for finding many structures or indexes
"""
def test_no_bulk_find_matching_course_indexes(self):
branch = Mock(name='branch')
search_targets = MagicMock(name='search_targets')
self.conn.find_matching_course_indexes.return_value = [Mock(name='result')]
result = self.bulk.find_matching_course_indexes(branch, search_targets)
self.assertConnCalls(call.find_matching_course_indexes(branch, search_targets))
self.assertEqual(result, self.conn.find_matching_course_indexes.return_value)
self.assertCacheNotCleared()
@ddt.data(
(None, None, [], []),
(
'draft',
None,
[{'versions': {'draft': '123'}}],
[
{'versions': {'published': '123'}},
{}
],
),
(
'draft',
{'f1': 'v1'},
[{'versions': {'draft': '123'}, 'search_targets': {'f1': 'v1'}}],
[
{'versions': {'draft': '123'}, 'search_targets': {'f1': 'value2'}},
{'versions': {'published': '123'}, 'search_targets': {'f1': 'v1'}},
{'search_targets': {'f1': 'v1'}},
{'versions': {'draft': '123'}},
],
),
(
None,
{'f1': 'v1'},
[
{'versions': {'draft': '123'}, 'search_targets': {'f1': 'v1'}},
{'versions': {'published': '123'}, 'search_targets': {'f1': 'v1'}},
{'search_targets': {'f1': 'v1'}},
],
[
{'versions': {'draft': '123'}, 'search_targets': {'f1': 'v2'}},
{'versions': {'draft': '123'}, 'search_targets': {'f2': 'v1'}},
{'versions': {'draft': '123'}},
],
),
(
None,
{'f1': 'v1', 'f2': 2},
[
{'search_targets': {'f1': 'v1', 'f2': 2}},
{'search_targets': {'f1': 'v1', 'f2': 2}},
],
[
{'versions': {'draft': '123'}, 'search_targets': {'f1': 'v1'}},
{'search_targets': {'f1': 'v1'}},
{'versions': {'draft': '123'}, 'search_targets': {'f1': 'v2'}},
{'versions': {'draft': '123'}},
],
),
)
@ddt.unpack
def test_find_matching_course_indexes(self, branch, search_targets, matching, unmatching):
db_indexes = [Mock(name='from_db')]
for n, index in enumerate(matching + unmatching):
course_key = CourseLocator('org', 'course', 'run{}'.format(n))
self.bulk._begin_bulk_operation(course_key)
self.bulk.insert_course_index(course_key, index)
expected = matching + db_indexes
self.conn.find_matching_course_indexes.return_value = db_indexes
result = self.bulk.find_matching_course_indexes(branch, search_targets)
self.assertItemsEqual(result, expected)
for item in unmatching:
self.assertNotIn(item, result)
def test_no_bulk_find_structures_by_id(self):
ids = [Mock(name='id')]
self.conn.find_structures_by_id.return_value = [MagicMock(name='result')]
result = self.bulk.find_structures_by_id(ids)
self.assertConnCalls(call.find_structures_by_id(ids))
self.assertEqual(result, self.conn.find_structures_by_id.return_value)
self.assertCacheNotCleared()
@ddt.data(
([], [], []),
([1, 2, 3], [1, 2], [1, 2]),
([1, 2, 3], [1], [1, 2]),
([1, 2, 3], [], [1, 2]),
)
@ddt.unpack
def test_find_structures_by_id(self, search_ids, active_ids, db_ids):
db_structure = lambda _id: {'db': 'structure', '_id': _id}
active_structure = lambda _id: {'active': 'structure', '_id': _id}
db_structures = [db_structure(_id) for _id in db_ids if _id not in active_ids]
for n, _id in enumerate(active_ids):
course_key = CourseLocator('org', 'course', 'run{}'.format(n))
self.bulk._begin_bulk_operation(course_key)
self.bulk.update_structure(course_key, active_structure(_id))
self.conn.find_structures_by_id.return_value = db_structures
results = self.bulk.find_structures_by_id(search_ids)
self.conn.find_structures_by_id.assert_called_once_with(list(set(search_ids) - set(active_ids)))
for _id in active_ids:
if _id in search_ids:
self.assertIn(active_structure(_id), results)
else:
self.assertNotIn(active_structure(_id), results)
for _id in db_ids:
if _id in search_ids and _id not in active_ids:
self.assertIn(db_structure(_id), results)
else:
self.assertNotIn(db_structure(_id), results)
def test_no_bulk_find_structures_derived_from(self):
ids = [Mock(name='id')]
self.conn.find_structures_derived_from.return_value = [MagicMock(name='result')]
result = self.bulk.find_structures_derived_from(ids)
self.assertConnCalls(call.find_structures_derived_from(ids))
self.assertEqual(result, self.conn.find_structures_derived_from.return_value)
self.assertCacheNotCleared()
@ddt.data(
# Test values are:
# - previous_versions to search for
# - documents in the cache with $previous_version.$_id
# - documents in the db with $previous_version.$_id
([], [], []),
(['1', '2', '3'], ['1.a', '1.b', '2.c'], ['1.a', '2.c']),
(['1', '2', '3'], ['1.a'], ['1.a', '2.c']),
(['1', '2', '3'], [], ['1.a', '2.c']),
(['1', '2', '3'], ['4.d'], ['1.a', '2.c']),
)
@ddt.unpack
def test_find_structures_derived_from(self, search_ids, active_ids, db_ids):
def db_structure(_id):
previous, _, current = _id.partition('.')
return {'db': 'structure', 'previous_version': previous, '_id': current}
def active_structure(_id):
previous, _, current = _id.partition('.')
return {'active': 'structure', 'previous_version': previous, '_id': current}
db_structures = [db_structure(_id) for _id in db_ids]
active_structures = []
for n, _id in enumerate(active_ids):
course_key = CourseLocator('org', 'course', 'run{}'.format(n))
self.bulk._begin_bulk_operation(course_key)
structure = active_structure(_id)
self.bulk.update_structure(course_key, structure)
active_structures.append(structure)
self.conn.find_structures_derived_from.return_value = db_structures
results = self.bulk.find_structures_derived_from(search_ids)
self.conn.find_structures_derived_from.assert_called_once_with(search_ids)
for structure in active_structures:
if structure['previous_version'] in search_ids:
self.assertIn(structure, results)
else:
self.assertNotIn(structure, results)
for structure in db_structures:
if (
structure['previous_version'] in search_ids and # We're searching for this document
not any(active.endswith(structure['_id']) for active in active_ids) # This document doesn't match any active _ids
):
self.assertIn(structure, results)
else:
self.assertNotIn(structure, results)
def test_no_bulk_find_ancestor_structures(self):
original_version = Mock(name='original_version')
block_id = Mock(name='block_id')
self.conn.find_ancestor_structures.return_value = [MagicMock(name='result')]
result = self.bulk.find_ancestor_structures(original_version, block_id)
self.assertConnCalls(call.find_ancestor_structures(original_version, block_id))
self.assertEqual(result, self.conn.find_ancestor_structures.return_value)
self.assertCacheNotCleared()
@ddt.data(
# Test values are:
# - original_version
# - block_id
# - matching documents in the cache
# - non-matching documents in the cache
# - expected documents returned from the db
# - unexpected documents returned from the db
('ov', 'bi', [{'original_version': 'ov', 'blocks': {'bi': {'edit_info': {'update_version': 'foo'}}}}], [], [], []),
('ov', 'bi', [{'original_version': 'ov', 'blocks': {'bi': {'edit_info': {'update_version': 'foo'}}}, '_id': 'foo'}], [], [], [{'_id': 'foo'}]),
('ov', 'bi', [], [{'blocks': {'bi': {'edit_info': {'update_version': 'foo'}}}}], [], []),
('ov', 'bi', [], [{'original_version': 'ov'}], [], []),
('ov', 'bi', [], [], [{'original_version': 'ov', 'blocks': {'bi': {'edit_info': {'update_version': 'foo'}}}}], []),
(
'ov',
'bi',
[{'original_version': 'ov', 'blocks': {'bi': {'edit_info': {'update_version': 'foo'}}}}],
[],
[{'original_version': 'ov', 'blocks': {'bi': {'edit_info': {'update_version': 'bar'}}}}],
[]
),
)
@ddt.unpack
def test_find_ancestor_structures(self, original_version, block_id, active_match, active_unmatch, db_match, db_unmatch):
for structure in active_match + active_unmatch + db_match + db_unmatch:
structure.setdefault('_id', ObjectId())
for n, structure in enumerate(active_match + active_unmatch):
course_key = CourseLocator('org', 'course', 'run{}'.format(n))
self.bulk._begin_bulk_operation(course_key)
self.bulk.update_structure(course_key, structure)
self.conn.find_ancestor_structures.return_value = db_match + db_unmatch
results = self.bulk.find_ancestor_structures(original_version, block_id)
self.conn.find_ancestor_structures.assert_called_once_with(original_version, block_id)
self.assertItemsEqual(active_match + db_match, results)
@ddt.ddt
class TestBulkWriteMixinOpen(TestBulkWriteMixin):
"""
Tests of the bulk write mixin when bulk write operations are open
"""
def setUp(self):
super(TestBulkWriteMixinOpen, self).setUp()
self.bulk._begin_bulk_operation(self.course_key)
@ddt.data('deadbeef1234' * 2, u'deadbeef1234' * 2, ObjectId())
def test_read_structure_without_write_from_db(self, version_guid):
# Reading a structure before it's been written (while in bulk operation mode)
# returns the structure from the database
result = self.bulk.get_structure(self.course_key, version_guid)
self.assertEquals(self.conn.get_structure.call_count, 1)
self.assertEqual(result, self.conn.get_structure.return_value)
self.assertCacheNotCleared()
@ddt.data('deadbeef1234' * 2, u'deadbeef1234' * 2, ObjectId())
def test_read_structure_without_write_only_reads_once(self, version_guid):
# Reading the same structure multiple times shouldn't hit the database
# more than once
for _ in xrange(2):
result = self.bulk.get_structure(self.course_key, version_guid)
self.assertEquals(self.conn.get_structure.call_count, 1)
self.assertEqual(result, self.conn.get_structure.return_value)
self.assertCacheNotCleared()
@ddt.data('deadbeef1234' * 2, u'deadbeef1234' * 2, ObjectId())
def test_read_structure_after_write_no_db(self, version_guid):
# Reading a structure that's already been written shouldn't hit the db at all
self.structure['_id'] = version_guid
self.bulk.update_structure(self.course_key, self.structure)
result = self.bulk.get_structure(self.course_key, version_guid)
self.assertEquals(self.conn.get_structure.call_count, 0)
self.assertEqual(result, self.structure)
@ddt.data('deadbeef1234' * 2, u'deadbeef1234' * 2, ObjectId())
def test_read_structure_after_write_after_read(self, version_guid):
# Reading a structure that's been updated after being pulled from the db should
# still get the updated value
self.structure['_id'] = version_guid
self.bulk.get_structure(self.course_key, version_guid)
self.bulk.update_structure(self.course_key, self.structure)
result = self.bulk.get_structure(self.course_key, version_guid)
self.assertEquals(self.conn.get_structure.call_count, 1)
self.assertEqual(result, self.structure)
@ddt.data(True, False)
def test_read_index_without_write_from_db(self, ignore_case):
# Reading the index without writing to it should pull from the database
result = self.bulk.get_course_index(self.course_key, ignore_case=ignore_case)
self.assertEquals(self.conn.get_course_index.call_count, 1)
self.assertEquals(self.conn.get_course_index.return_value, result)
@ddt.data(True, False)
def test_read_index_without_write_only_reads_once(self, ignore_case):
# Reading the index multiple times should only result in one read from
# the database
for _ in xrange(2):
result = self.bulk.get_course_index(self.course_key, ignore_case=ignore_case)
self.assertEquals(self.conn.get_course_index.call_count, 1)
self.assertEquals(self.conn.get_course_index.return_value, result)
@ddt.data(True, False)
def test_read_index_after_write(self, ignore_case):
# Reading the index after a write still should hit the database once to fetch the
# initial index, and should return the written index_entry
self.bulk.insert_course_index(self.course_key, self.index_entry)
result = self.bulk.get_course_index(self.course_key, ignore_case=ignore_case)
self.assertEquals(self.conn.get_course_index.call_count, 1)
self.assertEquals(self.index_entry, result)
def test_read_index_ignore_case(self):
# Reading using ignore case should find an already written entry with a different case
self.bulk.insert_course_index(self.course_key, self.index_entry)
result = self.bulk.get_course_index(
self.course_key.replace(
org=self.course_key.org.upper(),
course=self.course_key.course.title(),
run=self.course_key.run.upper()
),
ignore_case=True
)
self.assertEquals(self.conn.get_course_index.call_count, 1)
self.assertEquals(self.index_entry, result)
def test_version_structure_creates_new_version_before_read(self):
self.assertNotEquals(
self.bulk.version_structure(self.course_key, self.structure, 'user_id')['_id'],
self.structure['_id']
)
def test_version_structure_creates_new_version_after_read(self):
self.conn.get_structure.return_value = copy.deepcopy(self.structure)
self.bulk.get_structure(self.course_key, self.structure['_id'])
self.assertNotEquals(
self.bulk.version_structure(self.course_key, self.structure, 'user_id')['_id'],
self.structure['_id']
)
def test_copy_branch_versions(self):
# Directly updating an index so that the draft branch points to the published index
# version should work, and should only persist a single structure
self.maxDiff = None
published_structure = {'published': 'structure', '_id': ObjectId()}
self.bulk.update_structure(self.course_key, published_structure)
index = {'versions': {'published': published_structure['_id']}}
self.bulk.insert_course_index(self.course_key, index)
index_copy = copy.deepcopy(index)
index_copy['versions']['draft'] = index['versions']['published']
self.bulk.update_course_index(self.course_key, index_copy)
self.bulk._end_bulk_operation(self.course_key)
self.conn.upsert_structure.assert_called_once_with(published_structure)
self.conn.update_course_index.assert_called_once_with(index_copy, from_index=self.conn.get_course_index.return_value)
self.conn.get_course_index.assert_called_once_with(self.course_key)
class TestBulkWriteMixinOpenAfterPrevTransaction(TestBulkWriteMixinOpen, TestBulkWriteMixinPreviousTransaction):
"""
Test that operations on with an open transaction aren't affected by a previously executed transaction
"""
pass
...@@ -206,7 +206,7 @@ def import_from_xml( ...@@ -206,7 +206,7 @@ def import_from_xml(
) )
continue continue
with store.bulk_write_operations(dest_course_id): with store.bulk_operations(dest_course_id):
source_course = xml_module_store.get_course(course_key) source_course = xml_module_store.get_course(course_key)
# STEP 1: find and import course module # STEP 1: find and import course module
course, course_data_path = _import_course_module( course, course_data_path = _import_course_module(
...@@ -607,7 +607,7 @@ def _import_course_draft( ...@@ -607,7 +607,7 @@ def _import_course_draft(
_import_module(descriptor) _import_module(descriptor)
except Exception: except Exception:
logging.exception('There while importing draft descriptor %s', descriptor) logging.exception('while importing draft descriptor %s', descriptor)
def allowed_metadata_by_category(category): def allowed_metadata_by_category(category):
......
...@@ -287,9 +287,9 @@ class CourseComparisonTest(unittest.TestCase): ...@@ -287,9 +287,9 @@ class CourseComparisonTest(unittest.TestCase):
self.assertEqual(expected_item.has_children, actual_item.has_children) self.assertEqual(expected_item.has_children, actual_item.has_children)
if expected_item.has_children: if expected_item.has_children:
expected_children = [ expected_children = [
(course1_item_child.location.block_type, course1_item_child.location.block_id) (expected_item_child.location.block_type, expected_item_child.location.block_id)
# get_children() rather than children to strip privates from public parents # get_children() rather than children to strip privates from public parents
for course1_item_child in expected_item.get_children() for expected_item_child in expected_item.get_children()
] ]
actual_children = [ actual_children = [
(item_child.location.block_type, item_child.location.block_id) (item_child.location.block_type, item_child.location.block_id)
......
...@@ -21,6 +21,7 @@ from django.contrib.auth.models import User ...@@ -21,6 +21,7 @@ from django.contrib.auth.models import User
from xblock.runtime import KeyValueStore from xblock.runtime import KeyValueStore
from xblock.exceptions import KeyValueMultiSaveError, InvalidScopeError from xblock.exceptions import KeyValueMultiSaveError, InvalidScopeError
from xblock.fields import Scope, UserScope from xblock.fields import Scope, UserScope
from xmodule.modulestore.django import modulestore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -109,7 +110,8 @@ class FieldDataCache(object): ...@@ -109,7 +110,8 @@ class FieldDataCache(object):
return descriptors return descriptors
descriptors = get_child_descriptors(descriptor, depth, descriptor_filter) with modulestore().bulk_operations(descriptor.location.course_key):
descriptors = get_child_descriptors(descriptor, depth, descriptor_filter)
return FieldDataCache(descriptors, course_id, user, select_for_update) return FieldDataCache(descriptors, course_id, user, select_for_update)
......
...@@ -326,13 +326,15 @@ class TestTOC(ModuleStoreTestCase): ...@@ -326,13 +326,15 @@ class TestTOC(ModuleStoreTestCase):
self.request = factory.get(chapter_url) self.request = factory.get(chapter_url)
self.request.user = UserFactory() self.request.user = UserFactory()
self.modulestore = self.store._get_modulestore_for_courseid(self.course_key) self.modulestore = self.store._get_modulestore_for_courseid(self.course_key)
with check_mongo_calls(self.modulestore, num_finds, num_sends): self.toy_course = self.store.get_course(self.toy_loc, depth=2)
self.toy_course = self.store.get_course(self.toy_loc, depth=2) with check_mongo_calls(num_finds, num_sends):
self.field_data_cache = FieldDataCache.cache_for_descriptor_descendents( self.field_data_cache = FieldDataCache.cache_for_descriptor_descendents(
self.toy_loc, self.request.user, self.toy_course, depth=2) self.toy_loc, self.request.user, self.toy_course, depth=2
)
@ddt.data((ModuleStoreEnum.Type.mongo, 3, 0), (ModuleStoreEnum.Type.split, 7, 0)) # TODO: LMS-11220: Document why split find count is 21
@ddt.data((ModuleStoreEnum.Type.mongo, 1, 0), (ModuleStoreEnum.Type.split, 5, 0))
@ddt.unpack @ddt.unpack
def test_toc_toy_from_chapter(self, default_ms, num_finds, num_sends): def test_toc_toy_from_chapter(self, default_ms, num_finds, num_sends):
with self.store.default_store(default_ms): with self.store.default_store(default_ms):
...@@ -352,14 +354,15 @@ class TestTOC(ModuleStoreTestCase): ...@@ -352,14 +354,15 @@ class TestTOC(ModuleStoreTestCase):
'format': '', 'due': None, 'active': False}], 'format': '', 'due': None, 'active': False}],
'url_name': 'secret:magic', 'display_name': 'secret:magic'}]) 'url_name': 'secret:magic', 'display_name': 'secret:magic'}])
with check_mongo_calls(self.modulestore, 0, 0): with check_mongo_calls(0, 0):
actual = render.toc_for_course( actual = render.toc_for_course(
self.request.user, self.request, self.toy_course, self.chapter, None, self.field_data_cache self.request.user, self.request, self.toy_course, self.chapter, None, self.field_data_cache
) )
for toc_section in expected: for toc_section in expected:
self.assertIn(toc_section, actual) self.assertIn(toc_section, actual)
@ddt.data((ModuleStoreEnum.Type.mongo, 3, 0), (ModuleStoreEnum.Type.split, 7, 0)) # TODO: LMS-11220: Document why split find count is 21
@ddt.data((ModuleStoreEnum.Type.mongo, 1, 0), (ModuleStoreEnum.Type.split, 5, 0))
@ddt.unpack @ddt.unpack
def test_toc_toy_from_section(self, default_ms, num_finds, num_sends): def test_toc_toy_from_section(self, default_ms, num_finds, num_sends):
with self.store.default_store(default_ms): with self.store.default_store(default_ms):
......
...@@ -91,7 +91,10 @@ def test_lib(options): ...@@ -91,7 +91,10 @@ def test_lib(options):
} }
if test_id: if test_id:
lib = '/'.join(test_id.split('/')[0:3]) if '/' in test_id:
lib = '/'.join(test_id.split('/')[0:3])
else:
lib = 'common/lib/' + test_id.split('.')[0]
opts['test_id'] = test_id opts['test_id'] = test_id
lib_tests = [suites.LibTestSuite(lib, **opts)] lib_tests = [suites.LibTestSuite(lib, **opts)]
else: else:
......
...@@ -4,3 +4,4 @@ psutil==1.2.1 ...@@ -4,3 +4,4 @@ psutil==1.2.1
lazy==1.1 lazy==1.1
path.py==3.0.1 path.py==3.0.1
watchdog==0.7.1 watchdog==0.7.1
python-memcached
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment