Commit e394c181 by Calen Pennington

Merge pull request #7476 from cpennington/bulk-assert-all-assertions

Make BulkAssertionTest.bulk_assertions work with any assert* method
parents c1f3f558 3d7bd9aa
......@@ -7,35 +7,35 @@ Run like this:
"""
import inspect
import json
import os
import pprint
import sys
import traceback
import unittest
import inspect
import mock
from contextlib import contextmanager
from contextlib import contextmanager, nested
from eventtracking import tracker
from eventtracking.django import DjangoTracker
from functools import wraps
from lazy import lazy
from mock import Mock
from mock import Mock, patch
from operator import attrgetter
from path import path
from eventtracking import tracker
from eventtracking.django import DjangoTracker
from opaque_keys.edx.locations import SlashSeparatedCourseKey
from xblock.field_data import DictFieldData
from xblock.fields import ScopeIds, Scope, Reference, ReferenceList, ReferenceValueDict
from xmodule.x_module import ModuleSystem, XModuleDescriptor, XModuleMixin
from xmodule.modulestore.inheritance import InheritanceMixin, own_metadata
from opaque_keys.edx.locations import SlashSeparatedCourseKey
from xmodule.mako_module import MakoDescriptorSystem
from xmodule.error_module import ErrorDescriptor
from xmodule.assetstore import AssetMetadata
from xmodule.error_module import ErrorDescriptor
from xmodule.mako_module import MakoDescriptorSystem
from xmodule.modulestore import ModuleStoreEnum
from xmodule.modulestore.draft_and_published import DIRECT_ONLY_CATEGORIES, ModuleStoreDraftAndPublished
from xmodule.modulestore.inheritance import InheritanceMixin, own_metadata
from xmodule.modulestore.mongo.draft import DraftModuleStore
from xmodule.modulestore.xml import CourseLocationManager
from xmodule.modulestore.draft_and_published import DIRECT_ONLY_CATEGORIES, ModuleStoreDraftAndPublished
from xmodule.x_module import ModuleSystem, XModuleDescriptor, XModuleMixin
MODULE_DIR = path(__file__).dirname()
......@@ -58,7 +58,7 @@ class TestModuleSystem(ModuleSystem): # pylint: disable=abstract-method
"""
ModuleSystem for testing
"""
@mock.patch('eventtracking.tracker.emit')
@patch('eventtracking.tracker.emit')
def __init__(self, mock_emit, **kwargs): # pylint: disable=unused-argument
id_manager = CourseLocationManager(kwargs['course_id'])
kwargs.setdefault('id_reader', id_manager)
......@@ -239,57 +239,152 @@ def map_references(value, field, actual_course_key):
return value
class BulkAssertionManager(object):
class BulkAssertionError(AssertionError):
"""
An AssertionError that contains many sub-assertions.
"""
def __init__(self, assertion_errors):
self.errors = assertion_errors
super(BulkAssertionError, self).__init__("The following assertions were raised:\n{}".format(
"\n\n".join(self.errors)
))
class _BulkAssertionManager(object):
"""
This provides a facility for making a large number of assertions, and seeing all of
the failures at once, rather than only seeing single failures.
"""
def __init__(self, test_case):
self._equal_assertions = []
self._assertion_errors = []
self._test_case = test_case
def run_assertions(self):
if len(self._equal_assertions) > 0:
raise AssertionError(self._equal_assertions)
def log_error(self, formatted_exc):
"""
Record ``formatted_exc`` in the set of exceptions captured by this assertion manager.
"""
self._assertion_errors.append(formatted_exc)
def raise_assertion_errors(self):
"""
Raise a BulkAssertionError containing all of the captured AssertionErrors,
if there were any.
"""
if self._assertion_errors:
raise BulkAssertionError(self._assertion_errors)
class BulkAssertionTest(unittest.TestCase):
"""
This context manager provides a BulkAssertionManager to assert with,
and then calls `run_assertions` at the end of the block to validate all
This context manager provides a _BulkAssertionManager to assert with,
and then calls `raise_assertion_errors` at the end of the block to validate all
of the assertions.
"""
def setUp(self, *args, **kwargs):
super(BulkAssertionTest, self).setUp(*args, **kwargs)
self._manager = None
# Use __ to not pollute the namespace of subclasses with what could be a fairly generic name.
self.__manager = None
@contextmanager
def bulk_assertions(self):
if self._manager:
"""
A context manager that will capture all assertion failures made by self.assert*
methods within its context, and raise a single combined assertion error at
the end of the context.
"""
if self.__manager:
yield
else:
try:
self._manager = BulkAssertionManager(self)
self.__manager = _BulkAssertionManager(self)
yield
finally:
self._manager.run_assertions()
self._manager = None
except Exception:
raise
else:
manager = self.__manager
self.__manager = None
manager.raise_assertion_errors()
def assertEqual(self, expected, actual, message=None):
if self._manager is not None:
try:
super(BulkAssertionTest, self).assertEqual(expected, actual, message)
except Exception as error: # pylint: disable=broad-except
exc_stack = inspect.stack()[1]
if message is not None:
msg = '{} -> {}:{} -> {}'.format(message, exc_stack[1], exc_stack[2], unicode(error))
else:
msg = '{}:{} -> {}'.format(exc_stack[1], exc_stack[2], unicode(error))
self._manager._equal_assertions.append(msg) # pylint: disable=protected-access
@contextmanager
def _capture_assertion_errors(self):
"""
A context manager that captures any AssertionError raised within it,
and, if within a ``bulk_assertions`` context, records the captured
assertion to the bulk assertion manager. If not within a ``bulk_assertions``
context, just raises the original exception.
"""
try:
# Only wrap the first layer of assert functions by stashing away the manager
# before executing the assertion.
manager = self.__manager
self.__manager = None
yield
except AssertionError: # pylint: disable=broad-except
if manager is not None:
# Reconstruct the stack in which the error was thrown (so that the traceback)
# isn't cut off at `assertion(*args, **kwargs)`.
exc_type, exc_value, exc_tb = sys.exc_info()
# Count the number of stack frames before you get to a
# unittest context (walking up the stack from here).
relevant_frames = 0
for frame_record in inspect.stack():
# This is the same criterion used by unittest to decide if a
# stack frame is relevant to exception printing.
frame = frame_record[0]
if '__unittest' in frame.f_globals:
break
relevant_frames += 1
stack_above = traceback.extract_stack()[-relevant_frames:-1]
stack_below = traceback.extract_tb(exc_tb)
formatted_stack = traceback.format_list(stack_above + stack_below)
formatted_exc = traceback.format_exception_only(exc_type, exc_value)
manager.log_error(
"".join(formatted_stack + formatted_exc)
)
else:
raise
finally:
self.__manager = manager
def _wrap_assertion(self, assertion):
"""
Wraps an assert* method to capture an immediate exception,
or to generate a new assertion capturing context (in the case of assertRaises
and assertRaisesRegexp).
"""
@wraps(assertion)
def assert_(*args, **kwargs):
"""
Execute a captured assertion, and catch any assertion errors raised.
"""
context = None
# Run the assertion, and capture any raised assertionErrors
with self._capture_assertion_errors():
context = assertion(*args, **kwargs)
# Handle the assertRaises family of functions by returning
# a context manager that surrounds the assertRaises
# with our assertion capturing context manager.
if context is not None:
return nested(self._capture_assertion_errors(), context)
return assert_
def __getattribute__(self, name):
"""
Wrap all assert* methods of this class using self._wrap_assertion,
to capture all assertion errors in bulk.
"""
base_attr = super(BulkAssertionTest, self).__getattribute__(name)
if name.startswith('assert'):
return self._wrap_assertion(base_attr)
else:
super(BulkAssertionTest, self).assertEqual(expected, actual, message)
assertEquals = assertEqual
return base_attr
class LazyFormat(object):
......@@ -312,6 +407,12 @@ class LazyFormat(object):
def __repr__(self):
return unicode(self)
def __len__(self):
return len(unicode(self))
def __getitem__(self, index):
return unicode(self)[index]
class CourseComparisonTest(BulkAssertionTest):
"""
......@@ -456,6 +557,13 @@ class CourseComparisonTest(BulkAssertionTest):
for item in actual_items
}
# Split Mongo and Old-Mongo disagree about what the block_id of courses is, so skip those in
# this comparison
self.assertItemsEqual(
[map_key(item.location) for item in expected_items if item.scope_ids.block_type != 'course'],
[key for key in actual_item_map.keys() if key[0] != 'course'],
)
for expected_item in expected_items:
actual_item_location = actual_course_key.make_usage_key(expected_item.category, expected_item.location.block_id)
# split and old mongo use different names for the course root but we don't know which
......@@ -469,7 +577,10 @@ class CourseComparisonTest(BulkAssertionTest):
actual_item = actual_item_map.get(map_key(actual_item_location))
# Formatting the message slows down tests of large courses significantly, so only do it if it would be used
self.assertIsNotNone(actual_item, LazyFormat(u'cannot find {} in {}', map_key(actual_item_location), actual_item_map))
self.assertIn(map_key(actual_item_location), actual_item_map.keys())
if actual_item is None:
continue
# compare fields
self.assertEqual(expected_item.fields, actual_item.fields)
......
import ddt
from xmodule.tests import BulkAssertionTest
import itertools
from xmodule.tests import BulkAssertionTest, BulkAssertionError
STATIC_PASSING_ASSERTIONS = (
('assertTrue', True),
('assertFalse', False),
('assertIs', 1, 1),
('assertEqual', 1, 1),
('assertEquals', 1, 1),
('assertIsNot', 1, 2),
('assertIsNone', None),
('assertIsNotNone', 1),
('assertIn', 1, (1, 2, 3)),
('assertNotIn', 5, (1, 2, 3)),
('assertIsInstance', 1, int),
('assertNotIsInstance', '1', int),
('assertItemsEqual', [1, 2, 3], [3, 2, 1])
)
STATIC_FAILING_ASSERTIONS = (
('assertTrue', False),
('assertFalse', True),
('assertIs', 1, 2),
('assertEqual', 1, 2),
('assertEquals', 1, 2),
('assertIsNot', 1, 1),
('assertIsNone', 1),
('assertIsNotNone', None),
('assertIn', 5, (1, 2, 3)),
('assertNotIn', 1, (1, 2, 3)),
('assertIsInstance', '1', int),
('assertNotIsInstance', 1, int),
('assertItemsEqual', [1, 1, 1], [1, 1])
)
CONTEXT_PASSING_ASSERTIONS = (
('assertRaises', KeyError, {}.__getitem__, '1'),
('assertRaisesRegexp', KeyError, "1", {}.__getitem__, '1'),
)
CONTEXT_FAILING_ASSERTIONS = (
('assertRaises', ValueError, lambda: None),
('assertRaisesRegexp', KeyError, "2", {}.__getitem__, '1'),
)
@ddt.ddt
class TestBulkAssertionTestCase(BulkAssertionTest):
@ddt.data(
('assertTrue', True),
('assertFalse', False),
('assertIs', 1, 1),
('assertIsNot', 1, 2),
('assertIsNone', None),
('assertIsNotNone', 1),
('assertIn', 1, (1, 2, 3)),
('assertNotIn', 5, (1, 2, 3)),
('assertIsInstance', 1, int),
('assertNotIsInstance', '1', int),
('assertRaises', KeyError, {}.__getitem__, '1'),
)
@ddt.unpack
def test_passing_asserts_passthrough(self, assertion, *args):
# We have to use assertion methods from the base UnitTest class,
# so we make a number of super calls that skip BulkAssertionTest.
# pylint: disable=bad-super-call
def _run_assertion(self, assertion_tuple):
"""
Run the supplied tuple of (assertion, *args) as a method on this class.
"""
assertion, args = assertion_tuple[0], assertion_tuple[1:]
getattr(self, assertion)(*args)
@ddt.data(
('assertTrue', False),
('assertFalse', True),
('assertIs', 1, 2),
('assertIsNot', 1, 1),
('assertIsNone', 1),
('assertIsNotNone', None),
('assertIn', 5, (1, 2, 3)),
('assertNotIn', 1, (1, 2, 3)),
('assertIsInstance', '1', int),
('assertNotIsInstance', 1, int),
('assertRaises', ValueError, lambda: None),
)
@ddt.unpack
def test_failing_asserts_passthrough(self, assertion, *args):
def _raw_assert(self, assertion_name, *args, **kwargs):
"""
Run an un-modified assertion.
"""
# Use super(BulkAssertionTest) to make sure we get un-adulturated assertions
with super(BulkAssertionTest, self).assertRaises(AssertionError):
getattr(self, assertion)(*args)
return getattr(super(BulkAssertionTest, self), 'assert' + assertion_name)(*args, **kwargs)
def test_no_bulk_assert_equals(self):
# Use super(BulkAssertionTest) to make sure we get un-adulturated assertions
with super(BulkAssertionTest, self).assertRaises(AssertionError):
self.assertEquals(1, 2)
@ddt.data(
'assertEqual', 'assertEquals'
)
def test_bulk_assert_equals(self, asserterFn):
asserter = getattr(self, asserterFn)
@ddt.data(*(STATIC_PASSING_ASSERTIONS + CONTEXT_PASSING_ASSERTIONS))
def test_passing_asserts_passthrough(self, assertion_tuple):
self._run_assertion(assertion_tuple)
@ddt.data(*(STATIC_FAILING_ASSERTIONS + CONTEXT_FAILING_ASSERTIONS))
def test_failing_asserts_passthrough(self, assertion_tuple):
with self._raw_assert('Raises', AssertionError) as context:
self._run_assertion(assertion_tuple)
self._raw_assert('NotIsInstance', context.exception, BulkAssertionError)
@ddt.data(*CONTEXT_PASSING_ASSERTIONS)
@ddt.unpack
def test_passing_context_assertion_passthrough(self, assertion, *args):
assertion_args = []
args = list(args)
exception = args.pop(0)
while not callable(args[0]):
assertion_args.append(args.pop(0))
function = args.pop(0)
with getattr(self, assertion)(exception, *assertion_args):
function(*args)
@ddt.data(*CONTEXT_FAILING_ASSERTIONS)
@ddt.unpack
def test_failing_context_assertion_passthrough(self, assertion, *args):
assertion_args = []
args = list(args)
exception = args.pop(0)
while not callable(args[0]):
assertion_args.append(args.pop(0))
function = args.pop(0)
with self._raw_assert('Raises', AssertionError) as context:
with getattr(self, assertion)(exception, *assertion_args):
function(*args)
self._raw_assert('NotIsInstance', context.exception, BulkAssertionError)
@ddt.data(*list(itertools.product(
CONTEXT_PASSING_ASSERTIONS,
CONTEXT_FAILING_ASSERTIONS,
CONTEXT_FAILING_ASSERTIONS
)))
@ddt.unpack
def test_bulk_assert(self, passing_assertion, failing_assertion1, failing_assertion2):
contextmanager = self.bulk_assertions()
contextmanager.__enter__()
super(BulkAssertionTest, self).assertIsNotNone(self._manager)
asserter(1, 2)
asserter(3, 4)
self._run_assertion(passing_assertion)
self._run_assertion(failing_assertion1)
self._run_assertion(failing_assertion2)
# Use super(BulkAssertionTest) to make sure we get un-adulturated assertions
with super(BulkAssertionTest, self).assertRaises(AssertionError):
with self._raw_assert('Raises', BulkAssertionError) as context:
contextmanager.__exit__(None, None, None)
@ddt.data(
'assertEqual', 'assertEquals'
)
def test_bulk_assert_closed(self, asserterFn):
asserter = getattr(self, asserterFn)
self._raw_assert('Equals', len(context.exception.errors), 2)
@ddt.data(*list(itertools.product(
CONTEXT_FAILING_ASSERTIONS
)))
@ddt.unpack
def test_nested_bulk_asserts(self, failing_assertion):
with self._raw_assert('Raises', BulkAssertionError) as context:
with self.bulk_assertions():
self._run_assertion(failing_assertion)
with self.bulk_assertions():
self._run_assertion(failing_assertion)
self._run_assertion(failing_assertion)
with self.bulk_assertions():
asserter(1, 1)
asserter(2, 2)
self._raw_assert('Equal', len(context.exception.errors), 3)
# Use super(BulkAssertionTest) to make sure we get un-adulturated assertions
with super(BulkAssertionTest, self).assertRaises(AssertionError):
asserter(1, 2)
@ddt.data(*list(itertools.product(
CONTEXT_PASSING_ASSERTIONS,
CONTEXT_FAILING_ASSERTIONS,
CONTEXT_FAILING_ASSERTIONS
)))
@ddt.unpack
def test_bulk_assert_closed(self, passing_assertion, failing_assertion1, failing_assertion2):
with self._raw_assert('Raises', BulkAssertionError) as context:
with self.bulk_assertions():
self._run_assertion(passing_assertion)
self._run_assertion(failing_assertion1)
self._raw_assert('Equals', len(context.exception.errors), 1)
with self._raw_assert('Raises', AssertionError) as context:
self._run_assertion(failing_assertion2)
self._raw_assert('NotIsInstance', context.exception, BulkAssertionError)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment