Commit 1fbb4b8b by Nimisha Asthagiri

Update BlockStructure API

parent f016647f
...@@ -17,10 +17,14 @@ class BlockSerializer(serializers.Serializer): # pylint: disable=abstract-metho ...@@ -17,10 +17,14 @@ class BlockSerializer(serializers.Serializer): # pylint: disable=abstract-metho
Get the field value requested. The field may be an XBlock field, a Get the field value requested. The field may be an XBlock field, a
transformer block field, or an entire tranformer block data dict. transformer block field, or an entire tranformer block data dict.
""" """
value = None
if transformer is None: if transformer is None:
value = self.context['block_structure'].get_xblock_field(block_key, field_name) value = self.context['block_structure'].get_xblock_field(block_key, field_name)
elif field_name is None: elif field_name is None:
value = self.context['block_structure'].get_transformer_block_data(block_key, transformer) try:
value = self.context['block_structure'].get_transformer_block_data(block_key, transformer).fields
except KeyError:
pass
else: else:
value = self.context['block_structure'].get_transformer_block_field(block_key, transformer, field_name) value = self.context['block_structure'].get_transformer_block_field(block_key, transformer, field_name)
......
...@@ -41,10 +41,11 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase): ...@@ -41,10 +41,11 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase):
block_types_to_count=['video'], block_types_to_count=['video'],
requested_student_view_data=['video'], requested_student_view_data=['video'],
) )
self.transformers = BlockStructureTransformers(COURSE_BLOCK_ACCESS_TRANSFORMERS + [blocks_api_transformer])
self.block_structure = get_course_blocks( self.block_structure = get_course_blocks(
self.user, self.user,
self.course.location, self.course.location,
BlockStructureTransformers(COURSE_BLOCK_ACCESS_TRANSFORMERS + [blocks_api_transformer]), self.transformers,
) )
self.serializer_context = { self.serializer_context = {
'request': MagicMock(), 'request': MagicMock(),
...@@ -92,7 +93,7 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase): ...@@ -92,7 +93,7 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase):
{ {
'id', 'type', 'lms_web_url', 'student_view_url', 'id', 'type', 'lms_web_url', 'student_view_url',
'display_name', 'graded', 'display_name', 'graded',
'block_counts', 'student_view_multi_device', 'student_view_multi_device',
'lti_url', 'lti_url',
'visible_to_staff_only', 'visible_to_staff_only',
}, },
...@@ -108,6 +109,13 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase): ...@@ -108,6 +109,13 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase):
self.assertIn('student_view_multi_device', serialized_block) self.assertIn('student_view_multi_device', serialized_block)
self.assertTrue(serialized_block['student_view_multi_device']) self.assertTrue(serialized_block['student_view_multi_device'])
# chapters with video should have block_counts
if serialized_block['type'] == 'chapter':
if serialized_block['display_name'] not in ('poll_test', 'handout_container'):
self.assertIn('block_counts', serialized_block)
else:
self.assertNotIn('block_counts', serialized_block)
def create_staff_context(self): def create_staff_context(self):
""" """
Create staff user and course blocks accessible by that user Create staff user and course blocks accessible by that user
...@@ -119,7 +127,7 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase): ...@@ -119,7 +127,7 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase):
block_structure = get_course_blocks( block_structure = get_course_blocks(
staff_user, staff_user,
self.course.location, self.course.location,
BlockStructureTransformers(COURSE_BLOCK_ACCESS_TRANSFORMERS), self.transformers,
) )
return { return {
'request': MagicMock(), 'request': MagicMock(),
...@@ -156,12 +164,14 @@ class TestBlockSerializer(TestBlockSerializerBase): ...@@ -156,12 +164,14 @@ class TestBlockSerializer(TestBlockSerializerBase):
serializer = self.create_serializer() serializer = self.create_serializer()
for serialized_block in serializer.data: for serialized_block in serializer.data:
self.assert_basic_block(serialized_block['id'], serialized_block) self.assert_basic_block(serialized_block['id'], serialized_block)
self.assertEquals(len(serializer.data), 28)
def test_additional_requested_fields(self): def test_additional_requested_fields(self):
self.add_additional_requested_fields() self.add_additional_requested_fields()
serializer = self.create_serializer() serializer = self.create_serializer()
for serialized_block in serializer.data: for serialized_block in serializer.data:
self.assert_extended_block(serialized_block) self.assert_extended_block(serialized_block)
self.assertEquals(len(serializer.data), 28)
def test_staff_fields(self): def test_staff_fields(self):
""" """
...@@ -173,6 +183,7 @@ class TestBlockSerializer(TestBlockSerializerBase): ...@@ -173,6 +183,7 @@ class TestBlockSerializer(TestBlockSerializerBase):
for serialized_block in serializer.data: for serialized_block in serializer.data:
self.assert_extended_block(serialized_block) self.assert_extended_block(serialized_block)
self.assert_staff_fields(serialized_block) self.assert_staff_fields(serialized_block)
self.assertEquals(len(serializer.data), 29)
class TestBlockDictSerializer(TestBlockSerializerBase): class TestBlockDictSerializer(TestBlockSerializerBase):
...@@ -200,12 +211,14 @@ class TestBlockDictSerializer(TestBlockSerializerBase): ...@@ -200,12 +211,14 @@ class TestBlockDictSerializer(TestBlockSerializerBase):
for block_key_string, serialized_block in serializer.data['blocks'].iteritems(): for block_key_string, serialized_block in serializer.data['blocks'].iteritems():
self.assertEquals(serialized_block['id'], block_key_string) self.assertEquals(serialized_block['id'], block_key_string)
self.assert_basic_block(block_key_string, serialized_block) self.assert_basic_block(block_key_string, serialized_block)
self.assertEquals(len(serializer.data['blocks']), 28)
def test_additional_requested_fields(self): def test_additional_requested_fields(self):
self.add_additional_requested_fields() self.add_additional_requested_fields()
serializer = self.create_serializer() serializer = self.create_serializer()
for serialized_block in serializer.data['blocks'].itervalues(): for serialized_block in serializer.data['blocks'].itervalues():
self.assert_extended_block(serialized_block) self.assert_extended_block(serialized_block)
self.assertEquals(len(serializer.data['blocks']), 28)
def test_staff_fields(self): def test_staff_fields(self):
""" """
...@@ -217,3 +230,4 @@ class TestBlockDictSerializer(TestBlockSerializerBase): ...@@ -217,3 +230,4 @@ class TestBlockDictSerializer(TestBlockSerializerBase):
for serialized_block in serializer.data['blocks'].itervalues(): for serialized_block in serializer.data['blocks'].itervalues():
self.assert_extended_block(serialized_block) self.assert_extended_block(serialized_block)
self.assert_staff_fields(serialized_block) self.assert_staff_fields(serialized_block)
self.assertEquals(len(serializer.data['blocks']), 29)
...@@ -38,13 +38,13 @@ class TestBlockCountsTransformer(ModuleStoreTestCase): ...@@ -38,13 +38,13 @@ class TestBlockCountsTransformer(ModuleStoreTestCase):
) )
# verify count of chapters # verify count of chapters
self.assertEquals(block_counts_for_course['chapter'], 2) self.assertEquals(block_counts_for_course.chapter, 2)
# verify count of problems # verify count of problems
self.assertEquals(block_counts_for_course['problem'], 6) self.assertEquals(block_counts_for_course.problem, 6)
self.assertEquals(block_counts_for_chapter_x['problem'], 3) self.assertEquals(block_counts_for_chapter_x.problem, 3)
# verify other block types are not counted # verify other block types are not counted
for block_type in ['course', 'html', 'video']: for block_type in ['course', 'html', 'video']:
self.assertNotIn(block_type, block_counts_for_course) self.assertFalse(hasattr(block_counts_for_course, block_type))
self.assertNotIn(block_type, block_counts_for_chapter_x) self.assertFalse(hasattr(block_counts_for_chapter_x, block_type))
...@@ -72,6 +72,9 @@ class BlockStructure(object): ...@@ -72,6 +72,9 @@ class BlockStructure(object):
""" """
return self.get_block_keys() return self.get_block_keys()
def __len__(self):
return len(self._block_relations)
#--- Block structure relation methods ---# #--- Block structure relation methods ---#
def get_parents(self, usage_key): def get_parents(self, usage_key):
...@@ -149,6 +152,7 @@ class BlockStructure(object): ...@@ -149,6 +152,7 @@ class BlockStructure(object):
self, self,
filter_func=None, filter_func=None,
yield_descendants_of_unyielded=False, yield_descendants_of_unyielded=False,
start_node=None,
): ):
""" """
Performs a topological sort of the block structure and yields Performs a topological sort of the block structure and yields
...@@ -163,7 +167,7 @@ class BlockStructure(object): ...@@ -163,7 +167,7 @@ class BlockStructure(object):
traverse_topologically method. traverse_topologically method.
""" """
return traverse_topologically( return traverse_topologically(
start_node=self.root_block_usage_key, start_node=start_node or self.root_block_usage_key,
get_parents=self.get_parents, get_parents=self.get_parents,
get_children=self.get_children, get_children=self.get_children,
filter_func=filter_func, filter_func=filter_func,
...@@ -173,6 +177,7 @@ class BlockStructure(object): ...@@ -173,6 +177,7 @@ class BlockStructure(object):
def post_order_traversal( def post_order_traversal(
self, self,
filter_func=None, filter_func=None,
start_node=None,
): ):
""" """
Performs a post-order sort of the block structure and yields Performs a post-order sort of the block structure and yields
...@@ -187,7 +192,7 @@ class BlockStructure(object): ...@@ -187,7 +192,7 @@ class BlockStructure(object):
traverse_post_order method. traverse_post_order method.
""" """
return traverse_post_order( return traverse_post_order(
start_node=self.root_block_usage_key, start_node=start_node or self.root_block_usage_key,
get_children=self.get_children, get_children=self.get_children,
filter_func=filter_func, filter_func=filter_func,
) )
...@@ -267,19 +272,119 @@ class BlockStructure(object): ...@@ -267,19 +272,119 @@ class BlockStructure(object):
block_relations[usage_key] = _BlockRelations() block_relations[usage_key] = _BlockRelations()
class _BlockData(object): class FieldData(object):
""" """
Data structure to encapsulate collected data for a single block. Data structure to encapsulate collected fields.
"""
def class_field_names(self):
""" """
Returns list of names of fields that are defined directly
on the class. Can be overridden by subclasses. All other
fields are assumed to be stored in the self.fields dict.
"""
return ['fields']
def __init__(self): def __init__(self):
# Map of xblock field name to the field's value for this block. # Map of field name to the field's value for this block.
# dict {string: any picklable type} # dict {string: any picklable type}
self.xblock_fields = {} self.fields = {}
def __getattr__(self, field_name):
if self._is_own_field(field_name):
return super(FieldData, self).__getattr__(field_name)
try:
return self.fields[field_name]
except KeyError:
raise AttributeError("Field {0} does not exist".format(field_name))
def __setattr__(self, field_name, field_value):
if self._is_own_field(field_name):
return super(FieldData, self).__setattr__(field_name, field_value)
else:
self.fields[field_name] = field_value
def __delattr__(self, field_name):
if self._is_own_field(field_name):
return super(FieldData, self).__delattr__(field_name)
else:
delattr(self.fields, field_name)
def _is_own_field(self, field_name):
"""
Returns whether the given field_name is the name of an
actual field of this class.
"""
return field_name in self.class_field_names()
class TransformerData(FieldData):
"""
Data structure to encapsulate collected data for a transformer.
"""
pass
class TransformerDataMap(dict):
"""
A map of Transformer name to its corresponding TransformerData.
The map can be accessed by the Transformer's name or the
Transformer's class type.
"""
def __getitem__(self, key):
key = self._translate_key(key)
return dict.__getitem__(self, key)
def __setitem__(self, key, value):
key = self._translate_key(key)
dict.__setitem__(self, key, value)
def __delitem__(self, key):
key = self._translate_key(key)
dict.__delitem__(self, key)
# Map of transformer name to the transformer's data for this def get_or_create(self, key):
# block. """
# defaultdict {string: dict} Returns the TransformerData associated with the given
self.transformer_data = defaultdict(dict) key. If not found, creates and returns a new TransformerData
and maps it to the given key.
"""
try:
return self[key]
except KeyError:
new_transformer_data = TransformerData()
self[key] = new_transformer_data
return new_transformer_data
def _translate_key(self, key):
"""
Allows the given key to be either the transformer's class or name,
always returning the transformer's name. This allows
TransformerDataMap to be accessed in either of the following ways:
map[TransformerClass] or
map['transformer_name']
"""
try:
return key.name()
except AttributeError:
return key
class BlockData(FieldData):
"""
Data structure to encapsulate collected data for a single block.
"""
def class_field_names(self):
return super(BlockData, self).class_field_names() + ['location', 'transformer_data']
def __init__(self, usage_key):
super(BlockData, self).__init__()
# Location (or usage key) of the block.
self.location = usage_key
# Map of transformer name to its block-specific data.
self.transformer_data = TransformerDataMap()
class BlockStructureBlockData(BlockStructure): class BlockStructureBlockData(BlockStructure):
...@@ -292,12 +397,31 @@ class BlockStructureBlockData(BlockStructure): ...@@ -292,12 +397,31 @@ class BlockStructureBlockData(BlockStructure):
# Map of a block's usage key to its collected data, including # Map of a block's usage key to its collected data, including
# its xBlock fields and block-specific transformer data. # its xBlock fields and block-specific transformer data.
# defaultdict {UsageKey: _BlockData} # dict {UsageKey: BlockData}
self._block_data_map = defaultdict(_BlockData) self._block_data_map = {}
# Map of a transformer's name to its non-block-specific data. # Map of a transformer's name to its non-block-specific data.
# defaultdict {string: dict} self.transformer_data = TransformerDataMap()
self._transformer_data = defaultdict(dict)
def iteritems(self):
"""
Returns iterator of (UsageKey, BlockData) pairs for all
blocks in the BlockStructure.
"""
return self._block_data_map.iteritems()
def itervalues(self):
"""
Returns iterator of BlockData for all blocks in the
BlockStructure.
"""
return self._block_data_map.itervalues()
def __getitem__(self, usage_key):
"""
Returns the BlockData associated with the given key.
"""
return self._block_data_map.get(usage_key)
def get_xblock_field(self, usage_key, field_name, default=None): def get_xblock_field(self, usage_key, field_name, default=None):
""" """
...@@ -316,7 +440,7 @@ class BlockStructureBlockData(BlockStructure): ...@@ -316,7 +440,7 @@ class BlockStructureBlockData(BlockStructure):
not found. not found.
""" """
block_data = self._block_data_map.get(usage_key) block_data = self._block_data_map.get(usage_key)
return block_data.xblock_fields.get(field_name, default) if block_data else default return getattr(block_data, field_name, default) if block_data else default
def get_transformer_data(self, transformer, key, default=None): def get_transformer_data(self, transformer, key, default=None):
""" """
...@@ -330,7 +454,10 @@ class BlockStructureBlockData(BlockStructure): ...@@ -330,7 +454,10 @@ class BlockStructureBlockData(BlockStructure):
key (string) - A dictionary key to the transformer's data key (string) - A dictionary key to the transformer's data
that is requested. that is requested.
""" """
return self._transformer_data.get(transformer.name(), {}).get(key, default) try:
return getattr(self.transformer_data[transformer], key, default)
except KeyError:
return default
def set_transformer_data(self, transformer, key, value): def set_transformer_data(self, transformer, key, value):
""" """
...@@ -346,7 +473,23 @@ class BlockStructureBlockData(BlockStructure): ...@@ -346,7 +473,23 @@ class BlockStructureBlockData(BlockStructure):
value (any picklable type) - The value to associate with the value (any picklable type) - The value to associate with the
given key for the given transformer's data. given key for the given transformer's data.
""" """
self._transformer_data[transformer.name()][key] = value setattr(self.transformer_data.get_or_create(transformer), key, value)
def get_transformer_block_data(self, usage_key, transformer):
"""
Returns the TransformerData for the given
transformer for the block identified by the given usage_key.
Raises KeyError if not found.
Arguments:
usage_key (UsageKey) - Usage key of the block whose
transformer data is requested.
transformer (BlockStructureTransformer) - The transformer
whose dictionary data is requested.
"""
return self._block_data_map[usage_key].transformer_data[transformer]
def get_transformer_block_field(self, usage_key, transformer, key, default=None): def get_transformer_block_field(self, usage_key, transformer, key, default=None):
""" """
...@@ -367,8 +510,11 @@ class BlockStructureBlockData(BlockStructure): ...@@ -367,8 +510,11 @@ class BlockStructureBlockData(BlockStructure):
default (any type) - The value to return if a dictionary default (any type) - The value to return if a dictionary
entry is not found. entry is not found.
""" """
try:
transformer_data = self.get_transformer_block_data(usage_key, transformer) transformer_data = self.get_transformer_block_data(usage_key, transformer)
return transformer_data.get(key, default) except KeyError:
return default
return getattr(transformer_data, key, default)
def set_transformer_block_field(self, usage_key, transformer, key, value): def set_transformer_block_field(self, usage_key, transformer, key, value):
""" """
...@@ -388,30 +534,11 @@ class BlockStructureBlockData(BlockStructure): ...@@ -388,30 +534,11 @@ class BlockStructureBlockData(BlockStructure):
given key for the given transformer's data for the given key for the given transformer's data for the
requested block. requested block.
""" """
self._block_data_map[usage_key].transformer_data[transformer.name()][key] = value setattr(
self._get_or_create_block(usage_key).transformer_data.get_or_create(transformer),
def get_transformer_block_data(self, usage_key, transformer): key,
""" value,
Returns the entire transformer data dict for the given )
transformer for the block identified by the given usage_key;
returns an empty dict {} if not found.
Arguments:
usage_key (UsageKey) - Usage key of the block whose
transformer data is requested.
transformer (BlockStructureTransformer) - The transformer
whose dictionary data is requested.
key (string) - A dictionary key to the transformer's data
that is requested.
"""
default = {}
block_data = self._block_data_map.get(usage_key)
if not block_data:
return default
else:
return block_data.transformer_data.get(transformer.name(), default)
def remove_transformer_block_field(self, usage_key, transformer, key): def remove_transformer_block_field(self, usage_key, transformer, key):
""" """
...@@ -425,8 +552,11 @@ class BlockStructureBlockData(BlockStructure): ...@@ -425,8 +552,11 @@ class BlockStructureBlockData(BlockStructure):
transformer (BlockStructureTransformer) - The transformer transformer (BlockStructureTransformer) - The transformer
whose data entry is to be deleted. whose data entry is to be deleted.
""" """
try:
transformer_block_data = self.get_transformer_block_data(usage_key, transformer) transformer_block_data = self.get_transformer_block_data(usage_key, transformer)
transformer_block_data.pop(key, None) delattr(transformer_block_data, key)
except (AttributeError, KeyError):
pass
def remove_block(self, usage_key, keep_descendants): def remove_block(self, usage_key, keep_descendants):
""" """
...@@ -527,6 +657,19 @@ class BlockStructureBlockData(BlockStructure): ...@@ -527,6 +657,19 @@ class BlockStructureBlockData(BlockStructure):
raise TransformerException('VERSION attribute is not set on transformer {0}.', transformer.name()) raise TransformerException('VERSION attribute is not set on transformer {0}.', transformer.name())
self.set_transformer_data(transformer, TRANSFORMER_VERSION_KEY, transformer.VERSION) self.set_transformer_data(transformer, TRANSFORMER_VERSION_KEY, transformer.VERSION)
def _get_or_create_block(self, usage_key):
"""
Returns the BlockData associated with the given usage_key.
If not found, creates and returns a new BlockData and
maps it to the given key.
"""
try:
return self._block_data_map[usage_key]
except KeyError:
block_data = BlockData(usage_key)
self._block_data_map[usage_key] = block_data
return block_data
class BlockStructureModulestoreData(BlockStructureBlockData): class BlockStructureModulestoreData(BlockStructureBlockData):
""" """
...@@ -599,23 +742,19 @@ class BlockStructureModulestoreData(BlockStructureBlockData): ...@@ -599,23 +742,19 @@ class BlockStructureModulestoreData(BlockStructureBlockData):
Iterates through all instantiated xBlocks that were added and Iterates through all instantiated xBlocks that were added and
collects all xBlock fields that were requested. collects all xBlock fields that were requested.
""" """
if not self._requested_xblock_fields:
return
for xblock_usage_key, xblock in self._xblock_map.iteritems(): for xblock_usage_key, xblock in self._xblock_map.iteritems():
block_data = self._get_or_create_block(xblock_usage_key)
for field_name in self._requested_xblock_fields: for field_name in self._requested_xblock_fields:
self._set_xblock_field(xblock_usage_key, xblock, field_name) self._set_xblock_field(block_data, xblock, field_name)
def _set_xblock_field(self, usage_key, xblock, field_name): def _set_xblock_field(self, block_data, xblock, field_name):
""" """
Updates the given block's xBlock fields data with the xBlock Updates the given block's xBlock fields data with the xBlock
value for the given field name. value for the given field name.
Arguments: Arguments:
usage_key (UsageKey) - Usage key of the given xBlock. This block_data (BlockData) - A BlockStructure BlockData
value is passed in separately as opposed to retrieving object.
it from the given xBlock since this interface is
agnostic to and decoupled from the xBlock interface.
xblock (XBlock) - An instantiated XBlock object whose xblock (XBlock) - An instantiated XBlock object whose
field is being accessed and collected for later field is being accessed and collected for later
...@@ -625,4 +764,4 @@ class BlockStructureModulestoreData(BlockStructureBlockData): ...@@ -625,4 +764,4 @@ class BlockStructureModulestoreData(BlockStructureBlockData):
being collected and stored. being collected and stored.
""" """
if hasattr(xblock, field_name): if hasattr(xblock, field_name):
self._block_data_map[usage_key].xblock_fields[field_name] = getattr(xblock, field_name) setattr(block_data, field_name, getattr(xblock, field_name))
...@@ -40,7 +40,7 @@ class BlockStructureCache(object): ...@@ -40,7 +40,7 @@ class BlockStructureCache(object):
""" """
data_to_cache = ( data_to_cache = (
block_structure._block_relations, block_structure._block_relations,
block_structure._transformer_data, block_structure.transformer_data,
block_structure._block_data_map, block_structure._block_data_map,
) )
zp_data_to_cache = zpickle(data_to_cache) zp_data_to_cache = zpickle(data_to_cache)
...@@ -99,7 +99,7 @@ class BlockStructureCache(object): ...@@ -99,7 +99,7 @@ class BlockStructureCache(object):
block_relations, transformer_data, block_data_map = zunpickle(zp_data_from_cache) block_relations, transformer_data, block_data_map = zunpickle(zp_data_from_cache)
block_structure = BlockStructureModulestoreData(root_block_usage_key) block_structure = BlockStructureModulestoreData(root_block_usage_key)
block_structure._block_relations = block_relations block_structure._block_relations = block_relations
block_structure._transformer_data = transformer_data block_structure.transformer_data = transformer_data
block_structure._block_data_map = block_data_map block_structure._block_data_map = block_data_map
return block_structure return block_structure
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
Top-level module for the Block Structure framework with a class for managing Top-level module for the Block Structure framework with a class for managing
BlockStructures. BlockStructures.
""" """
from contextlib import contextmanager
from .cache import BlockStructureCache from .cache import BlockStructureCache
from .factory import BlockStructureFactory from .factory import BlockStructureFactory
from .exceptions import UsageKeyNotInBlockStructure from .exceptions import UsageKeyNotInBlockStructure
...@@ -87,6 +89,7 @@ class BlockStructureManager(object): ...@@ -87,6 +89,7 @@ class BlockStructureManager(object):
) )
cache_miss = block_structure is None cache_miss = block_structure is None
if cache_miss or BlockStructureTransformers.is_collected_outdated(block_structure): if cache_miss or BlockStructureTransformers.is_collected_outdated(block_structure):
with self._bulk_operations():
block_structure = BlockStructureFactory.create_from_modulestore( block_structure = BlockStructureFactory.create_from_modulestore(
self.root_block_usage_key, self.root_block_usage_key,
self.modulestore self.modulestore
...@@ -111,3 +114,15 @@ class BlockStructureManager(object): ...@@ -111,3 +114,15 @@ class BlockStructureManager(object):
root block key. root block key.
""" """
self.block_structure_cache.delete(self.root_block_usage_key) self.block_structure_cache.delete(self.root_block_usage_key)
@contextmanager
def _bulk_operations(self):
"""
A context manager for notifying the store of bulk operations.
"""
try:
course_key = self.root_block_usage_key.course_key
except AttributeError:
course_key = None
with self.modulestore.bulk_operations(course_key):
yield
...@@ -68,6 +68,13 @@ class MockModulestore(object): ...@@ -68,6 +68,13 @@ class MockModulestore(object):
raise ItemNotFoundError raise ItemNotFoundError
return item return item
@contextmanager
def bulk_operations(self, ignore): # pylint: disable=unused-argument
"""
A context manager for notifying the store of bulk operations.
"""
yield
class MockCache(object): class MockCache(object):
""" """
......
...@@ -138,17 +138,19 @@ class TestBlockStructureData(TestCase, ChildrenMapTestMixin): ...@@ -138,17 +138,19 @@ class TestBlockStructureData(TestCase, ChildrenMapTestMixin):
# verify fields have not been collected yet # verify fields have not been collected yet
for block in blocks: for block in blocks:
bs_block = block_structure[block.location]
for field in fields: for field in fields:
self.assertIsNone(block_structure.get_xblock_field(block.location, field)) self.assertIsNone(getattr(bs_block, field, None))
# collect fields # collect fields
block_structure._collect_requested_xblock_fields() block_structure._collect_requested_xblock_fields()
# verify values of collected fields # verify values of collected fields
for block in blocks: for block in blocks:
bs_block = block_structure[block.location]
for field in fields: for field in fields:
self.assertEquals( self.assertEquals(
block_structure.get_xblock_field(block.location, field), getattr(bs_block, field, None),
block.field_map.get(field), block.field_map.get(field),
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment