Commit 6cd7e08c by Calen Pennington

Assert course equality in bulk, so that all differences are displayed

parent f1da4efb
...@@ -12,6 +12,7 @@ import os ...@@ -12,6 +12,7 @@ import os
import pprint import pprint
import unittest import unittest
from contextlib import contextmanager
from mock import Mock from mock import Mock
from path import path from path import path
...@@ -174,12 +175,64 @@ def map_references(value, field, actual_course_key): ...@@ -174,12 +175,64 @@ def map_references(value, field, actual_course_key):
return value return value
class CourseComparisonTest(unittest.TestCase): class BulkAssertionManager(object):
"""
This provides a facility for making a large number of assertions, and seeing all of
the failures at once, rather than only seeing single failures.
"""
def __init__(self, test_case):
self._equal_expected = []
self._equal_actual = []
self._test_case = test_case
def assertEqual(self, expected, actual, description=None):
if description is None:
description = "{!r} does not equal {!r}".format(expected, actual)
if expected != actual:
self._equal_expected.append((description, expected))
self._equal_actual.append((description, actual))
def run_assertions(self):
self._test_case.assertEqual(self._equal_expected, self._equal_actual)
class BulkAssertionTest(unittest.TestCase):
"""
This context manager provides a BulkAssertionManager to assert with,
and then calls `run_assertions` at the end of the block to validate all
of the assertions.
"""
def setUp(self, *args, **kwargs):
super(BulkAssertionTest, self).setUp(*args, **kwargs)
self._manager = None
@contextmanager
def bulk_assertions(self):
if self._manager:
yield
else:
try:
self._manager = BulkAssertionManager(self)
yield
finally:
self._manager.run_assertions()
self._manager = None
def assertEqual(self, expected, actual, message=None):
if self._manager is not None:
self._manager.assertEqual(expected, actual, message)
else:
super(BulkAssertionTest, self).assertEqual(expected, actual, message)
class CourseComparisonTest(BulkAssertionTest):
""" """
Mixin that has methods for comparing courses for equality. Mixin that has methods for comparing courses for equality.
""" """
def setUp(self): def setUp(self):
super(CourseComparisonTest, self).setUp()
self.field_exclusions = set() self.field_exclusions = set()
self.ignored_asset_keys = set() self.ignored_asset_keys = set()
...@@ -235,68 +288,69 @@ class CourseComparisonTest(unittest.TestCase): ...@@ -235,68 +288,69 @@ class CourseComparisonTest(unittest.TestCase):
self._assertCoursesEqual(expected_items, actual_items, actual_course_key, expect_drafts=True) self._assertCoursesEqual(expected_items, actual_items, actual_course_key, expect_drafts=True)
def _assertCoursesEqual(self, expected_items, actual_items, actual_course_key, expect_drafts=False): def _assertCoursesEqual(self, expected_items, actual_items, actual_course_key, expect_drafts=False):
self.assertEqual(len(expected_items), len(actual_items)) with self.bulk_assertions():
self.assertEqual(len(expected_items), len(actual_items))
actual_item_map = {
item.location.block_id: item actual_item_map = {
for item in actual_items item.location.block_id: item
} for item in actual_items
}
for expected_item in expected_items:
actual_item_location = actual_course_key.make_usage_key(expected_item.category, expected_item.location.block_id) for expected_item in expected_items:
# split and old mongo use different names for the course root but we don't know which actual_item_location = actual_course_key.make_usage_key(expected_item.category, expected_item.location.block_id)
# modulestore actual's come from here; so, assume old mongo and if that fails, assume split # split and old mongo use different names for the course root but we don't know which
if expected_item.location.category == 'course': # modulestore actual's come from here; so, assume old mongo and if that fails, assume split
actual_item_location = actual_item_location.replace(name=actual_item_location.run) if expected_item.location.category == 'course':
actual_item = actual_item_map.get(actual_item_location.block_id) actual_item_location = actual_item_location.replace(name=actual_item_location.run)
# must be split
if actual_item is None and expected_item.location.category == 'course':
actual_item_location = actual_item_location.replace(name='course')
actual_item = actual_item_map.get(actual_item_location.block_id) actual_item = actual_item_map.get(actual_item_location.block_id)
self.assertIsNotNone(actual_item, u'cannot find {} in {}'.format(actual_item_location, actual_item_map)) # must be split
if actual_item is None and expected_item.location.category == 'course':
# compare fields actual_item_location = actual_item_location.replace(name='course')
self.assertEqual(expected_item.fields, actual_item.fields) actual_item = actual_item_map.get(actual_item_location.block_id)
self.assertIsNotNone(actual_item, u'cannot find {} in {}'.format(actual_item_location, actual_item_map))
for field_name, field in expected_item.fields.iteritems():
if (expected_item.scope_ids.usage_id, field_name) in self.field_exclusions: # compare fields
continue self.assertEqual(expected_item.fields, actual_item.fields)
if (None, field_name) in self.field_exclusions: for field_name, field in expected_item.fields.iteritems():
continue if (expected_item.scope_ids.usage_id, field_name) in self.field_exclusions:
continue
# Children are handled specially
if field_name == 'children': if (None, field_name) in self.field_exclusions:
continue continue
exp_value = map_references(field.read_from(expected_item), field, actual_course_key) # Children are handled specially
actual_value = field.read_from(actual_item) if field_name == 'children':
self.assertEqual( continue
exp_value,
actual_value, exp_value = map_references(field.read_from(expected_item), field, actual_course_key)
"Field {!r} doesn't match between usages {} and {}: {!r} != {!r}".format( actual_value = field.read_from(actual_item)
field_name, self.assertEqual(
expected_item.scope_ids.usage_id,
actual_item.scope_ids.usage_id,
exp_value, exp_value,
actual_value, actual_value,
"Field {!r} doesn't match between usages {} and {}: {!r} != {!r}".format(
field_name,
expected_item.scope_ids.usage_id,
actual_item.scope_ids.usage_id,
exp_value,
actual_value,
)
) )
)
# compare children
# compare children self.assertEqual(expected_item.has_children, actual_item.has_children)
self.assertEqual(expected_item.has_children, actual_item.has_children) if expected_item.has_children:
if expected_item.has_children: expected_children = [
expected_children = [ (expected_item_child.location.block_type, expected_item_child.location.block_id)
(expected_item_child.location.block_type, expected_item_child.location.block_id) # get_children() rather than children to strip privates from public parents
# get_children() rather than children to strip privates from public parents for expected_item_child in expected_item.get_children()
for expected_item_child in expected_item.get_children() ]
] actual_children = [
actual_children = [ (item_child.location.block_type, item_child.location.block_id)
(item_child.location.block_type, item_child.location.block_id) # get_children() rather than children to strip privates from public parents
# get_children() rather than children to strip privates from public parents for item_child in actual_item.get_children()
for item_child in actual_item.get_children() ]
] self.assertEqual(expected_children, actual_children)
self.assertEqual(expected_children, actual_children)
def assertAssetEqual(self, expected_course_key, expected_asset, actual_course_key, actual_asset): def assertAssetEqual(self, expected_course_key, expected_asset, actual_course_key, actual_asset):
""" """
...@@ -339,10 +393,12 @@ class CourseComparisonTest(unittest.TestCase): ...@@ -339,10 +393,12 @@ class CourseComparisonTest(unittest.TestCase):
expected_content, expected_count = expected_store.get_all_content_for_course(expected_course_key) expected_content, expected_count = expected_store.get_all_content_for_course(expected_course_key)
actual_content, actual_count = actual_store.get_all_content_for_course(actual_course_key) actual_content, actual_count = actual_store.get_all_content_for_course(actual_course_key)
self.assertEqual(expected_count, actual_count) with self.bulk_assertions():
self._assertAssetsEqual(expected_course_key, expected_content, actual_course_key, actual_content)
self.assertEqual(expected_count, actual_count)
self._assertAssetsEqual(expected_course_key, expected_content, actual_course_key, actual_content)
expected_thumbs = expected_store.get_all_content_thumbnails_for_course(expected_course_key) expected_thumbs = expected_store.get_all_content_thumbnails_for_course(expected_course_key)
actual_thumbs = actual_store.get_all_content_thumbnails_for_course(actual_course_key) actual_thumbs = actual_store.get_all_content_thumbnails_for_course(actual_course_key)
self._assertAssetsEqual(expected_course_key, expected_thumbs, actual_course_key, actual_thumbs) self._assertAssetsEqual(expected_course_key, expected_thumbs, actual_course_key, actual_thumbs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment