Commit 2d64c54c by Nimisha Asthagiri

fixup! SplitTestTransformer skip split_test module.

parent 7a3b4104
"""
Transformers helpers functions.
"""
def get_user_partition_groups(course_key, user_partitions, user):
"""
Collect group ID for each partition in this course for this user.
Arguments:
course_key (CourseKey)
user_partitions (list[UserPartition])
user (User)
Returns:
dict[int: Group]: Mapping from user partitions to the group to which
the user belongs in each partition. If the user isn't in a group
for a particular partition, then that partition's ID will not be
in the dict.
"""
partition_groups = {}
for partition in user_partitions:
group = partition.scheme.get_group_for_user(
course_key,
user,
partition,
)
if group is not None:
partition_groups[partition.id] = group
return partition_groups
...@@ -56,4 +56,7 @@ class SplitTestTransformer(BlockStructureTransformer): ...@@ -56,4 +56,7 @@ class SplitTestTransformer(BlockStructureTransformer):
user_info (object) user_info (object)
block_structure (BlockStructureCollectedData) block_structure (BlockStructureCollectedData)
""" """
pass block_structure.remove_block_if(
lambda block_key: block_key.block_type == 'split_test',
keep_descendants=True,
)
...@@ -5,12 +5,9 @@ from openedx.core.djangoapps.user_api.partition_schemes import RandomUserPartiti ...@@ -5,12 +5,9 @@ from openedx.core.djangoapps.user_api.partition_schemes import RandomUserPartiti
from student.tests.factories import CourseEnrollmentFactory from student.tests.factories import CourseEnrollmentFactory
from xmodule.partitions.partitions import Group, UserPartition from xmodule.partitions.partitions import Group, UserPartition
from course_blocks.transformers.split_test import SplitTestTransformer from ...api import get_course_blocks
from course_blocks.transformers.user_partitions import UserPartitionTransformer from ..user_partitions import UserPartitionTransformer, get_user_partition_groups
from course_blocks.api import get_course_blocks from .test_helpers import CourseStructureTestCase
from lms.djangoapps.course_blocks.transformers.tests.test_helpers import CourseStructureTestCase
from course_blocks.transformers.helpers import get_user_partition_groups
class SplitTestTransformerTestCase(CourseStructureTestCase): class SplitTestTransformerTestCase(CourseStructureTestCase):
...@@ -132,7 +129,7 @@ class SplitTestTransformerTestCase(CourseStructureTestCase): ...@@ -132,7 +129,7 @@ class SplitTestTransformerTestCase(CourseStructureTestCase):
self.assertEquals(len(user_groups), 1) self.assertEquals(len(user_groups), 1)
group = user_groups[self.split_test_user_partition_id] group = user_groups[self.split_test_user_partition_id]
expected_blocks = ['course', 'chapter1', 'lesson1', 'vertical1', 'split_test1'] expected_blocks = ['course', 'chapter1', 'lesson1', 'vertical1']
if group.id == 3: if group.id == 3:
expected_blocks += ['vertical2', 'html1'] expected_blocks += ['vertical2', 'html1']
else: else:
......
...@@ -3,7 +3,6 @@ User Partitions Transformer ...@@ -3,7 +3,6 @@ User Partitions Transformer
""" """
from openedx.core.lib.block_cache.transformer import BlockStructureTransformer from openedx.core.lib.block_cache.transformer import BlockStructureTransformer
from .helpers import get_user_partition_groups
from .split_test import SplitTestTransformer from .split_test import SplitTestTransformer
...@@ -108,6 +107,33 @@ class MergedGroupAccess(object): ...@@ -108,6 +107,33 @@ class MergedGroupAccess(object):
return True return True
def get_user_partition_groups(course_key, user_partitions, user):
"""
Collect group ID for each partition in this course for this user.
Arguments:
course_key (CourseKey)
user_partitions (list[UserPartition])
user (User)
Returns:
dict[int: Group]: Mapping from user partitions to the group to which
the user belongs in each partition. If the user isn't in a group
for a particular partition, then that partition's ID will not be
in the dict.
"""
partition_groups = {}
for partition in user_partitions:
group = partition.scheme.get_group_for_user(
course_key,
user,
partition,
)
if group is not None:
partition_groups[partition.id] = group
return partition_groups
class UserPartitionTransformer(BlockStructureTransformer): class UserPartitionTransformer(BlockStructureTransformer):
""" """
... ...
......
...@@ -34,7 +34,7 @@ class BlockStructure(object): ...@@ -34,7 +34,7 @@ class BlockStructure(object):
self._add_relation(self._block_relations, parent_key, child_key) self._add_relation(self._block_relations, parent_key, child_key)
def get_parents(self, usage_key): def get_parents(self, usage_key):
return self._block_relations.get(usage_key).parents if self.has_block(usage_key) else [] return self._block_relations[usage_key].parents if self.has_block(usage_key) else []
def get_children(self, usage_key): def get_children(self, usage_key):
return self._block_relations[usage_key].children if self.has_block(usage_key) else [] return self._block_relations[usage_key].children if self.has_block(usage_key) else []
...@@ -134,23 +134,30 @@ class BlockStructureBlockData(BlockStructure): ...@@ -134,23 +134,30 @@ class BlockStructureBlockData(BlockStructure):
def remove_transformer_block_data(self, usage_key, transformer, key): def remove_transformer_block_data(self, usage_key, transformer, key):
self._block_data_map[usage_key]._transformer_data.get(transformer.name(), {}).pop(key, None) self._block_data_map[usage_key]._transformer_data.get(transformer.name(), {}).pop(key, None)
def remove_block(self, usage_key): def remove_block(self, usage_key, keep_descendants):
children = self._block_relations[usage_key].children
parents = self._block_relations[usage_key].parents
# Remove block from its children. # Remove block from its children.
for child in self._block_relations[usage_key].children: for child in children:
self._block_relations[child].parents.remove(usage_key) self._block_relations[child].parents.remove(usage_key)
# Remove block from its parents. # Remove block from its parents.
for parent_key in self._block_relations[usage_key].parents: for parent in parents:
self._block_relations[parent_key].children.remove(usage_key) self._block_relations[parent].children.remove(usage_key)
# Remove block. # Remove block.
self._block_relations.pop(usage_key, None) self._block_relations.pop(usage_key, None)
self._block_data_map.pop(usage_key, None) self._block_data_map.pop(usage_key, None)
def remove_block_if(self, removal_condition, **kwargs): # Recreate the graph connections if descendants are to be kept.
if keep_descendants:
[self.add_relation(parent, child) for child in children for parent in parents]
def remove_block_if(self, removal_condition, keep_descendants=False, **kwargs):
def predicate(block_key): def predicate(block_key):
if removal_condition(block_key): if removal_condition(block_key):
self.remove_block(block_key) self.remove_block(block_key, keep_descendants)
return False return False
return True return True
list(self.topological_traversal(predicate=predicate, **kwargs)) list(self.topological_traversal(predicate=predicate, **kwargs))
......
...@@ -5,11 +5,11 @@ Tests for block_cache.py ...@@ -5,11 +5,11 @@ Tests for block_cache.py
from mock import patch from mock import patch
from unittest import TestCase from unittest import TestCase
from .test_utils import ( from .test_utils import (
MockModulestoreFactory, MockCache, MockUserInfo, MockTransformer, SIMPLE_CHILDREN_MAP, BlockStructureTestMixin MockModulestoreFactory, MockCache, MockUserInfo, MockTransformer, ChildrenMapTestMixin
) )
from ..block_cache import get_blocks from ..block_cache import get_blocks
class TestBlockCache(TestCase, BlockStructureTestMixin): class TestBlockCache(TestCase, ChildrenMapTestMixin):
class TestTransformer1(MockTransformer): class TestTransformer1(MockTransformer):
@classmethod @classmethod
...@@ -45,7 +45,7 @@ class TestBlockCache(TestCase, BlockStructureTestMixin): ...@@ -45,7 +45,7 @@ class TestBlockCache(TestCase, BlockStructureTestMixin):
@patch('openedx.core.lib.block_cache.transformer.BlockStructureTransformers.get_available_plugins') @patch('openedx.core.lib.block_cache.transformer.BlockStructureTransformers.get_available_plugins')
def test_get_blocks(self, mock_available_transforms): def test_get_blocks(self, mock_available_transforms):
children_map = SIMPLE_CHILDREN_MAP children_map = self.SIMPLE_CHILDREN_MAP
cache = MockCache() cache = MockCache()
user_info = MockUserInfo() user_info = MockUserInfo()
modulestore = MockModulestoreFactory.create(children_map) modulestore = MockModulestoreFactory.create(children_map)
...@@ -57,4 +57,4 @@ class TestBlockCache(TestCase, BlockStructureTestMixin): ...@@ -57,4 +57,4 @@ class TestBlockCache(TestCase, BlockStructureTestMixin):
mock_available_transforms.return_value = {transformer.name(): transformer for transformer in transformers} mock_available_transforms.return_value = {transformer.name(): transformer for transformer in transformers}
block_structure = get_blocks(cache, modulestore, user_info, root_block_key=0, transformers=transformers) block_structure = get_blocks(cache, modulestore, user_info, root_block_key=0, transformers=transformers)
self.verify_block_structure(block_structure, children_map) self.assert_block_structure(block_structure, children_map)
...@@ -2,62 +2,36 @@ ...@@ -2,62 +2,36 @@
Tests for block_structure.py Tests for block_structure.py
""" """
from collections import namedtuple from collections import namedtuple
from copy import deepcopy
import ddt import ddt
import itertools
from mock import patch from mock import patch
from unittest import TestCase from unittest import TestCase
from ..block_structure import ( from ..block_structure import (
BlockStructure, BlockStructureCollectedData, BlockStructureBlockData, BlockStructureFactory BlockStructure, BlockStructureCollectedData, BlockStructureBlockData, BlockStructureFactory
) )
from ..graph_traversals import traverse_post_order
from ..transformer import BlockStructureTransformer, BlockStructureTransformers from ..transformer import BlockStructureTransformer, BlockStructureTransformers
from .test_utils import ( from .test_utils import (
MockCache, MockXBlock, MockModulestoreFactory, MockTransformer, SIMPLE_CHILDREN_MAP, BlockStructureTestMixin MockCache, MockXBlock, MockModulestoreFactory, MockTransformer,
ChildrenMapTestMixin
) )
@ddt.ddt @ddt.ddt
class TestBlockStructure(TestCase): class TestBlockStructure(TestCase, ChildrenMapTestMixin):
""" """
Tests for BlockStructure Tests for BlockStructure
""" """
def get_parents_map(self, children_map):
parent_map = [[] for node in children_map]
for parent, children in enumerate(children_map):
for child in children:
parent_map[child].append(parent)
return parent_map
@ddt.data( @ddt.data(
[], [],
# 0 ChildrenMapTestMixin.SIMPLE_CHILDREN_MAP,
# / \ ChildrenMapTestMixin.LINEAR_CHILDREN_MAP,
# 1 2 ChildrenMapTestMixin.DAG_CHILDREN_MAP,
# / \
# 3 4
SIMPLE_CHILDREN_MAP,
# 0
# /
# 1
# /
# 2
# /
# 3
[[1], [2], [3], []],
# 0
# / \
# 1 2
# \ /
# 3
[[1, 2], [3], [3], []],
) )
def test_relations(self, children_map): def test_relations(self, children_map):
# create block structure block_structure = self.create_block_structure(BlockStructure, children_map)
block_structure = BlockStructure(root_block_key=0)
# add_relation
for parent, children in enumerate(children_map):
for child in children:
block_structure.add_relation(parent, child)
# get_children # get_children
for parent, children in enumerate(children_map): for parent, children in enumerate(children_map):
...@@ -73,7 +47,8 @@ class TestBlockStructure(TestCase): ...@@ -73,7 +47,8 @@ class TestBlockStructure(TestCase):
self.assertFalse(block_structure.has_block(len(children_map) + 1)) self.assertFalse(block_structure.has_block(len(children_map) + 1))
class TestBlockStructureData(TestCase): @ddt.ddt
class TestBlockStructureData(TestCase, ChildrenMapTestMixin):
""" """
Tests for BlockStructureBlockData and BlockStructureCollectedData Tests for BlockStructureBlockData and BlockStructureCollectedData
""" """
...@@ -176,41 +151,88 @@ class TestBlockStructureData(TestCase): ...@@ -176,41 +151,88 @@ class TestBlockStructureData(TestCase):
block.field_map.get(field), block.field_map.get(field),
) )
def test_remove_block(self): @ddt.data(
block_structure = BlockStructureBlockData(root_block_key=0) *itertools.product(
for parent, children in enumerate(SIMPLE_CHILDREN_MAP): [True, False],
for child in children: range(7),
block_structure.add_relation(parent, child) [
# ChildrenMapTestMixin.SIMPLE_CHILDREN_MAP,
self.assertTrue(block_structure.has_block(1)) ChildrenMapTestMixin.LINEAR_CHILDREN_MAP,
self.assertTrue(1 in block_structure.get_children(0)) # ChildrenMapTestMixin.DAG_CHILDREN_MAP,
],
block_structure.remove_block(1) )
)
self.assertFalse(block_structure.has_block(1)) @ddt.unpack
self.assertFalse(1 in block_structure.get_children(0)) def test_remove_block(self, keep_descendants, block_to_remove, children_map):
### skip test if invalid
if (
(block_to_remove >= len(children_map)) or
(keep_descendants and block_to_remove == 0)
):
return
### create structure
block_structure = self.create_block_structure(BlockStructureBlockData, children_map)
parents_map = self.get_parents_map(children_map)
### verify blocks pre-exist
self.assert_block_structure(block_structure, children_map)
### remove block
block_structure.remove_block(block_to_remove, keep_descendants)
missing_blocks = [block_to_remove]
### compute and verify updated children_map
removed_children_map = deepcopy(children_map)
removed_children_map[block_to_remove] = []
[removed_children_map[parent].remove(block_to_remove) for parent in parents_map[block_to_remove]]
if keep_descendants:
# update the graph connecting the old parents to the old children
[
removed_children_map[parent].append(child)
for child in children_map[block_to_remove]
for parent in parents_map[block_to_remove]
]
self.assert_block_structure(block_structure, removed_children_map, missing_blocks)
### prune the structure
block_structure.prune()
self.assertTrue(block_structure.has_block(3)) ### compute and verify updated children_map
self.assertTrue(block_structure.has_block(4)) pruned_children_map = deepcopy(removed_children_map)
block_structure.prune() if not keep_descendants:
def update_descendant(block):
"""
add descendant to missing blocks and empty its children
"""
missing_blocks.append(block)
pruned_children_map[block] = []
self.assertFalse(block_structure.has_block(3)) # update all descendants
self.assertFalse(block_structure.has_block(4)) for child in children_map[block_to_remove]:
list(traverse_post_order(
child,
get_children=lambda block: pruned_children_map[block],
get_result=update_descendant,
))
self.assert_block_structure(block_structure, pruned_children_map, missing_blocks)
class TestBlockStructureFactory(TestCase, BlockStructureTestMixin): class TestBlockStructureFactory(TestCase, ChildrenMapTestMixin):
""" """
Tests for BlockStructureFactory Tests for BlockStructureFactory
""" """
def test_factory_methods(self): def test_factory_methods(self):
children_map = SIMPLE_CHILDREN_MAP children_map = self.SIMPLE_CHILDREN_MAP
modulestore = MockModulestoreFactory.create(children_map) modulestore = MockModulestoreFactory.create(children_map)
cache = MockCache() cache = MockCache()
# test create from modulestore # test create from modulestore
block_structure = BlockStructureFactory.create_from_modulestore(root_block_key=0, modulestore=modulestore) block_structure = BlockStructureFactory.create_from_modulestore(root_block_key=0, modulestore=modulestore)
self.verify_block_structure(block_structure, children_map) self.assert_block_structure(block_structure, children_map)
# test not in cache # test not in cache
self.assertIsNone(BlockStructureFactory.create_from_cache(root_block_key=0, cache=cache)) self.assertIsNone(BlockStructureFactory.create_from_cache(root_block_key=0, cache=cache))
...@@ -232,7 +254,7 @@ class TestBlockStructureFactory(TestCase, BlockStructureTestMixin): ...@@ -232,7 +254,7 @@ class TestBlockStructureFactory(TestCase, BlockStructureTestMixin):
# test re-create from cache # test re-create from cache
from_cache_block_structure = BlockStructureFactory.create_from_cache(root_block_key=0, cache=cache) from_cache_block_structure = BlockStructureFactory.create_from_cache(root_block_key=0, cache=cache)
self.assertIsNotNone(from_cache_block_structure) self.assertIsNotNone(from_cache_block_structure)
self.verify_block_structure(from_cache_block_structure, children_map) self.assert_block_structure(from_cache_block_structure, children_map)
# test remove from cache # test remove from cache
BlockStructureFactory.remove_from_cache(root_block_key=0, cache=cache) BlockStructureFactory.remove_from_cache(root_block_key=0, cache=cache)
......
...@@ -74,21 +74,68 @@ class MockTransformer(BlockStructureTransformer): ...@@ -74,21 +74,68 @@ class MockTransformer(BlockStructureTransformer):
pass pass
# 0 class ChildrenMapTestMixin(object):
# / \ # 0
# 1 2 # / \
# / \ # 1 2
# 3 4 # / \
SIMPLE_CHILDREN_MAP = [[1, 2], [3, 4], [], [], []] # 3 4
SIMPLE_CHILDREN_MAP = [[1, 2], [3, 4], [], [], []]
# 0
# /
# 1
# /
# 2
# /
# 3
LINEAR_CHILDREN_MAP = [[1], [2], [3], []]
# 0
# / \
# 1 2
# \ / \
# 3 4
# / \
# 5 6
DAG_CHILDREN_MAP = [[1, 2], [3], [3, 4], [5, 6], [], [], []]
def create_block_structure(self, block_structure_cls, children_map):
# create block structure
block_structure = block_structure_cls(root_block_key=0)
# add_relation
for parent, children in enumerate(children_map):
for child in children:
block_structure.add_relation(parent, child)
return block_structure
def get_parents_map(self, children_map):
parent_map = [[] for _ in children_map]
for parent, children in enumerate(children_map):
for child in children:
parent_map[child].append(parent)
return parent_map
def assert_block_structure(self, block_structure, children_map, missing_blocks=None):
if not missing_blocks:
missing_blocks = []
class BlockStructureTestMixin(object):
def verify_block_structure(self, block_structure, children_map):
for block_key, children in enumerate(children_map): for block_key, children in enumerate(children_map):
self.assertTrue(
block_structure.has_block(block_key)
)
self.assertEquals( self.assertEquals(
set(block_structure.get_children(block_key)), block_structure.has_block(block_key),
set(children), block_key not in missing_blocks,
) )
if block_key not in missing_blocks:
self.assertEquals(
set(block_structure.get_children(block_key)),
set(children),
)
parents_map = self.get_parents_map(children_map)
for block_key, parents in enumerate(parents_map):
if block_key not in missing_blocks:
self.assertEquals(
set(block_structure.get_parents(block_key)),
set(parents),
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment