Commit ab25d900 by Ned Batchelder

Required base class checking works

parent 0a0a214f
...@@ -2,9 +2,6 @@ ...@@ -2,9 +2,6 @@
import collections import collections
import astroid
from astroid.scoped_nodes import get_locals # not sure this is the right import
from pylint.checkers import BaseChecker, utils from pylint.checkers import BaseChecker, utils
from pylint.interfaces import IAstroidChecker from pylint.interfaces import IAstroidChecker
...@@ -47,6 +44,7 @@ class RequiredBaseClassChecker(BaseChecker): ...@@ -47,6 +44,7 @@ class RequiredBaseClassChecker(BaseChecker):
self.class_map = collections.defaultdict(set) self.class_map = collections.defaultdict(set)
def open(self): def open(self):
# pylint: disable=no-member
if self.config.required_base_class: if self.config.required_base_class:
for pair in self.config.required_base_class: for pair in self.config.required_base_class:
child, parent = pair.split(":") child, parent = pair.split(":")
...@@ -55,5 +53,22 @@ class RequiredBaseClassChecker(BaseChecker): ...@@ -55,5 +53,22 @@ class RequiredBaseClassChecker(BaseChecker):
@utils.check_messages(MESSAGE_ID) @utils.check_messages(MESSAGE_ID)
def visit_class(self, node): def visit_class(self, node):
"""Check each class.""" """Check each class."""
if self.class_map: if not self.class_map:
self.add_message(self.MESSAGE_ID, args=("Foo", self.class_map['Foo']), node=node) return
all_bases = [usable_class_name(c) for c in node.mro()]
for base in all_bases:
required = self.class_map.get(base)
if required is not None:
if not all(r in all_bases for r in required):
nice_required = ", ".join(sorted(required))
self.add_message(self.MESSAGE_ID, args=(node.name, nice_required), node=node)
def usable_class_name(node):
"""Make a reasonable class name for a class node."""
name = node.qname()
for prefix in ["__builtin__.", "builtins."]:
if name.startswith(prefix):
name = name[len(prefix):]
return name
"""Unit tests for required-base-classes.""" """Unit tests for required-base-classes."""
import unittest
import astroid import astroid
from pylint.testutils import CheckerTestCase, Message, set_config from pylint.testutils import CheckerTestCase, Message, set_config
...@@ -9,21 +7,63 @@ from edx_lint.pylint.required_base_class import RequiredBaseClassChecker ...@@ -9,21 +7,63 @@ from edx_lint.pylint.required_base_class import RequiredBaseClassChecker
class RequiredBaseClassTestCase(CheckerTestCase): class RequiredBaseClassTestCase(CheckerTestCase):
"""Unittest tests of RequiredBaseClassChecker."""
CHECKER_CLASS = RequiredBaseClassChecker CHECKER_CLASS = RequiredBaseClassChecker
def test_something(self): def get_class_node(self, code):
node = astroid.parse(''' """Parse `code`, and return the last class node.
class MyClass(object):
pass The code should have at least one class definition.
"""
node = astroid.parse(code)
class_node = None
for body_node in node.body:
if getattr(body_node, 'type', 'none') == "class":
class_node = body_node
return class_node
def test_no_messages_by_default(self):
node = self.get_class_node('''
class MyClass(object):
pass
''')
with self.assertNoMessages():
self.checker.visit_class(node)
@set_config(required_base_class=["BaseClass:MyMixin"])
def test_no_messages_if_class_not_used(self):
node = self.get_class_node('''
class MyClass(object):
pass
''') ''')
with self.assertNoMessages(): with self.assertNoMessages():
self.checker.visit_class(node) self.checker.visit_class(node)
@set_config(required_base_class=["Foo:Bax"]) @set_config(required_base_class=["unittest.case.TestCase:.MyTestMixin"])
def test_wut(self): def test_error_if_class_is_not_used(self):
node = astroid.parse(''' node = self.get_class_node('''
class MyClass(object): from unittest import TestCase
pass class MyClass(TestCase):
pass
''')
expected_msg = Message(
'missing-required-base-class',
node=node,
args=('MyClass', '.MyTestMixin'),
)
with self.assertAddsMessages(expected_msg):
self.checker.visit_class(node)
@set_config(required_base_class=["unittest.case.TestCase:.MyTestMixin"])
def test_no_messages_if_class_is_used(self):
node = self.get_class_node('''
from unittest import TestCase
class MyTestMixin(object):
pass
class MyClass(MyTestMixin, TestCase):
pass
''') ''')
with self.assertNoMessages(): with self.assertNoMessages():
self.checker.visit_class(node) self.checker.visit_class(node)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment