Commit 324d4785 by Calen Pennington

Include stack traces when counting calls in unit tests

parent de142b2b
......@@ -2,12 +2,15 @@
Factories for use in tests of XBlocks.
"""
import functools
import inspect
import pprint
import pymongo.message
import threading
from uuid import uuid4
import traceback
from collections import defaultdict
from decorator import contextmanager
import pymongo.message
from uuid import uuid4
from factory import Factory, Sequence, lazy_attribute_sequence, lazy_attribute
from factory.containers import CyclicDefinitionError
......@@ -320,47 +323,157 @@ def check_number_of_calls(object_with_method, method_name, maximum_calls, minimu
return check_sum_of_calls(object_with_method, [method_name], maximum_calls, minimum_calls)
class StackTraceCounter(object):
"""
A class that counts unique stack traces underneath a particular stack frame.
"""
def __init__(self, stack_depth, include_arguments=True):
"""
Arguments:
stack_depth (int): The number of stack frames above this constructor to capture.
include_arguments (bool): Whether to store the arguments that are passed
when capturing a stack trace.
"""
self.include_arguments = include_arguments
self._top_of_stack = traceback.extract_stack(limit=stack_depth)[0]
if self.include_arguments:
self._stacks = defaultdict(lambda: defaultdict(int))
else:
self._stacks = defaultdict(int)
def capture_stack(self, args, kwargs):
"""
Record the stack frames starting at the caller of this method, and
ending at the top of the stack as defined by the ``stack_depth``.
Arguments:
args: The positional arguments to capture at this stack frame
kwargs: The keyword arguments to capture at this stack frame
"""
stack = traceback.extract_stack()[:-2]
if self._top_of_stack in stack:
stack = stack[stack.index(self._top_of_stack):]
if self.include_arguments:
safe_args = []
for arg in args:
try:
safe_args.append(repr(arg))
except Exception as exc:
safe_args.append('<un-repr-able value: {}'.format(exc))
safe_kwargs = {}
for key, kwarg in kwargs.items():
try:
safe_kwargs[key] = repr(kwarg)
except Exception as exc:
safe_kwargs[key] = '<un-repr-able value: {}'.format(exc)
self._stacks[tuple(stack)][tuple(safe_args), tuple(safe_kwargs.items())] += 1
else:
self._stacks[tuple(stack)] += 1
@property
def total_calls(self):
"""
Return the total number of stacks recorded.
"""
return sum(self.stack_calls(stack) for stack in self._stacks)
def stack_calls(self, stack):
"""
Return the number of calls to the supplied ``stack``.
"""
if self.include_arguments:
return sum(self._stacks[stack].values())
else:
return self._stacks[stack]
def __iter__(self):
"""
Iterate over all unique captured stacks.
"""
return iter(sorted(self._stacks.keys(), key=lambda stack: (self.stack_calls(stack), stack), reverse=True))
def __getitem__(self, stack):
"""
Return the set of captured calls with the supplied stack.
"""
return self._stacks[stack]
@classmethod
def capture_call(cls, func, stack_depth, include_arguments=True):
"""
A decorator that wraps ``func``, and captures each call to ``func``,
recording the stack trace, and optionally the arguments that the function
is called with.
Arguments:
func: the function to wrap
stack_depth: how far up the stack to truncate the stored stack traces (
this is counted from the call to ``capture_call``, rather than calls
to the captured function).
"""
stacks = StackTraceCounter(stack_depth, include_arguments)
@functools.wraps(func)
def capture(*args, **kwargs):
stacks.capture_stack(args, kwargs)
return func(*args, **kwargs)
capture.stack_counter = stacks
return capture
@contextmanager
def check_sum_of_calls(object_, methods, maximum_calls, minimum_calls=1):
def check_sum_of_calls(object_, methods, maximum_calls, minimum_calls=1, include_arguments=True):
"""
Instruments the given methods on the given object to verify that the total sum of calls made to the
methods falls between minumum_calls and maximum_calls.
"""
mocks = {
method: Mock(wraps=getattr(object_, method))
method: StackTraceCounter.capture_call(
getattr(object_, method),
stack_depth=7,
include_arguments=include_arguments
)
for method in methods
}
if inspect.isclass(object_):
# If the object that we're intercepting methods on is a class, rather than a module,
# then we need to set the method to a real function, so that self gets passed to it,
# and then explicitly pass that self into the call to the mock
# pylint: disable=unnecessary-lambda,cell-var-from-loop
mock_kwargs = {
method: lambda self, *args, **kwargs: mocks[method](self, *args, **kwargs)
for method in methods
}
else:
mock_kwargs = mocks
with patch.multiple(object_, **mock_kwargs):
with patch.multiple(object_, **mocks):
yield
call_count = sum(mock.call_count for mock in mocks.values())
call_count = sum(capture_fn.stack_counter.total_calls for capture_fn in mocks.values())
# Assertion errors don't handle multi-line values, so pretty-print to std-out instead
if not minimum_calls <= call_count <= maximum_calls:
calls = {
method_name: mock.call_args_list
for method_name, mock in mocks.items()
}
print "Expected between {} and {} calls, {} were made. Calls: {}".format(
messages = ["Expected between {} and {} calls, {} were made.\n\n".format(
minimum_calls,
maximum_calls,
call_count,
pprint.pformat(calls),
)
)]
for method_name, capture_fn in mocks.items():
stack_counter = capture_fn.stack_counter
messages.append("{!r} was called {} times:\n".format(
method_name,
stack_counter.total_calls
))
for stack in stack_counter:
messages.append(" called {} times:\n\n".format(stack_counter.stack_calls(stack)))
messages.append(" " + " ".join(traceback.format_list(stack)))
messages.append("\n\n")
if include_arguments:
for (args, kwargs), count in stack_counter[stack].items():
messages.append(" called {} times with:\n".format(count))
messages.append(" args: {}\n".format(args))
messages.append(" kwargs: {}\n\n".format(dict(kwargs)))
print "".join(messages)
# verify the counter actually worked by ensuring we have counted greater than (or equal to) the minimum calls
assert_greater_equal(call_count, minimum_calls)
......
......@@ -143,7 +143,7 @@ class FieldOverridePerformanceTestCase(ProceduralCourseTestMixin,
with self.assertNumQueries(queries):
with check_mongo_calls(reads):
with check_sum_of_calls(XBlock, ['__init__'], xblocks, xblocks):
with check_sum_of_calls(XBlock, ['__init__'], xblocks, xblocks, include_arguments=False):
self.grade_course(self.course)
@ddt.data(*itertools.product(('no_overrides', 'ccx'), range(1, 4), (True, False)))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment