Commit 3c0367d5 by Ned Batchelder

Polish up layered_test_check

Move method out of the class that could be useful for future thing.
Booleanify the return value just in case. Add test cases to get coverage
to 100%
parent 4214aede
...@@ -14,6 +14,26 @@ def register_checkers(linter): ...@@ -14,6 +14,26 @@ def register_checkers(linter):
linter.register_checker(LayeredTestClassChecker(linter)) linter.register_checker(LayeredTestClassChecker(linter))
def is_test_case_class(node):
"""Is this node a test class?
To be a test class, it has to derive from unittest.TestCase, and not
have __test__ defined as False.
"""
if not node.is_subtype_of('unittest.case.TestCase'):
return False
dunder_test = get_locals(node).get("__test__")
if dunder_test:
if isinstance(dunder_test[0], astroid.AssName):
value = list(dunder_test[0].assigned_stmts())
if len(value) == 1 and isinstance(value[0], astroid.Const):
return bool(value[0].value)
return True
class LayeredTestClassChecker(BaseChecker): class LayeredTestClassChecker(BaseChecker):
"""Pylint checker for tests inheriting test methods from other tests.""" """Pylint checker for tests inheriting test methods from other tests."""
...@@ -32,32 +52,13 @@ class LayeredTestClassChecker(BaseChecker): ...@@ -32,32 +52,13 @@ class LayeredTestClassChecker(BaseChecker):
} }
@utils.check_messages(MESSAGE_ID) @utils.check_messages(MESSAGE_ID)
def is_test_case_class(self, node):
"""Is this node a test class?
To be a test class, it has to derive from unittest.TestCase, and not
have __test__ defined as False.
"""
if not node.is_subtype_of('unittest.case.TestCase'):
return False
dunder_test = get_locals(node).get("__test__")
if dunder_test:
if isinstance(dunder_test[0], astroid.AssName):
value = list(dunder_test[0].assigned_stmts())
if len(value) == 1 and isinstance(value[0], astroid.Const):
return value[0].value
return True
def visit_class(self, node): def visit_class(self, node):
"""Check each class.""" """Check each class."""
if not self.is_test_case_class(node): if not is_test_case_class(node):
return return
for anc in node.ancestors(): for anc in node.ancestors():
if not self.is_test_case_class(anc): if not is_test_case_class(anc):
continue continue
for meth in anc.mymethods(): for meth in anc.mymethods():
if meth.name.startswith("test_"): if meth.name.startswith("test_"):
......
...@@ -63,3 +63,13 @@ class EmptyTestCase(unittest.TestCase): ...@@ -63,3 +63,13 @@ class EmptyTestCase(unittest.TestCase):
class ActualTestCase(EmptyTestCase): class ActualTestCase(EmptyTestCase):
def test_something(self): def test_something(self):
pass pass
# Bizzaro __test__ examples to complete branch coverage.
class WhatIsThis(unittest.TestCase):
def __test__(self):
return self.fail("I don't know what I'm doing.")
class TooTrickyForTheirOwnGood(unittest.TestCase):
__test__ = 1 - 1
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment