Commit 7a3b4104 by Nimisha Asthagiri

fixup! SplitTestTransformer to handle deeper hierarchies and DAGs.

parent 183f2f82
......@@ -20,7 +20,6 @@ LMS_COURSE_TRANSFORMERS = [
visibility.VisibilityTransformer(),
start_date.StartDateTransformer(),
user_partitions.UserPartitionTransformer(),
split_test.SplitTestTransformer(),
library_content.ContentLibraryTransformer(),
]
......
"""
Split Test Block Transformer, used to filter course structure per user.
Split Test Block Transformer
"""
from openedx.core.lib.block_cache.transformer import BlockStructureTransformer
from .helpers import get_user_partition_groups
class SplitTestTransformer(BlockStructureTransformer):
......@@ -11,25 +10,6 @@ class SplitTestTransformer(BlockStructureTransformer):
"""
VERSION = 1
@staticmethod
def check_split_access(split_test_groups, user_groups):
"""
Check that user has access to specific split test group.
Arguments:
split_test_groups (list)
user_groups (dict[Partition Id: Group])
Returns:
bool
"""
if split_test_groups:
for _, group in user_groups.iteritems():
if group.id in split_test_groups:
return True
return False
return True
@classmethod
def collect(cls, block_structure):
"""
......@@ -40,46 +20,33 @@ class SplitTestTransformer(BlockStructureTransformer):
block_structure (BlockStructureCollectedData)
"""
# Check potential previously set values for user_partitions and split_test_partitions
xblock = block_structure.get_xblock(block_structure.root_block_key)
user_partitions = getattr(xblock, 'user_partitions', [])
split_test_partitions = getattr(xblock, 'split_test_partition', []) or []
root_block = block_structure.get_xblock(block_structure.root_block_key)
user_partitions = getattr(root_block, 'user_partitions', [])
# For each block, check if it is a split_test block.
# If split_test is found, check its user_partition value and get children.
# Set split_test_group on each of the children for fast retrieval in transform phase.
# Add same group to children's children, because due to structure restrictions first level
# children are verticals.
for block_key in block_structure.topological_traversal():
for block_key in block_structure.topological_traversal(
predicate=lambda block_key: block_key.block_type == 'split_test',
yield_descendants_of_unyielded=True,
):
xblock = block_structure.get_xblock(block_key)
category = getattr(xblock, 'category', None)
if category == 'split_test':
for user_partition in user_partitions:
if user_partition.id == xblock.user_partition_id:
if user_partition not in split_test_partitions:
split_test_partitions.append(user_partition)
for child in xblock.children:
for group in user_partition.groups:
child_location = xblock.group_id_to_child.get(
unicode(group.id),
None
)
if child_location == child:
block_structure.set_transformer_block_data(
child,
cls,
'split_test_groups',
[group.id]
)
for component in block_structure.get_xblock(child).children:
block_structure.set_transformer_block_data(
component,
cls,
'split_test_groups',
[group.id]
)
block_structure.set_transformer_data(cls, 'split_test_partition', split_test_partitions)
partition_for_this_block = next(
(
partition for partition in user_partitions
if partition.id == xblock.user_partition_id
),
None
)
if not partition_for_this_block:
continue
# create dict of child location to group_id
child_to_group = {
xblock.group_id_to_child.get(unicode(group.id), None): group.id
for group in partition_for_this_block.groups
}
# set group access for each child
for child_location in xblock.children:
child = block_structure.get_xblock(child_location)
child.group_access[partition_for_this_block.id] = [child_to_group[child_location]]
def transform(self, user_info, block_structure):
"""
......@@ -89,22 +56,4 @@ class SplitTestTransformer(BlockStructureTransformer):
user_info (object)
block_structure (BlockStructureCollectedData)
"""
user_partitions = block_structure.get_transformer_data(self, 'split_test_partition')
# If there are no split test user partitions, this transformation is a no-op,
# so there is nothing to transform.
if not user_partitions:
return
if not user_info.has_staff_access:
user_groups = get_user_partition_groups(
user_info.course_key, user_partitions, user_info.user
)
block_structure.remove_block_if(
lambda block_key: not SplitTestTransformer.check_split_access(
block_structure.get_transformer_block_data(
block_key, self, 'split_test_groups', default=[]
), user_groups
)
)
pass
......@@ -2,12 +2,12 @@
Tests for SplitTestTransformer.
"""
from openedx.core.djangoapps.user_api.partition_schemes import RandomUserPartitionScheme
from opaque_keys.edx.keys import CourseKey
from student.tests.factories import CourseEnrollmentFactory
from xmodule.partitions.partitions import Group, UserPartition
from course_blocks.transformers.split_test import SplitTestTransformer
from course_blocks.api import get_course_blocks, clear_course_from_cache
from course_blocks.transformers.user_partitions import UserPartitionTransformer
from course_blocks.api import get_course_blocks
from lms.djangoapps.course_blocks.transformers.tests.test_helpers import CourseStructureTestCase
from course_blocks.transformers.helpers import get_user_partition_groups
......@@ -43,9 +43,9 @@ class SplitTestTransformerTestCase(CourseStructureTestCase):
# Enroll user in course.
CourseEnrollmentFactory.create(user=self.user, course_id=self.course.id, is_active=True)
self.transformer = SplitTestTransformer()
self.transformer = UserPartitionTransformer()
def get_course_hierarchy(self):
"""
Get a course hierarchy to test with.
......@@ -148,6 +148,3 @@ class SplitTestTransformerTestCase(CourseStructureTestCase):
)
self.assertEqual(set(reloaded_structure.get_block_keys()), set(self.get_block_key_set(*expected_blocks)))
def test_staff_user(self):
self.assert_staff_access_to_all_blocks(self.staff, self.course, self.blocks, self.transformer)
"""
User Partitions Transformer, used to filter course structure per user.
User Partitions Transformer
"""
from openedx.core.lib.block_cache.transformer import BlockStructureTransformer
from .helpers import get_user_partition_groups
from .split_test import SplitTestTransformer
class MergedGroupAccess(object):
......@@ -122,6 +123,9 @@ class UserPartitionTransformer(BlockStructureTransformer):
Arguments:
block_structure (BlockStructureCollectedData)
"""
# First have the split test transformer setup its group access data for each block.
SplitTestTransformer.collect(block_structure)
# Because user partitions are course-wide, only store data for them on the root block.
root_block = block_structure.get_xblock(block_structure.root_block_key)
user_partitions = getattr(root_block, 'user_partitions', []) or []
......@@ -153,6 +157,8 @@ class UserPartitionTransformer(BlockStructureTransformer):
user_info (object)
block_structure (BlockStructureCollectedData)
"""
SplitTestTransformer().transform(user_info, block_structure)
user_partitions = block_structure.get_transformer_data(self, 'user_partitions')
if not user_partitions or user_info.has_staff_access:
......
......@@ -45,21 +45,19 @@ class BlockStructure(object):
def get_block_keys(self):
return self._block_relations.iterkeys()
def topological_traversal(self, get_result=None, predicate=None):
def topological_traversal(self, **kwargs):
return traverse_topologically(
start_node=self.root_block_key,
get_parents=self.get_parents,
get_children=self.get_children,
get_result=get_result,
predicate=predicate,
**kwargs
)
def post_order_traversal(self, get_result=None, predicate=None):
def post_order_traversal(self, **kwargs):
return traverse_post_order(
start_node=self.root_block_key,
get_children=self.get_children,
get_result=get_result,
predicate=predicate,
**kwargs
)
def prune(self):
......@@ -67,14 +65,6 @@ class BlockStructure(object):
pruned_block_relations = defaultdict(self.BlockRelations)
old_block_relations = self._block_relations
# def do_for_each_block(block_key):
# if block_key in old_block_relations:
# self._add_block(pruned_block_relations, block_key)
#
# for parent in old_block_relations[block_key].parents:
# if parent in pruned_block_relations:
# self._add_relation(pruned_block_relations, parent, block_key)
def do_for_each_block(block_key):
if block_key in old_block_relations:
self._add_block(pruned_block_relations, block_key)
......@@ -157,13 +147,13 @@ class BlockStructureBlockData(BlockStructure):
self._block_relations.pop(usage_key, None)
self._block_data_map.pop(usage_key, None)
def remove_block_if(self, removal_condition):
def remove_block_if(self, removal_condition, **kwargs):
def predicate(block_key):
if removal_condition(block_key):
self.remove_block(block_key)
return False
return True
list(self.topological_traversal(predicate=predicate))
list(self.topological_traversal(predicate=predicate, **kwargs))
class BlockStructureCollectedData(BlockStructureBlockData):
......
......@@ -4,7 +4,9 @@
from collections import deque
def _traverse_generic(start_node, get_parents, get_children, get_result=None, predicate=None):
def _traverse_generic(
start_node, get_parents, get_children, get_result=None, predicate=None, yield_descendants_of_unyielded=False
):
"""
Helper function to avoid duplicating functionality between
traverse_depth_first and traverse_topologically.
......@@ -21,6 +23,9 @@ def _traverse_generic(start_node, get_parents, get_children, get_result=None, pr
get_children - function that returns a list of children nodes for the given node
get_result - function that computes and returns the resulting value to be yielded for the given node
predicate - function that returns whether or not to yield the given node
yield_descendants_of_unyielded -
if False, all descendants of an unyielded node are not yielded.
if True, descendants of an unyielded node are yielded even if none of their parents were yielded.
"""
# If get_result or predicate aren't provided, just make them to no-ops.
get_result = get_result or (lambda node_: node_)
......@@ -44,7 +49,7 @@ def _traverse_generic(start_node, get_parents, get_children, get_result=None, pr
parents = get_parents(curr_node)
all_parents_visited = all(parent in yield_results for parent in parents)
any_parent_yielded = any(yield_results[parent] for parent in parents) if all_parents_visited else False
if not all_parents_visited or not any_parent_yielded:
if not all_parents_visited or (not yield_descendants_of_unyielded and not any_parent_yielded):
continue
# Add its unvisited children to the stack in reverse order so that
......@@ -71,23 +76,21 @@ def _traverse_generic(start_node, get_parents, get_children, get_result=None, pr
yield get_result(curr_node)
def traverse_topologically(start_node, get_parents, get_children, get_result=None, predicate=None):
def traverse_topologically(start_node, get_parents, get_children, **kwargs):
return _traverse_generic(
start_node,
get_parents=get_parents,
get_children=get_children,
get_result=get_result,
predicate=predicate
**kwargs
)
def traverse_pre_order(start_node, get_children, get_result=None, predicate=None):
def traverse_pre_order(start_node, get_children, **kwargs):
return _traverse_generic(
start_node,
get_parents=None,
get_children=get_children,
get_result=get_result,
predicate=predicate
**kwargs
)
......
......@@ -59,9 +59,6 @@ class GraphTraversalsTestCase(TestCase):
return result
def test_pre_order(self):
"""
...
"""
self.assertEqual(
list(traverse_pre_order(
start_node='b1',
......@@ -73,9 +70,6 @@ class GraphTraversalsTestCase(TestCase):
)
def test_post_order(self):
"""
...
"""
self.assertEqual(
list(traverse_post_order(
start_node='b1',
......@@ -87,9 +81,6 @@ class GraphTraversalsTestCase(TestCase):
)
def test_topological(self):
"""
...
"""
self.assertEqual(
list(traverse_topologically(
start_node='b1',
......@@ -100,16 +91,38 @@ class GraphTraversalsTestCase(TestCase):
['b1', 'c1', 'd1', 'd2', 'e1', 'e2', 'f1', 'c2']
)
def test_topological_with_predicate(self):
"""
...
"""
def test_topological_yield_descendants(self):
self.assertEqual(
list(traverse_topologically(
start_node='b1',
get_children=(lambda node: self.graph_1[node]),
get_parents=(lambda node: self.graph_1_parents[node]),
predicate=(lambda node: node != 'd2'),
yield_descendants_of_unyielded=True,
)),
['b1', 'c1', 'd1', 'e1', 'e2', 'f1', 'c2', 'd3']
)
def test_topological_not_yield_descendants(self):
self.assertEqual(
list(traverse_topologically(
start_node='b1',
get_children=(lambda node: self.graph_1[node]),
get_parents=(lambda node: self.graph_1_parents[node]),
predicate=(lambda node: node != 'd2')
predicate=(lambda node: node != 'd2'),
yield_descendants_of_unyielded=False,
)),
['b1', 'c1', 'd1', 'e1', 'c2', 'd3']
)
def test_topological_yield_single_node(self):
self.assertEqual(
list(traverse_topologically(
start_node='b1',
get_children=(lambda node: self.graph_1[node]),
get_parents=(lambda node: self.graph_1_parents[node]),
predicate=(lambda node: node == 'c2'),
yield_descendants_of_unyielded=True,
)),
['c2']
)
......@@ -52,7 +52,6 @@ setup(
"visibility = lms.djangoapps.course_blocks.transformers.visibility:VisibilityTransformer",
"start_date = lms.djangoapps.course_blocks.transformers.start_date:StartDateTransformer",
"user_partitions = lms.djangoapps.course_blocks.transformers.user_partitions:UserPartitionTransformer",
"split_test = lms.djangoapps.course_blocks.transformers.split_test:SplitTestTransformer",
"library_content = lms.djangoapps.course_blocks.transformers.library_content:ContentLibraryTransformer",
"blocks_api = lms.djangoapps.course_api.blocks.transformers.blocks_api:BlocksAPITransformer",
],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment