Commit 1fbb4b8b by Nimisha Asthagiri

Update BlockStructure API

parent f016647f
...@@ -17,10 +17,14 @@ class BlockSerializer(serializers.Serializer): # pylint: disable=abstract-metho ...@@ -17,10 +17,14 @@ class BlockSerializer(serializers.Serializer): # pylint: disable=abstract-metho
Get the field value requested. The field may be an XBlock field, a Get the field value requested. The field may be an XBlock field, a
transformer block field, or an entire tranformer block data dict. transformer block field, or an entire tranformer block data dict.
""" """
value = None
if transformer is None: if transformer is None:
value = self.context['block_structure'].get_xblock_field(block_key, field_name) value = self.context['block_structure'].get_xblock_field(block_key, field_name)
elif field_name is None: elif field_name is None:
value = self.context['block_structure'].get_transformer_block_data(block_key, transformer) try:
value = self.context['block_structure'].get_transformer_block_data(block_key, transformer).fields
except KeyError:
pass
else: else:
value = self.context['block_structure'].get_transformer_block_field(block_key, transformer, field_name) value = self.context['block_structure'].get_transformer_block_field(block_key, transformer, field_name)
......
...@@ -41,10 +41,11 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase): ...@@ -41,10 +41,11 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase):
block_types_to_count=['video'], block_types_to_count=['video'],
requested_student_view_data=['video'], requested_student_view_data=['video'],
) )
self.transformers = BlockStructureTransformers(COURSE_BLOCK_ACCESS_TRANSFORMERS + [blocks_api_transformer])
self.block_structure = get_course_blocks( self.block_structure = get_course_blocks(
self.user, self.user,
self.course.location, self.course.location,
BlockStructureTransformers(COURSE_BLOCK_ACCESS_TRANSFORMERS + [blocks_api_transformer]), self.transformers,
) )
self.serializer_context = { self.serializer_context = {
'request': MagicMock(), 'request': MagicMock(),
...@@ -92,7 +93,7 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase): ...@@ -92,7 +93,7 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase):
{ {
'id', 'type', 'lms_web_url', 'student_view_url', 'id', 'type', 'lms_web_url', 'student_view_url',
'display_name', 'graded', 'display_name', 'graded',
'block_counts', 'student_view_multi_device', 'student_view_multi_device',
'lti_url', 'lti_url',
'visible_to_staff_only', 'visible_to_staff_only',
}, },
...@@ -108,6 +109,13 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase): ...@@ -108,6 +109,13 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase):
self.assertIn('student_view_multi_device', serialized_block) self.assertIn('student_view_multi_device', serialized_block)
self.assertTrue(serialized_block['student_view_multi_device']) self.assertTrue(serialized_block['student_view_multi_device'])
# chapters with video should have block_counts
if serialized_block['type'] == 'chapter':
if serialized_block['display_name'] not in ('poll_test', 'handout_container'):
self.assertIn('block_counts', serialized_block)
else:
self.assertNotIn('block_counts', serialized_block)
def create_staff_context(self): def create_staff_context(self):
""" """
Create staff user and course blocks accessible by that user Create staff user and course blocks accessible by that user
...@@ -119,7 +127,7 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase): ...@@ -119,7 +127,7 @@ class TestBlockSerializerBase(SharedModuleStoreTestCase):
block_structure = get_course_blocks( block_structure = get_course_blocks(
staff_user, staff_user,
self.course.location, self.course.location,
BlockStructureTransformers(COURSE_BLOCK_ACCESS_TRANSFORMERS), self.transformers,
) )
return { return {
'request': MagicMock(), 'request': MagicMock(),
...@@ -156,12 +164,14 @@ class TestBlockSerializer(TestBlockSerializerBase): ...@@ -156,12 +164,14 @@ class TestBlockSerializer(TestBlockSerializerBase):
serializer = self.create_serializer() serializer = self.create_serializer()
for serialized_block in serializer.data: for serialized_block in serializer.data:
self.assert_basic_block(serialized_block['id'], serialized_block) self.assert_basic_block(serialized_block['id'], serialized_block)
self.assertEquals(len(serializer.data), 28)
def test_additional_requested_fields(self): def test_additional_requested_fields(self):
self.add_additional_requested_fields() self.add_additional_requested_fields()
serializer = self.create_serializer() serializer = self.create_serializer()
for serialized_block in serializer.data: for serialized_block in serializer.data:
self.assert_extended_block(serialized_block) self.assert_extended_block(serialized_block)
self.assertEquals(len(serializer.data), 28)
def test_staff_fields(self): def test_staff_fields(self):
""" """
...@@ -173,6 +183,7 @@ class TestBlockSerializer(TestBlockSerializerBase): ...@@ -173,6 +183,7 @@ class TestBlockSerializer(TestBlockSerializerBase):
for serialized_block in serializer.data: for serialized_block in serializer.data:
self.assert_extended_block(serialized_block) self.assert_extended_block(serialized_block)
self.assert_staff_fields(serialized_block) self.assert_staff_fields(serialized_block)
self.assertEquals(len(serializer.data), 29)
class TestBlockDictSerializer(TestBlockSerializerBase): class TestBlockDictSerializer(TestBlockSerializerBase):
...@@ -200,12 +211,14 @@ class TestBlockDictSerializer(TestBlockSerializerBase): ...@@ -200,12 +211,14 @@ class TestBlockDictSerializer(TestBlockSerializerBase):
for block_key_string, serialized_block in serializer.data['blocks'].iteritems(): for block_key_string, serialized_block in serializer.data['blocks'].iteritems():
self.assertEquals(serialized_block['id'], block_key_string) self.assertEquals(serialized_block['id'], block_key_string)
self.assert_basic_block(block_key_string, serialized_block) self.assert_basic_block(block_key_string, serialized_block)
self.assertEquals(len(serializer.data['blocks']), 28)
def test_additional_requested_fields(self): def test_additional_requested_fields(self):
self.add_additional_requested_fields() self.add_additional_requested_fields()
serializer = self.create_serializer() serializer = self.create_serializer()
for serialized_block in serializer.data['blocks'].itervalues(): for serialized_block in serializer.data['blocks'].itervalues():
self.assert_extended_block(serialized_block) self.assert_extended_block(serialized_block)
self.assertEquals(len(serializer.data['blocks']), 28)
def test_staff_fields(self): def test_staff_fields(self):
""" """
...@@ -217,3 +230,4 @@ class TestBlockDictSerializer(TestBlockSerializerBase): ...@@ -217,3 +230,4 @@ class TestBlockDictSerializer(TestBlockSerializerBase):
for serialized_block in serializer.data['blocks'].itervalues(): for serialized_block in serializer.data['blocks'].itervalues():
self.assert_extended_block(serialized_block) self.assert_extended_block(serialized_block)
self.assert_staff_fields(serialized_block) self.assert_staff_fields(serialized_block)
self.assertEquals(len(serializer.data['blocks']), 29)
...@@ -38,13 +38,13 @@ class TestBlockCountsTransformer(ModuleStoreTestCase): ...@@ -38,13 +38,13 @@ class TestBlockCountsTransformer(ModuleStoreTestCase):
) )
# verify count of chapters # verify count of chapters
self.assertEquals(block_counts_for_course['chapter'], 2) self.assertEquals(block_counts_for_course.chapter, 2)
# verify count of problems # verify count of problems
self.assertEquals(block_counts_for_course['problem'], 6) self.assertEquals(block_counts_for_course.problem, 6)
self.assertEquals(block_counts_for_chapter_x['problem'], 3) self.assertEquals(block_counts_for_chapter_x.problem, 3)
# verify other block types are not counted # verify other block types are not counted
for block_type in ['course', 'html', 'video']: for block_type in ['course', 'html', 'video']:
self.assertNotIn(block_type, block_counts_for_course) self.assertFalse(hasattr(block_counts_for_course, block_type))
self.assertNotIn(block_type, block_counts_for_chapter_x) self.assertFalse(hasattr(block_counts_for_chapter_x, block_type))
...@@ -40,7 +40,7 @@ class BlockStructureCache(object): ...@@ -40,7 +40,7 @@ class BlockStructureCache(object):
""" """
data_to_cache = ( data_to_cache = (
block_structure._block_relations, block_structure._block_relations,
block_structure._transformer_data, block_structure.transformer_data,
block_structure._block_data_map, block_structure._block_data_map,
) )
zp_data_to_cache = zpickle(data_to_cache) zp_data_to_cache = zpickle(data_to_cache)
...@@ -99,7 +99,7 @@ class BlockStructureCache(object): ...@@ -99,7 +99,7 @@ class BlockStructureCache(object):
block_relations, transformer_data, block_data_map = zunpickle(zp_data_from_cache) block_relations, transformer_data, block_data_map = zunpickle(zp_data_from_cache)
block_structure = BlockStructureModulestoreData(root_block_usage_key) block_structure = BlockStructureModulestoreData(root_block_usage_key)
block_structure._block_relations = block_relations block_structure._block_relations = block_relations
block_structure._transformer_data = transformer_data block_structure.transformer_data = transformer_data
block_structure._block_data_map = block_data_map block_structure._block_data_map = block_data_map
return block_structure return block_structure
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
Top-level module for the Block Structure framework with a class for managing Top-level module for the Block Structure framework with a class for managing
BlockStructures. BlockStructures.
""" """
from contextlib import contextmanager
from .cache import BlockStructureCache from .cache import BlockStructureCache
from .factory import BlockStructureFactory from .factory import BlockStructureFactory
from .exceptions import UsageKeyNotInBlockStructure from .exceptions import UsageKeyNotInBlockStructure
...@@ -87,12 +89,13 @@ class BlockStructureManager(object): ...@@ -87,12 +89,13 @@ class BlockStructureManager(object):
) )
cache_miss = block_structure is None cache_miss = block_structure is None
if cache_miss or BlockStructureTransformers.is_collected_outdated(block_structure): if cache_miss or BlockStructureTransformers.is_collected_outdated(block_structure):
block_structure = BlockStructureFactory.create_from_modulestore( with self._bulk_operations():
self.root_block_usage_key, block_structure = BlockStructureFactory.create_from_modulestore(
self.modulestore self.root_block_usage_key,
) self.modulestore
BlockStructureTransformers.collect(block_structure) )
self.block_structure_cache.add(block_structure) BlockStructureTransformers.collect(block_structure)
self.block_structure_cache.add(block_structure)
return block_structure return block_structure
def update_collected(self): def update_collected(self):
...@@ -111,3 +114,15 @@ class BlockStructureManager(object): ...@@ -111,3 +114,15 @@ class BlockStructureManager(object):
root block key. root block key.
""" """
self.block_structure_cache.delete(self.root_block_usage_key) self.block_structure_cache.delete(self.root_block_usage_key)
@contextmanager
def _bulk_operations(self):
"""
A context manager for notifying the store of bulk operations.
"""
try:
course_key = self.root_block_usage_key.course_key
except AttributeError:
course_key = None
with self.modulestore.bulk_operations(course_key):
yield
...@@ -68,6 +68,13 @@ class MockModulestore(object): ...@@ -68,6 +68,13 @@ class MockModulestore(object):
raise ItemNotFoundError raise ItemNotFoundError
return item return item
@contextmanager
def bulk_operations(self, ignore): # pylint: disable=unused-argument
"""
A context manager for notifying the store of bulk operations.
"""
yield
class MockCache(object): class MockCache(object):
""" """
......
...@@ -138,17 +138,19 @@ class TestBlockStructureData(TestCase, ChildrenMapTestMixin): ...@@ -138,17 +138,19 @@ class TestBlockStructureData(TestCase, ChildrenMapTestMixin):
# verify fields have not been collected yet # verify fields have not been collected yet
for block in blocks: for block in blocks:
bs_block = block_structure[block.location]
for field in fields: for field in fields:
self.assertIsNone(block_structure.get_xblock_field(block.location, field)) self.assertIsNone(getattr(bs_block, field, None))
# collect fields # collect fields
block_structure._collect_requested_xblock_fields() block_structure._collect_requested_xblock_fields()
# verify values of collected fields # verify values of collected fields
for block in blocks: for block in blocks:
bs_block = block_structure[block.location]
for field in fields: for field in fields:
self.assertEquals( self.assertEquals(
block_structure.get_xblock_field(block.location, field), getattr(bs_block, field, None),
block.field_map.get(field), block.field_map.get(field),
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment