Commit 7a3b4104 by Nimisha Asthagiri

fixup! SplitTestTransformer to handle deeper hierarchies and DAGs.

parent 183f2f82
...@@ -20,7 +20,6 @@ LMS_COURSE_TRANSFORMERS = [ ...@@ -20,7 +20,6 @@ LMS_COURSE_TRANSFORMERS = [
visibility.VisibilityTransformer(), visibility.VisibilityTransformer(),
start_date.StartDateTransformer(), start_date.StartDateTransformer(),
user_partitions.UserPartitionTransformer(), user_partitions.UserPartitionTransformer(),
split_test.SplitTestTransformer(),
library_content.ContentLibraryTransformer(), library_content.ContentLibraryTransformer(),
] ]
......
""" """
Split Test Block Transformer, used to filter course structure per user. Split Test Block Transformer
""" """
from openedx.core.lib.block_cache.transformer import BlockStructureTransformer from openedx.core.lib.block_cache.transformer import BlockStructureTransformer
from .helpers import get_user_partition_groups
class SplitTestTransformer(BlockStructureTransformer): class SplitTestTransformer(BlockStructureTransformer):
...@@ -11,25 +10,6 @@ class SplitTestTransformer(BlockStructureTransformer): ...@@ -11,25 +10,6 @@ class SplitTestTransformer(BlockStructureTransformer):
""" """
VERSION = 1 VERSION = 1
@staticmethod
def check_split_access(split_test_groups, user_groups):
"""
Check that user has access to specific split test group.
Arguments:
split_test_groups (list)
user_groups (dict[Partition Id: Group])
Returns:
bool
"""
if split_test_groups:
for _, group in user_groups.iteritems():
if group.id in split_test_groups:
return True
return False
return True
@classmethod @classmethod
def collect(cls, block_structure): def collect(cls, block_structure):
""" """
...@@ -40,46 +20,33 @@ class SplitTestTransformer(BlockStructureTransformer): ...@@ -40,46 +20,33 @@ class SplitTestTransformer(BlockStructureTransformer):
block_structure (BlockStructureCollectedData) block_structure (BlockStructureCollectedData)
""" """
# Check potential previously set values for user_partitions and split_test_partitions root_block = block_structure.get_xblock(block_structure.root_block_key)
xblock = block_structure.get_xblock(block_structure.root_block_key) user_partitions = getattr(root_block, 'user_partitions', [])
user_partitions = getattr(xblock, 'user_partitions', [])
split_test_partitions = getattr(xblock, 'split_test_partition', []) or []
# For each block, check if it is a split_test block. for block_key in block_structure.topological_traversal(
# If split_test is found, check its user_partition value and get children. predicate=lambda block_key: block_key.block_type == 'split_test',
# Set split_test_group on each of the children for fast retrieval in transform phase. yield_descendants_of_unyielded=True,
# Add same group to children's children, because due to structure restrictions first level ):
# children are verticals.
for block_key in block_structure.topological_traversal():
xblock = block_structure.get_xblock(block_key) xblock = block_structure.get_xblock(block_key)
category = getattr(xblock, 'category', None) partition_for_this_block = next(
if category == 'split_test': (
for user_partition in user_partitions: partition for partition in user_partitions
if user_partition.id == xblock.user_partition_id: if partition.id == xblock.user_partition_id
if user_partition not in split_test_partitions: ),
split_test_partitions.append(user_partition) None
for child in xblock.children: )
for group in user_partition.groups: if not partition_for_this_block:
child_location = xblock.group_id_to_child.get( continue
unicode(group.id),
None # create dict of child location to group_id
) child_to_group = {
if child_location == child: xblock.group_id_to_child.get(unicode(group.id), None): group.id
block_structure.set_transformer_block_data( for group in partition_for_this_block.groups
child, }
cls, # set group access for each child
'split_test_groups', for child_location in xblock.children:
[group.id] child = block_structure.get_xblock(child_location)
) child.group_access[partition_for_this_block.id] = [child_to_group[child_location]]
for component in block_structure.get_xblock(child).children:
block_structure.set_transformer_block_data(
component,
cls,
'split_test_groups',
[group.id]
)
block_structure.set_transformer_data(cls, 'split_test_partition', split_test_partitions)
def transform(self, user_info, block_structure): def transform(self, user_info, block_structure):
""" """
...@@ -89,22 +56,4 @@ class SplitTestTransformer(BlockStructureTransformer): ...@@ -89,22 +56,4 @@ class SplitTestTransformer(BlockStructureTransformer):
user_info (object) user_info (object)
block_structure (BlockStructureCollectedData) block_structure (BlockStructureCollectedData)
""" """
user_partitions = block_structure.get_transformer_data(self, 'split_test_partition') pass
# If there are no split test user partitions, this transformation is a no-op,
# so there is nothing to transform.
if not user_partitions:
return
if not user_info.has_staff_access:
user_groups = get_user_partition_groups(
user_info.course_key, user_partitions, user_info.user
)
block_structure.remove_block_if(
lambda block_key: not SplitTestTransformer.check_split_access(
block_structure.get_transformer_block_data(
block_key, self, 'split_test_groups', default=[]
), user_groups
)
)
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
Tests for SplitTestTransformer. Tests for SplitTestTransformer.
""" """
from openedx.core.djangoapps.user_api.partition_schemes import RandomUserPartitionScheme from openedx.core.djangoapps.user_api.partition_schemes import RandomUserPartitionScheme
from opaque_keys.edx.keys import CourseKey
from student.tests.factories import CourseEnrollmentFactory from student.tests.factories import CourseEnrollmentFactory
from xmodule.partitions.partitions import Group, UserPartition from xmodule.partitions.partitions import Group, UserPartition
from course_blocks.transformers.split_test import SplitTestTransformer from course_blocks.transformers.split_test import SplitTestTransformer
from course_blocks.api import get_course_blocks, clear_course_from_cache from course_blocks.transformers.user_partitions import UserPartitionTransformer
from course_blocks.api import get_course_blocks
from lms.djangoapps.course_blocks.transformers.tests.test_helpers import CourseStructureTestCase from lms.djangoapps.course_blocks.transformers.tests.test_helpers import CourseStructureTestCase
from course_blocks.transformers.helpers import get_user_partition_groups from course_blocks.transformers.helpers import get_user_partition_groups
...@@ -43,9 +43,9 @@ class SplitTestTransformerTestCase(CourseStructureTestCase): ...@@ -43,9 +43,9 @@ class SplitTestTransformerTestCase(CourseStructureTestCase):
# Enroll user in course. # Enroll user in course.
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id, is_active=True) CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id, is_active=True)
self.transformer = SplitTestTransformer()
self.transformer = UserPartitionTransformer()
def get_course_hierarchy(self): def get_course_hierarchy(self):
""" """
Get a course hierarchy to test with. Get a course hierarchy to test with.
...@@ -148,6 +148,3 @@ class SplitTestTransformerTestCase(CourseStructureTestCase): ...@@ -148,6 +148,3 @@ class SplitTestTransformerTestCase(CourseStructureTestCase):
) )
self.assertEqual(set(reloaded_structure.get_block_keys()), set(self.get_block_key_set(*expected_blocks))) self.assertEqual(set(reloaded_structure.get_block_keys()), set(self.get_block_key_set(*expected_blocks)))
def test_staff_user(self):
self.assert_staff_access_to_all_blocks(self.staff, self.course, self.blocks, self.transformer)
""" """
User Partitions Transformer, used to filter course structure per user. User Partitions Transformer
""" """
from openedx.core.lib.block_cache.transformer import BlockStructureTransformer from openedx.core.lib.block_cache.transformer import BlockStructureTransformer
from .helpers import get_user_partition_groups from .helpers import get_user_partition_groups
from .split_test import SplitTestTransformer
class MergedGroupAccess(object): class MergedGroupAccess(object):
...@@ -122,6 +123,9 @@ class UserPartitionTransformer(BlockStructureTransformer): ...@@ -122,6 +123,9 @@ class UserPartitionTransformer(BlockStructureTransformer):
Arguments: Arguments:
block_structure (BlockStructureCollectedData) block_structure (BlockStructureCollectedData)
""" """
# First have the split test transformer setup its group access data for each block.
SplitTestTransformer.collect(block_structure)
# Because user partitions are course-wide, only store data for them on the root block. # Because user partitions are course-wide, only store data for them on the root block.
root_block = block_structure.get_xblock(block_structure.root_block_key) root_block = block_structure.get_xblock(block_structure.root_block_key)
user_partitions = getattr(root_block, 'user_partitions', []) or [] user_partitions = getattr(root_block, 'user_partitions', []) or []
...@@ -153,6 +157,8 @@ class UserPartitionTransformer(BlockStructureTransformer): ...@@ -153,6 +157,8 @@ class UserPartitionTransformer(BlockStructureTransformer):
user_info (object) user_info (object)
block_structure (BlockStructureCollectedData) block_structure (BlockStructureCollectedData)
""" """
SplitTestTransformer().transform(user_info, block_structure)
user_partitions = block_structure.get_transformer_data(self, 'user_partitions') user_partitions = block_structure.get_transformer_data(self, 'user_partitions')
if not user_partitions or user_info.has_staff_access: if not user_partitions or user_info.has_staff_access:
......
...@@ -45,21 +45,19 @@ class BlockStructure(object): ...@@ -45,21 +45,19 @@ class BlockStructure(object):
def get_block_keys(self): def get_block_keys(self):
return self._block_relations.iterkeys() return self._block_relations.iterkeys()
def topological_traversal(self, get_result=None, predicate=None): def topological_traversal(self, **kwargs):
return traverse_topologically( return traverse_topologically(
start_node=self.root_block_key, start_node=self.root_block_key,
get_parents=self.get_parents, get_parents=self.get_parents,
get_children=self.get_children, get_children=self.get_children,
get_result=get_result, **kwargs
predicate=predicate,
) )
def post_order_traversal(self, get_result=None, predicate=None): def post_order_traversal(self, **kwargs):
return traverse_post_order( return traverse_post_order(
start_node=self.root_block_key, start_node=self.root_block_key,
get_children=self.get_children, get_children=self.get_children,
get_result=get_result, **kwargs
predicate=predicate,
) )
def prune(self): def prune(self):
...@@ -67,14 +65,6 @@ class BlockStructure(object): ...@@ -67,14 +65,6 @@ class BlockStructure(object):
pruned_block_relations = defaultdict(self.BlockRelations) pruned_block_relations = defaultdict(self.BlockRelations)
old_block_relations = self._block_relations old_block_relations = self._block_relations
# def do_for_each_block(block_key):
# if block_key in old_block_relations:
# self._add_block(pruned_block_relations, block_key)
#
# for parent in old_block_relations[block_key].parents:
# if parent in pruned_block_relations:
# self._add_relation(pruned_block_relations, parent, block_key)
def do_for_each_block(block_key): def do_for_each_block(block_key):
if block_key in old_block_relations: if block_key in old_block_relations:
self._add_block(pruned_block_relations, block_key) self._add_block(pruned_block_relations, block_key)
...@@ -157,13 +147,13 @@ class BlockStructureBlockData(BlockStructure): ...@@ -157,13 +147,13 @@ class BlockStructureBlockData(BlockStructure):
self._block_relations.pop(usage_key, None) self._block_relations.pop(usage_key, None)
self._block_data_map.pop(usage_key, None) self._block_data_map.pop(usage_key, None)
def remove_block_if(self, removal_condition): def remove_block_if(self, removal_condition, **kwargs):
def predicate(block_key): def predicate(block_key):
if removal_condition(block_key): if removal_condition(block_key):
self.remove_block(block_key) self.remove_block(block_key)
return False return False
return True return True
list(self.topological_traversal(predicate=predicate)) list(self.topological_traversal(predicate=predicate, **kwargs))
class BlockStructureCollectedData(BlockStructureBlockData): class BlockStructureCollectedData(BlockStructureBlockData):
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
from collections import deque from collections import deque
def _traverse_generic(start_node, get_parents, get_children, get_result=None, predicate=None): def _traverse_generic(
start_node, get_parents, get_children, get_result=None, predicate=None, yield_descendants_of_unyielded=False
):
""" """
Helper function to avoid duplicating functionality between Helper function to avoid duplicating functionality between
traverse_depth_first and traverse_topologically. traverse_depth_first and traverse_topologically.
...@@ -21,6 +23,9 @@ def _traverse_generic(start_node, get_parents, get_children, get_result=None, pr ...@@ -21,6 +23,9 @@ def _traverse_generic(start_node, get_parents, get_children, get_result=None, pr
get_children - function that returns a list of children nodes for the given node get_children - function that returns a list of children nodes for the given node
get_result - function that computes and returns the resulting value to be yielded for the given node get_result - function that computes and returns the resulting value to be yielded for the given node
predicate - function that returns whether or not to yield the given node predicate - function that returns whether or not to yield the given node
yield_descendants_of_unyielded -
if False, all descendants of an unyielded node are not yielded.
if True, descendants of an unyielded node are yielded even if none of their parents were yielded.
""" """
# If get_result or predicate aren't provided, just make them to no-ops. # If get_result or predicate aren't provided, just make them to no-ops.
get_result = get_result or (lambda node_: node_) get_result = get_result or (lambda node_: node_)
...@@ -44,7 +49,7 @@ def _traverse_generic(start_node, get_parents, get_children, get_result=None, pr ...@@ -44,7 +49,7 @@ def _traverse_generic(start_node, get_parents, get_children, get_result=None, pr
parents = get_parents(curr_node) parents = get_parents(curr_node)
all_parents_visited = all(parent in yield_results for parent in parents) all_parents_visited = all(parent in yield_results for parent in parents)
any_parent_yielded = any(yield_results[parent] for parent in parents) if all_parents_visited else False any_parent_yielded = any(yield_results[parent] for parent in parents) if all_parents_visited else False
if not all_parents_visited or not any_parent_yielded: if not all_parents_visited or (not yield_descendants_of_unyielded and not any_parent_yielded):
continue continue
# Add its unvisited children to the stack in reverse order so that # Add its unvisited children to the stack in reverse order so that
...@@ -71,23 +76,21 @@ def _traverse_generic(start_node, get_parents, get_children, get_result=None, pr ...@@ -71,23 +76,21 @@ def _traverse_generic(start_node, get_parents, get_children, get_result=None, pr
yield get_result(curr_node) yield get_result(curr_node)
def traverse_topologically(start_node, get_parents, get_children, get_result=None, predicate=None): def traverse_topologically(start_node, get_parents, get_children, **kwargs):
return _traverse_generic( return _traverse_generic(
start_node, start_node,
get_parents=get_parents, get_parents=get_parents,
get_children=get_children, get_children=get_children,
get_result=get_result, **kwargs
predicate=predicate
) )
def traverse_pre_order(start_node, get_children, get_result=None, predicate=None): def traverse_pre_order(start_node, get_children, **kwargs):
return _traverse_generic( return _traverse_generic(
start_node, start_node,
get_parents=None, get_parents=None,
get_children=get_children, get_children=get_children,
get_result=get_result, **kwargs
predicate=predicate
) )
......
...@@ -59,9 +59,6 @@ class GraphTraversalsTestCase(TestCase): ...@@ -59,9 +59,6 @@ class GraphTraversalsTestCase(TestCase):
return result return result
def test_pre_order(self): def test_pre_order(self):
"""
...
"""
self.assertEqual( self.assertEqual(
list(traverse_pre_order( list(traverse_pre_order(
start_node='b1', start_node='b1',
...@@ -73,9 +70,6 @@ class GraphTraversalsTestCase(TestCase): ...@@ -73,9 +70,6 @@ class GraphTraversalsTestCase(TestCase):
) )
def test_post_order(self): def test_post_order(self):
"""
...
"""
self.assertEqual( self.assertEqual(
list(traverse_post_order( list(traverse_post_order(
start_node='b1', start_node='b1',
...@@ -87,9 +81,6 @@ class GraphTraversalsTestCase(TestCase): ...@@ -87,9 +81,6 @@ class GraphTraversalsTestCase(TestCase):
) )
def test_topological(self): def test_topological(self):
"""
...
"""
self.assertEqual( self.assertEqual(
list(traverse_topologically( list(traverse_topologically(
start_node='b1', start_node='b1',
...@@ -100,16 +91,38 @@ class GraphTraversalsTestCase(TestCase): ...@@ -100,16 +91,38 @@ class GraphTraversalsTestCase(TestCase):
['b1', 'c1', 'd1', 'd2', 'e1', 'e2', 'f1', 'c2'] ['b1', 'c1', 'd1', 'd2', 'e1', 'e2', 'f1', 'c2']
) )
def test_topological_with_predicate(self): def test_topological_yield_descendants(self):
""" self.assertEqual(
... list(traverse_topologically(
""" start_node='b1',
get_children=(lambda node: self.graph_1[node]),
get_parents=(lambda node: self.graph_1_parents[node]),
predicate=(lambda node: node != 'd2'),
yield_descendants_of_unyielded=True,
)),
['b1', 'c1', 'd1', 'e1', 'e2', 'f1', 'c2', 'd3']
)
def test_topological_not_yield_descendants(self):
self.assertEqual( self.assertEqual(
list(traverse_topologically( list(traverse_topologically(
start_node='b1', start_node='b1',
get_children=(lambda node: self.graph_1[node]), get_children=(lambda node: self.graph_1[node]),
get_parents=(lambda node: self.graph_1_parents[node]), get_parents=(lambda node: self.graph_1_parents[node]),
predicate=(lambda node: node != 'd2') predicate=(lambda node: node != 'd2'),
yield_descendants_of_unyielded=False,
)), )),
['b1', 'c1', 'd1', 'e1', 'c2', 'd3'] ['b1', 'c1', 'd1', 'e1', 'c2', 'd3']
) )
def test_topological_yield_single_node(self):
self.assertEqual(
list(traverse_topologically(
start_node='b1',
get_children=(lambda node: self.graph_1[node]),
get_parents=(lambda node: self.graph_1_parents[node]),
predicate=(lambda node: node == 'c2'),
yield_descendants_of_unyielded=True,
)),
['c2']
)
...@@ -52,7 +52,6 @@ setup( ...@@ -52,7 +52,6 @@ setup(
"visibility = lms.djangoapps.course_blocks.transformers.visibility:VisibilityTransformer", "visibility = lms.djangoapps.course_blocks.transformers.visibility:VisibilityTransformer",
"start_date = lms.djangoapps.course_blocks.transformers.start_date:StartDateTransformer", "start_date = lms.djangoapps.course_blocks.transformers.start_date:StartDateTransformer",
"user_partitions = lms.djangoapps.course_blocks.transformers.user_partitions:UserPartitionTransformer", "user_partitions = lms.djangoapps.course_blocks.transformers.user_partitions:UserPartitionTransformer",
"split_test = lms.djangoapps.course_blocks.transformers.split_test:SplitTestTransformer",
"library_content = lms.djangoapps.course_blocks.transformers.library_content:ContentLibraryTransformer", "library_content = lms.djangoapps.course_blocks.transformers.library_content:ContentLibraryTransformer",
"blocks_api = lms.djangoapps.course_api.blocks.transformers.blocks_api:BlocksAPITransformer", "blocks_api = lms.djangoapps.course_api.blocks.transformers.blocks_api:BlocksAPITransformer",
], ],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment