From fd0a416ab49ba997dd80bda5c4276630dec24d30 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 13 Jun 2017 18:43:25 +0100 Subject: [PATCH] Add basic support for user-defined mypy plugins (#3517) Configure them through "plugins=path/plugin.py, ..." in the ini file. The paths are relative to the configuration file. This is an almost minimal implementation and some features are missing: * Plugins installed through pip aren't properly supported. * Plugins within packages aren't properly supported. * Incremental mode doesn't invalidate cache files when plugins change. Also change path normalization in test cases in Windows. Previously we sometimes normalized to Windows paths and sometimes to Linux paths. Now switching to always use Linux paths. --- mypy/build.py | 75 +++++++++++++++++++++++-- mypy/main.py | 3 +- mypy/options.py | 3 + mypy/plugin.py | 46 +++++++++++++-- mypy/test/data.py | 22 ++++++-- mypy/test/testcheck.py | 4 +- mypy/test/testcmdline.py | 4 +- mypy/test/testgraph.py | 3 + mypy/test/testsemanal.py | 1 + mypy/test/testtransform.py | 5 +- test-data/unit/check-custom-plugin.test | 71 +++++++++++++++++++++++ test-data/unit/plugins/badreturn.py | 2 + test-data/unit/plugins/badreturn2.py | 5 ++ test-data/unit/plugins/fnplugin.py | 13 +++++ test-data/unit/plugins/noentry.py | 1 + test-data/unit/plugins/plugin2.py | 13 +++++ 16 files changed, 253 insertions(+), 18 deletions(-) create mode 100644 test-data/unit/check-custom-plugin.test create mode 100644 test-data/unit/plugins/badreturn.py create mode 100644 test-data/unit/plugins/badreturn2.py create mode 100644 test-data/unit/plugins/fnplugin.py create mode 100644 test-data/unit/plugins/noentry.py create mode 100644 test-data/unit/plugins/plugin2.py diff --git a/mypy/build.py b/mypy/build.py index 81cab5a99ef8..e4b202c4e72b 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -42,7 +42,7 @@ from mypy.stats import dump_type_stats from mypy.types import Type from mypy.version import __version__ -from mypy.plugin import DefaultPlugin +from mypy.plugin import Plugin, DefaultPlugin, ChainedPlugin # We need to know the location of this file to load data, but @@ -183,7 +183,9 @@ def build(sources: List[BuildSource], reports=reports, options=options, version_id=__version__, - ) + plugin=DefaultPlugin(options.python_version)) + + manager.plugin = load_custom_plugins(manager.plugin, options, manager.errors) try: graph = dispatch(sources, manager) @@ -334,6 +336,67 @@ def import_priority(imp: ImportBase, toplevel_priority: int) -> int: return toplevel_priority +def load_custom_plugins(default_plugin: Plugin, options: Options, errors: Errors) -> Plugin: + """Load custom plugins if any are configured. + + Return a plugin that chains all custom plugins (if any) and falls + back to default_plugin. + """ + + def plugin_error(message: str) -> None: + errors.report(0, 0, message) + errors.raise_error() + + custom_plugins = [] + for plugin_path in options.plugins: + if options.config_file: + # Plugin paths are relative to the config file location. + plugin_path = os.path.join(os.path.dirname(options.config_file), plugin_path) + errors.set_file(plugin_path, None) + + if not os.path.isfile(plugin_path): + plugin_error("Can't find plugin") + plugin_dir = os.path.dirname(plugin_path) + fnam = os.path.basename(plugin_path) + if not fnam.endswith('.py'): + plugin_error("Plugin must have .py extension") + module_name = fnam[:-3] + import importlib + sys.path.insert(0, plugin_dir) + try: + m = importlib.import_module(module_name) + except Exception: + print('Error importing plugin {}\n'.format(plugin_path)) + raise # Propagate to display traceback + finally: + assert sys.path[0] == plugin_dir + del sys.path[0] + if not hasattr(m, 'plugin'): + plugin_error('Plugin does not define entry point function "plugin"') + try: + plugin_type = getattr(m, 'plugin')(__version__) + except Exception: + print('Error calling the plugin(version) entry point of {}\n'.format(plugin_path)) + raise # Propagate to display traceback + if not isinstance(plugin_type, type): + plugin_error( + 'Type object expected as the return value of "plugin" (got {!r})'.format( + plugin_type)) + if not issubclass(plugin_type, Plugin): + plugin_error( + 'Return value of "plugin" must be a subclass of "mypy.plugin.Plugin"') + try: + custom_plugins.append(plugin_type(options.python_version)) + except Exception: + print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__)) + raise # Propagate to display traceback + if not custom_plugins: + return default_plugin + else: + # Custom plugins take precendence over built-in plugins. + return ChainedPlugin(options.python_version, custom_plugins + [default_plugin]) + + # TODO: Get rid of all_types. It's not used except for one log message. # Maybe we could instead publish a map from module ID to its type_map. class BuildManager: @@ -357,6 +420,7 @@ class BuildManager: missing_modules: Set of modules that could not be imported encountered so far stale_modules: Set of modules that needed to be rechecked version_id: The current mypy version (based on commit id when possible) + plugin: Active mypy plugin(s) """ def __init__(self, data_dir: str, @@ -365,7 +429,8 @@ def __init__(self, data_dir: str, source_set: BuildSourceSet, reports: Reports, options: Options, - version_id: str) -> None: + version_id: str, + plugin: Plugin) -> None: self.start_time = time.time() self.data_dir = data_dir self.errors = Errors(options.show_error_context, options.show_column_numbers) @@ -385,6 +450,7 @@ def __init__(self, data_dir: str, self.indirection_detector = TypeIndirectionVisitor() self.stale_modules = set() # type: Set[str] self.rechecked_modules = set() # type: Set[str] + self.plugin = plugin def maybe_swap_for_shadow_path(self, path: str) -> str: if (self.options.shadow_file and @@ -1549,9 +1615,8 @@ def type_check_first_pass(self) -> None: if self.options.semantic_analysis_only: return with self.wrap_context(): - plugin = DefaultPlugin(self.options.python_version) self.type_checker = TypeChecker(manager.errors, manager.modules, self.options, - self.tree, self.xpath, plugin) + self.tree, self.xpath, manager.plugin) self.type_checker.check_first_pass() def type_check_second_pass(self) -> bool: diff --git a/mypy/main.py b/mypy/main.py index 422ca3ccec03..f4f13f29c1e1 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -379,7 +379,7 @@ def add_invertible_flag(flag: str, parser.parse_args(args, dummy) config_file = dummy.config_file if config_file is not None and not os.path.exists(config_file): - parser.error("Cannot file config file '%s'" % config_file) + parser.error("Cannot find config file '%s'" % config_file) # Parse config file first, so command line can override. options = Options() @@ -613,6 +613,7 @@ def get_init_file(dir: str) -> Optional[str]: # These two are for backwards compatibility 'silent_imports': bool, 'almost_silent': bool, + 'plugins': lambda s: [p.strip() for p in s.split(',')], } SHARED_CONFIG_FILES = ('setup.cfg',) diff --git a/mypy/options.py b/mypy/options.py index 69f99cce9501..fac8fe6d4459 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -113,6 +113,9 @@ def __init__(self) -> None: self.debug_cache = False self.quick_and_dirty = False + # Paths of user plugins + self.plugins = [] # type: List[str] + # Per-module options (raw) self.per_module_options = {} # type: Dict[Pattern[str], Dict[str, object]] diff --git a/mypy/plugin.py b/mypy/plugin.py index 5015f7b4c940..7acd4d0b29a5 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Tuple, Optional, NamedTuple +from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context from mypy.types import ( @@ -60,7 +60,7 @@ class Plugin: - """Base class of type checker plugins. + """Base class of all type checker plugins. This defines a no-op plugin. Subclasses can override some methods to provide some actual functionality. @@ -69,8 +69,6 @@ class Plugin: results might be cached). """ - # TODO: Way of chaining multiple plugins - def __init__(self, python_version: Tuple[int, int]) -> None: self.python_version = python_version @@ -86,6 +84,46 @@ def get_method_hook(self, fullname: str) -> Optional[MethodHook]: # TODO: metaclass / class decorator hook +T = TypeVar('T') + + +class ChainedPlugin(Plugin): + """A plugin that represents a sequence of chained plugins. + + Each lookup method returns the hook for the first plugin that + reports a match. + + This class should not be subclassed -- use Plugin as the base class + for all plugins. + """ + + # TODO: Support caching of lookup results (through a LRU cache, for example). + + def __init__(self, python_version: Tuple[int, int], plugins: List[Plugin]) -> None: + """Initialize chained plugin. + + Assume that the child plugins aren't mutated (results may be cached). + """ + super().__init__(python_version) + self._plugins = plugins + + def get_function_hook(self, fullname: str) -> Optional[FunctionHook]: + return self._find_hook(lambda plugin: plugin.get_function_hook(fullname)) + + def get_method_signature_hook(self, fullname: str) -> Optional[MethodSignatureHook]: + return self._find_hook(lambda plugin: plugin.get_method_signature_hook(fullname)) + + def get_method_hook(self, fullname: str) -> Optional[MethodHook]: + return self._find_hook(lambda plugin: plugin.get_method_hook(fullname)) + + def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]: + for plugin in self._plugins: + hook = lookup(plugin) + if hook: + return hook + return None + + class DefaultPlugin(Plugin): """Type checker plugin that is enabled by default.""" diff --git a/mypy/test/data.py b/mypy/test/data.py index ccee92eac276..09fe931d0c62 100644 --- a/mypy/test/data.py +++ b/mypy/test/data.py @@ -13,6 +13,9 @@ from mypy.myunit import TestCase, SkipTestCaseException +root_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), '..', '..')) + + def parse_test_cases( path: str, perform: Optional[Callable[['DataDrivenTestCase'], None]], @@ -62,7 +65,9 @@ def parse_test_cases( # Record an extra file needed for the test case. arg = p[i].arg assert arg is not None - file_entry = (join(base_path, arg), '\n'.join(p[i].data)) + contents = '\n'.join(p[i].data) + contents = expand_variables(contents) + file_entry = (join(base_path, arg), contents) if p[i].id == 'file': files.append(file_entry) elif p[i].id == 'outfile': @@ -119,13 +124,15 @@ def parse_test_cases( deleted_paths.setdefault(num, set()).add(full) elif p[i].id == 'out' or p[i].id == 'out1': tcout = p[i].data - if native_sep and os.path.sep == '\\': + tcout = [expand_variables(line) for line in tcout] + if os.path.sep == '\\': tcout = [fix_win_path(line) for line in tcout] ok = True elif re.match(r'out[0-9]*$', p[i].id): passnum = int(p[i].id[3:]) assert passnum > 1 output = p[i].data + output = [expand_variables(line) for line in output] if native_sep and os.path.sep == '\\': output = [fix_win_path(line) for line in output] tcout2[passnum] = output @@ -415,6 +422,10 @@ def expand_includes(a: List[str], base_path: str) -> List[str]: return res +def expand_variables(s: str) -> str: + return s.replace('', root_dir) + + def expand_errors(input: List[str], output: List[str], fnam: str) -> None: """Transform comments such as '# E: message' or '# E:3: message' in input. @@ -445,16 +456,17 @@ def expand_errors(input: List[str], output: List[str], fnam: str) -> None: def fix_win_path(line: str) -> str: - r"""Changes paths to Windows paths in error messages. + r"""Changes Windows paths to Linux paths in error messages. - E.g. foo/bar.py -> foo\bar.py. + E.g. foo\bar.py -> foo/bar.py. """ + line = line.replace(root_dir, root_dir.replace('\\', '/')) m = re.match(r'^([\S/]+):(\d+:)?(\s+.*)', line) if not m: return line else: filename, lineno, message = m.groups() - return '{}:{}{}'.format(filename.replace('/', '\\'), + return '{}:{}{}'.format(filename.replace('\\', '/'), lineno or '', message) diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index cbf0b5856aab..39a2d3a3a308 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -76,6 +76,7 @@ 'check-classvar.test', 'check-enum.test', 'check-incomplete-fixture.test', + 'check-custom-plugin.test', ] @@ -261,7 +262,8 @@ def find_error_paths(self, a: List[str]) -> Set[str]: for line in a: m = re.match(r'([^\s:]+):\d+: error:', line) if m: - p = m.group(1).replace('/', os.path.sep) + # Normalize to Linux paths. + p = m.group(1).replace(os.path.sep, '/') hits.add(p) return hits diff --git a/mypy/test/testcmdline.py b/mypy/test/testcmdline.py index 8e56b42bc766..06009f107b14 100644 --- a/mypy/test/testcmdline.py +++ b/mypy/test/testcmdline.py @@ -15,7 +15,7 @@ from mypy.test.config import test_data_prefix, test_temp_dir from mypy.test.data import fix_cobertura_filename from mypy.test.data import parse_test_cases, DataDrivenTestCase -from mypy.test.helpers import assert_string_arrays_equal +from mypy.test.helpers import assert_string_arrays_equal, normalize_error_messages from mypy.version import __version__, base_version # Path to Python 3 interpreter @@ -71,10 +71,12 @@ def test_python_evaluation(testcase: DataDrivenTestCase) -> None: os.path.abspath(test_temp_dir)) if testcase.native_sep and os.path.sep == '\\': normalized_output = [fix_cobertura_filename(line) for line in normalized_output] + normalized_output = normalize_error_messages(normalized_output) assert_string_arrays_equal(expected_content.splitlines(), normalized_output, 'Output file {} did not match its expected output'.format( path)) else: + out = normalize_error_messages(out) assert_string_arrays_equal(testcase.output, out, 'Invalid output ({}, line {})'.format( testcase.file, testcase.line)) diff --git a/mypy/test/testgraph.py b/mypy/test/testgraph.py index 7a9062914f89..d168ad53e236 100644 --- a/mypy/test/testgraph.py +++ b/mypy/test/testgraph.py @@ -8,6 +8,8 @@ from mypy.version import __version__ from mypy.options import Options from mypy.report import Reports +from mypy.plugin import Plugin +from mypy import defaults class GraphSuite(Suite): @@ -42,6 +44,7 @@ def _make_manager(self) -> BuildManager: reports=Reports('', {}), options=Options(), version_id=__version__, + plugin=Plugin(defaults.PYTHON3_VERSION), ) return manager diff --git a/mypy/test/testsemanal.py b/mypy/test/testsemanal.py index 99c0078e9196..6d7f2ddb24bb 100644 --- a/mypy/test/testsemanal.py +++ b/mypy/test/testsemanal.py @@ -89,6 +89,7 @@ def test_semanal(testcase: DataDrivenTestCase) -> None: a += str(f).split('\n') except CompileError as e: a = e.messages + a = normalize_error_messages(a) assert_string_arrays_equal( testcase.output, a, 'Invalid semantic analyzer output ({}, line {})'.format(testcase.file, diff --git a/mypy/test/testtransform.py b/mypy/test/testtransform.py index 1dac3081efbd..0dcdd1d0c649 100644 --- a/mypy/test/testtransform.py +++ b/mypy/test/testtransform.py @@ -7,7 +7,9 @@ from mypy import build from mypy.build import BuildSource from mypy.myunit import Suite -from mypy.test.helpers import assert_string_arrays_equal, testfile_pyversion +from mypy.test.helpers import ( + assert_string_arrays_equal, testfile_pyversion, normalize_error_messages +) from mypy.test.data import parse_test_cases, DataDrivenTestCase from mypy.test.config import test_data_prefix, test_temp_dir from mypy.errors import CompileError @@ -73,6 +75,7 @@ def test_transform(testcase: DataDrivenTestCase) -> None: a += str(f).split('\n') except CompileError as e: a = e.messages + a = normalize_error_messages(a) assert_string_arrays_equal( testcase.output, a, 'Invalid semantic analyzer output ({}, line {})'.format(testcase.file, diff --git a/test-data/unit/check-custom-plugin.test b/test-data/unit/check-custom-plugin.test new file mode 100644 index 000000000000..30b00a4b3a62 --- /dev/null +++ b/test-data/unit/check-custom-plugin.test @@ -0,0 +1,71 @@ +-- Test cases for user-defined plugins +-- +-- Note: Plugins used by tests live under test-data/unit/plugins. Defining +-- plugin files in test cases does not work reliably. + +[case testFunctionPlugin] +# flags: --config-file tmp/mypy.ini +def f() -> str: ... +reveal_type(f()) # E: Revealed type is 'builtins.int' +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/fnplugin.py + +[case testTwoPlugins] +# flags: --config-file tmp/mypy.ini +def f(): ... +def g(): ... +def h(): ... +reveal_type(f()) # E: Revealed type is 'builtins.int' +reveal_type(g()) # E: Revealed type is 'builtins.str' +reveal_type(h()) # E: Revealed type is 'Any' +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/fnplugin.py, + /test-data/unit/plugins/plugin2.py + +[case testMissingPlugin] +# flags: --config-file tmp/mypy.ini +[file mypy.ini] +[[mypy] +plugins=missing.py +[out] +tmp/missing.py:0: error: Can't find plugin +--' (work around syntax highlighting) + +[case testInvalidPluginExtension] +# flags: --config-file tmp/mypy.ini +[file mypy.ini] +[[mypy] +plugins=badext.pyi +[file badext.pyi] +[out] +tmp/badext.pyi:0: error: Plugin must have .py extension + +[case testMissingPluginEntryPoint] +# flags: --config-file tmp/mypy.ini +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/noentry.py +[out] +/test-data/unit/plugins/noentry.py:0: error: Plugin does not define entry point function "plugin" + +[case testInvalidPluginEntryPointReturnValue] +# flags: --config-file tmp/mypy.ini +def f(): pass +f() +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/badreturn.py +[out] +/test-data/unit/plugins/badreturn.py:0: error: Type object expected as the return value of "plugin" (got None) + +[case testInvalidPluginEntryPointReturnValue2] +# flags: --config-file tmp/mypy.ini +def f(): pass +f() +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/badreturn2.py +[out] +/test-data/unit/plugins/badreturn2.py:0: error: Return value of "plugin" must be a subclass of "mypy.plugin.Plugin" diff --git a/test-data/unit/plugins/badreturn.py b/test-data/unit/plugins/badreturn.py new file mode 100644 index 000000000000..fd7430606dd6 --- /dev/null +++ b/test-data/unit/plugins/badreturn.py @@ -0,0 +1,2 @@ +def plugin(version): + pass diff --git a/test-data/unit/plugins/badreturn2.py b/test-data/unit/plugins/badreturn2.py new file mode 100644 index 000000000000..c7e0447841c1 --- /dev/null +++ b/test-data/unit/plugins/badreturn2.py @@ -0,0 +1,5 @@ +class MyPlugin: + pass + +def plugin(version): + return MyPlugin diff --git a/test-data/unit/plugins/fnplugin.py b/test-data/unit/plugins/fnplugin.py new file mode 100644 index 000000000000..d5027219a09f --- /dev/null +++ b/test-data/unit/plugins/fnplugin.py @@ -0,0 +1,13 @@ +from mypy.plugin import Plugin + +class MyPlugin(Plugin): + def get_function_hook(self, fullname): + if fullname == '__main__.f': + return my_hook + return None + +def my_hook(arg_types, args, inferred_return_type, named_generic_type): + return named_generic_type('builtins.int', []) + +def plugin(version): + return MyPlugin diff --git a/test-data/unit/plugins/noentry.py b/test-data/unit/plugins/noentry.py new file mode 100644 index 000000000000..c591ad11fd64 --- /dev/null +++ b/test-data/unit/plugins/noentry.py @@ -0,0 +1 @@ +# empty plugin diff --git a/test-data/unit/plugins/plugin2.py b/test-data/unit/plugins/plugin2.py new file mode 100644 index 000000000000..1584871fae1d --- /dev/null +++ b/test-data/unit/plugins/plugin2.py @@ -0,0 +1,13 @@ +from mypy.plugin import Plugin + +class Plugin2(Plugin): + def get_function_hook(self, fullname): + if fullname in ('__main__.f', '__main__.g'): + return str_hook + return None + +def str_hook(arg_types, args, inferred_return_type, named_generic_type): + return named_generic_type('builtins.str', []) + +def plugin(version): + return Plugin2