Skip to content

Commit

Permalink
Add basic support for user-defined mypy plugins (#3517)
Browse files Browse the repository at this point in the history
Configure them through "plugins=path/plugin.py, ..." in the ini file.
The paths are relative to the configuration file.

This is an almost minimal implementation and some features
are missing:

* Plugins installed through pip aren't properly supported.
* Plugins within packages aren't properly supported.
* Incremental mode doesn't invalidate cache files when
  plugins change.

Also change path normalization in test cases in Windows.
Previously we sometimes normalized to Windows paths and sometimes
to Linux paths. Now switching to always use Linux paths.
  • Loading branch information
JukkaL authored Jun 13, 2017
1 parent 7d630b7 commit fd0a416
Show file tree
Hide file tree
Showing 16 changed files with 253 additions and 18 deletions.
75 changes: 70 additions & 5 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from mypy.stats import dump_type_stats
from mypy.types import Type
from mypy.version import __version__
from mypy.plugin import DefaultPlugin
from mypy.plugin import Plugin, DefaultPlugin, ChainedPlugin


# We need to know the location of this file to load data, but
Expand Down Expand Up @@ -183,7 +183,9 @@ def build(sources: List[BuildSource],
reports=reports,
options=options,
version_id=__version__,
)
plugin=DefaultPlugin(options.python_version))

manager.plugin = load_custom_plugins(manager.plugin, options, manager.errors)

try:
graph = dispatch(sources, manager)
Expand Down Expand Up @@ -334,6 +336,67 @@ def import_priority(imp: ImportBase, toplevel_priority: int) -> int:
return toplevel_priority


def load_custom_plugins(default_plugin: Plugin, options: Options, errors: Errors) -> Plugin:
"""Load custom plugins if any are configured.
Return a plugin that chains all custom plugins (if any) and falls
back to default_plugin.
"""

def plugin_error(message: str) -> None:
errors.report(0, 0, message)
errors.raise_error()

custom_plugins = []
for plugin_path in options.plugins:
if options.config_file:
# Plugin paths are relative to the config file location.
plugin_path = os.path.join(os.path.dirname(options.config_file), plugin_path)
errors.set_file(plugin_path, None)

if not os.path.isfile(plugin_path):
plugin_error("Can't find plugin")
plugin_dir = os.path.dirname(plugin_path)
fnam = os.path.basename(plugin_path)
if not fnam.endswith('.py'):
plugin_error("Plugin must have .py extension")
module_name = fnam[:-3]
import importlib
sys.path.insert(0, plugin_dir)
try:
m = importlib.import_module(module_name)
except Exception:
print('Error importing plugin {}\n'.format(plugin_path))
raise # Propagate to display traceback
finally:
assert sys.path[0] == plugin_dir
del sys.path[0]
if not hasattr(m, 'plugin'):
plugin_error('Plugin does not define entry point function "plugin"')
try:
plugin_type = getattr(m, 'plugin')(__version__)
except Exception:
print('Error calling the plugin(version) entry point of {}\n'.format(plugin_path))
raise # Propagate to display traceback
if not isinstance(plugin_type, type):
plugin_error(
'Type object expected as the return value of "plugin" (got {!r})'.format(
plugin_type))
if not issubclass(plugin_type, Plugin):
plugin_error(
'Return value of "plugin" must be a subclass of "mypy.plugin.Plugin"')
try:
custom_plugins.append(plugin_type(options.python_version))
except Exception:
print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__))
raise # Propagate to display traceback
if not custom_plugins:
return default_plugin
else:
# Custom plugins take precendence over built-in plugins.
return ChainedPlugin(options.python_version, custom_plugins + [default_plugin])


# TODO: Get rid of all_types. It's not used except for one log message.
# Maybe we could instead publish a map from module ID to its type_map.
class BuildManager:
Expand All @@ -357,6 +420,7 @@ class BuildManager:
missing_modules: Set of modules that could not be imported encountered so far
stale_modules: Set of modules that needed to be rechecked
version_id: The current mypy version (based on commit id when possible)
plugin: Active mypy plugin(s)
"""

def __init__(self, data_dir: str,
Expand All @@ -365,7 +429,8 @@ def __init__(self, data_dir: str,
source_set: BuildSourceSet,
reports: Reports,
options: Options,
version_id: str) -> None:
version_id: str,
plugin: Plugin) -> None:
self.start_time = time.time()
self.data_dir = data_dir
self.errors = Errors(options.show_error_context, options.show_column_numbers)
Expand All @@ -385,6 +450,7 @@ def __init__(self, data_dir: str,
self.indirection_detector = TypeIndirectionVisitor()
self.stale_modules = set() # type: Set[str]
self.rechecked_modules = set() # type: Set[str]
self.plugin = plugin

def maybe_swap_for_shadow_path(self, path: str) -> str:
if (self.options.shadow_file and
Expand Down Expand Up @@ -1549,9 +1615,8 @@ def type_check_first_pass(self) -> None:
if self.options.semantic_analysis_only:
return
with self.wrap_context():
plugin = DefaultPlugin(self.options.python_version)
self.type_checker = TypeChecker(manager.errors, manager.modules, self.options,
self.tree, self.xpath, plugin)
self.tree, self.xpath, manager.plugin)
self.type_checker.check_first_pass()

def type_check_second_pass(self) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def add_invertible_flag(flag: str,
parser.parse_args(args, dummy)
config_file = dummy.config_file
if config_file is not None and not os.path.exists(config_file):
parser.error("Cannot file config file '%s'" % config_file)
parser.error("Cannot find config file '%s'" % config_file)

# Parse config file first, so command line can override.
options = Options()
Expand Down Expand Up @@ -613,6 +613,7 @@ def get_init_file(dir: str) -> Optional[str]:
# These two are for backwards compatibility
'silent_imports': bool,
'almost_silent': bool,
'plugins': lambda s: [p.strip() for p in s.split(',')],
}

SHARED_CONFIG_FILES = ('setup.cfg',)
Expand Down
3 changes: 3 additions & 0 deletions mypy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def __init__(self) -> None:
self.debug_cache = False
self.quick_and_dirty = False

# Paths of user plugins
self.plugins = [] # type: List[str]

# Per-module options (raw)
self.per_module_options = {} # type: Dict[Pattern[str], Dict[str, object]]

Expand Down
46 changes: 42 additions & 4 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Optional, NamedTuple
from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar

from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context
from mypy.types import (
Expand Down Expand Up @@ -60,7 +60,7 @@


class Plugin:
"""Base class of type checker plugins.
"""Base class of all type checker plugins.
This defines a no-op plugin. Subclasses can override some methods to
provide some actual functionality.
Expand All @@ -69,8 +69,6 @@ class Plugin:
results might be cached).
"""

# TODO: Way of chaining multiple plugins

def __init__(self, python_version: Tuple[int, int]) -> None:
self.python_version = python_version

Expand All @@ -86,6 +84,46 @@ def get_method_hook(self, fullname: str) -> Optional[MethodHook]:
# TODO: metaclass / class decorator hook


T = TypeVar('T')


class ChainedPlugin(Plugin):
"""A plugin that represents a sequence of chained plugins.
Each lookup method returns the hook for the first plugin that
reports a match.
This class should not be subclassed -- use Plugin as the base class
for all plugins.
"""

# TODO: Support caching of lookup results (through a LRU cache, for example).

def __init__(self, python_version: Tuple[int, int], plugins: List[Plugin]) -> None:
"""Initialize chained plugin.
Assume that the child plugins aren't mutated (results may be cached).
"""
super().__init__(python_version)
self._plugins = plugins

def get_function_hook(self, fullname: str) -> Optional[FunctionHook]:
return self._find_hook(lambda plugin: plugin.get_function_hook(fullname))

def get_method_signature_hook(self, fullname: str) -> Optional[MethodSignatureHook]:
return self._find_hook(lambda plugin: plugin.get_method_signature_hook(fullname))

def get_method_hook(self, fullname: str) -> Optional[MethodHook]:
return self._find_hook(lambda plugin: plugin.get_method_hook(fullname))

def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]:
for plugin in self._plugins:
hook = lookup(plugin)
if hook:
return hook
return None


class DefaultPlugin(Plugin):
"""Type checker plugin that is enabled by default."""

Expand Down
22 changes: 17 additions & 5 deletions mypy/test/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from mypy.myunit import TestCase, SkipTestCaseException


root_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), '..', '..'))


def parse_test_cases(
path: str,
perform: Optional[Callable[['DataDrivenTestCase'], None]],
Expand Down Expand Up @@ -62,7 +65,9 @@ def parse_test_cases(
# Record an extra file needed for the test case.
arg = p[i].arg
assert arg is not None
file_entry = (join(base_path, arg), '\n'.join(p[i].data))
contents = '\n'.join(p[i].data)
contents = expand_variables(contents)
file_entry = (join(base_path, arg), contents)
if p[i].id == 'file':
files.append(file_entry)
elif p[i].id == 'outfile':
Expand Down Expand Up @@ -119,13 +124,15 @@ def parse_test_cases(
deleted_paths.setdefault(num, set()).add(full)
elif p[i].id == 'out' or p[i].id == 'out1':
tcout = p[i].data
if native_sep and os.path.sep == '\\':
tcout = [expand_variables(line) for line in tcout]
if os.path.sep == '\\':
tcout = [fix_win_path(line) for line in tcout]
ok = True
elif re.match(r'out[0-9]*$', p[i].id):
passnum = int(p[i].id[3:])
assert passnum > 1
output = p[i].data
output = [expand_variables(line) for line in output]
if native_sep and os.path.sep == '\\':
output = [fix_win_path(line) for line in output]
tcout2[passnum] = output
Expand Down Expand Up @@ -415,6 +422,10 @@ def expand_includes(a: List[str], base_path: str) -> List[str]:
return res


def expand_variables(s: str) -> str:
return s.replace('<ROOT>', root_dir)


def expand_errors(input: List[str], output: List[str], fnam: str) -> None:
"""Transform comments such as '# E: message' or
'# E:3: message' in input.
Expand Down Expand Up @@ -445,16 +456,17 @@ def expand_errors(input: List[str], output: List[str], fnam: str) -> None:


def fix_win_path(line: str) -> str:
r"""Changes paths to Windows paths in error messages.
r"""Changes Windows paths to Linux paths in error messages.
E.g. foo/bar.py -> foo\bar.py.
E.g. foo\bar.py -> foo/bar.py.
"""
line = line.replace(root_dir, root_dir.replace('\\', '/'))
m = re.match(r'^([\S/]+):(\d+:)?(\s+.*)', line)
if not m:
return line
else:
filename, lineno, message = m.groups()
return '{}:{}{}'.format(filename.replace('/', '\\'),
return '{}:{}{}'.format(filename.replace('\\', '/'),
lineno or '', message)


Expand Down
4 changes: 3 additions & 1 deletion mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
'check-classvar.test',
'check-enum.test',
'check-incomplete-fixture.test',
'check-custom-plugin.test',
]


Expand Down Expand Up @@ -261,7 +262,8 @@ def find_error_paths(self, a: List[str]) -> Set[str]:
for line in a:
m = re.match(r'([^\s:]+):\d+: error:', line)
if m:
p = m.group(1).replace('/', os.path.sep)
# Normalize to Linux paths.
p = m.group(1).replace(os.path.sep, '/')
hits.add(p)
return hits

Expand Down
4 changes: 3 additions & 1 deletion mypy/test/testcmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mypy.test.config import test_data_prefix, test_temp_dir
from mypy.test.data import fix_cobertura_filename
from mypy.test.data import parse_test_cases, DataDrivenTestCase
from mypy.test.helpers import assert_string_arrays_equal
from mypy.test.helpers import assert_string_arrays_equal, normalize_error_messages
from mypy.version import __version__, base_version

# Path to Python 3 interpreter
Expand Down Expand Up @@ -71,10 +71,12 @@ def test_python_evaluation(testcase: DataDrivenTestCase) -> None:
os.path.abspath(test_temp_dir))
if testcase.native_sep and os.path.sep == '\\':
normalized_output = [fix_cobertura_filename(line) for line in normalized_output]
normalized_output = normalize_error_messages(normalized_output)
assert_string_arrays_equal(expected_content.splitlines(), normalized_output,
'Output file {} did not match its expected output'.format(
path))
else:
out = normalize_error_messages(out)
assert_string_arrays_equal(testcase.output, out,
'Invalid output ({}, line {})'.format(
testcase.file, testcase.line))
Expand Down
3 changes: 3 additions & 0 deletions mypy/test/testgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from mypy.version import __version__
from mypy.options import Options
from mypy.report import Reports
from mypy.plugin import Plugin
from mypy import defaults


class GraphSuite(Suite):
Expand Down Expand Up @@ -42,6 +44,7 @@ def _make_manager(self) -> BuildManager:
reports=Reports('', {}),
options=Options(),
version_id=__version__,
plugin=Plugin(defaults.PYTHON3_VERSION),
)
return manager

Expand Down
1 change: 1 addition & 0 deletions mypy/test/testsemanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_semanal(testcase: DataDrivenTestCase) -> None:
a += str(f).split('\n')
except CompileError as e:
a = e.messages
a = normalize_error_messages(a)
assert_string_arrays_equal(
testcase.output, a,
'Invalid semantic analyzer output ({}, line {})'.format(testcase.file,
Expand Down
5 changes: 4 additions & 1 deletion mypy/test/testtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from mypy import build
from mypy.build import BuildSource
from mypy.myunit import Suite
from mypy.test.helpers import assert_string_arrays_equal, testfile_pyversion
from mypy.test.helpers import (
assert_string_arrays_equal, testfile_pyversion, normalize_error_messages
)
from mypy.test.data import parse_test_cases, DataDrivenTestCase
from mypy.test.config import test_data_prefix, test_temp_dir
from mypy.errors import CompileError
Expand Down Expand Up @@ -73,6 +75,7 @@ def test_transform(testcase: DataDrivenTestCase) -> None:
a += str(f).split('\n')
except CompileError as e:
a = e.messages
a = normalize_error_messages(a)
assert_string_arrays_equal(
testcase.output, a,
'Invalid semantic analyzer output ({}, line {})'.format(testcase.file,
Expand Down
Loading

0 comments on commit fd0a416

Please sign in to comment.