diff --git a/.gitignore b/.gitignore index c00e966c..32c2fec0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,11 @@ *.egg-info -*.iml *.py[co] .*.sw[a-z] -.pytest_cache .coverage -.idea -.project -.pydevproject .tox .venv.touch +/.mypy_cache +/.pytest_cache /venv* coverage-html dist -# SublimeText project/workspace files -*.sublime-* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8bd0fdc5..49905378 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: rev: v1.3.5 hooks: - id: reorder-python-imports - language_version: python2.7 + language_version: python3 - repo: https://github.com/asottile/pyupgrade rev: v1.11.1 hooks: @@ -36,3 +36,8 @@ repos: rev: v0.7.1 hooks: - id: add-trailing-comma +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.660 + hooks: + - id: mypy + language_version: python3 diff --git a/.travis.yml b/.travis.yml index 477b5c4b..fa16ccef 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,4 @@ +dist: xenial language: python matrix: include: # These should match the tox env list @@ -6,9 +7,8 @@ matrix: python: 3.6 - env: TOXENV=py37 python: 3.7 - dist: xenial - env: TOXENV=pypy - python: pypy-5.7.1 + python: pypy2.7-5.10.0 install: pip install coveralls tox script: tox before_install: diff --git a/get-git-lfs.py b/get-git-lfs.py index 48dd31eb..4b09cac6 100755 --- a/get-git-lfs.py +++ b/get-git-lfs.py @@ -4,7 +4,9 @@ import os.path import shutil import tarfile -from urllib.request import urlopen +import urllib.request +from typing import cast +from typing import IO DOWNLOAD_PATH = ( 'https://github.com/github/git-lfs/releases/download/' @@ -15,7 +17,7 @@ DEST_DIR = os.path.dirname(DEST_PATH) -def main(): +def main(): # type: () -> int if ( os.path.exists(DEST_PATH) and os.path.isfile(DEST_PATH) and @@ -27,12 +29,13 @@ def main(): shutil.rmtree(DEST_DIR, ignore_errors=True) os.makedirs(DEST_DIR, exist_ok=True) - contents = io.BytesIO(urlopen(DOWNLOAD_PATH).read()) + contents = io.BytesIO(urllib.request.urlopen(DOWNLOAD_PATH).read()) with tarfile.open(fileobj=contents) as tar: - with tar.extractfile(PATH_IN_TAR) as src_file: + with cast(IO[bytes], tar.extractfile(PATH_IN_TAR)) as src_file: with open(DEST_PATH, 'wb') as dest_file: shutil.copyfileobj(src_file, dest_file) os.chmod(DEST_PATH, 0o755) + return 0 if __name__ == '__main__': diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..ee62c89f --- /dev/null +++ b/mypy.ini @@ -0,0 +1,12 @@ +[mypy] +check_untyped_defs = true +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_untyped_defs = true +no_implicit_optional = true + +[mypy-testing.*] +disallow_untyped_defs = false + +[mypy-tests.*] +disallow_untyped_defs = false diff --git a/pre_commit_hooks/autopep8_wrapper.py b/pre_commit_hooks/autopep8_wrapper.py index 9951924d..8b69a049 100644 --- a/pre_commit_hooks/autopep8_wrapper.py +++ b/pre_commit_hooks/autopep8_wrapper.py @@ -3,7 +3,7 @@ from __future__ import unicode_literals -def main(argv=None): +def main(): # type: () -> int raise SystemExit( 'autopep8-wrapper is deprecated. Instead use autopep8 directly via ' 'https://github.com/pre-commit/mirrors-autopep8', diff --git a/pre_commit_hooks/check_added_large_files.py b/pre_commit_hooks/check_added_large_files.py index 2d06706d..be394989 100644 --- a/pre_commit_hooks/check_added_large_files.py +++ b/pre_commit_hooks/check_added_large_files.py @@ -7,13 +7,17 @@ import json import math import os +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Set from pre_commit_hooks.util import added_files from pre_commit_hooks.util import CalledProcessError from pre_commit_hooks.util import cmd_output -def lfs_files(): +def lfs_files(): # type: () -> Set[str] try: # Introduced in git-lfs 2.2.0, first working in 2.2.1 lfs_ret = cmd_output('git', 'lfs', 'status', '--json') @@ -24,6 +28,7 @@ def lfs_files(): def find_large_added_files(filenames, maxkb): + # type: (Iterable[str], int) -> int # Find all added files that are also in the list of files pre-commit tells # us about filenames = (added_files() & set(filenames)) - lfs_files() @@ -38,7 +43,7 @@ def find_large_added_files(filenames, maxkb): return retv -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( 'filenames', nargs='*', diff --git a/pre_commit_hooks/check_ast.py b/pre_commit_hooks/check_ast.py index ded65e46..0df35407 100644 --- a/pre_commit_hooks/check_ast.py +++ b/pre_commit_hooks/check_ast.py @@ -7,9 +7,11 @@ import platform import sys import traceback +from typing import Optional +from typing import Sequence -def check_ast(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv) @@ -34,4 +36,4 @@ def check_ast(argv=None): if __name__ == '__main__': - exit(check_ast()) + exit(main()) diff --git a/pre_commit_hooks/check_builtin_literals.py b/pre_commit_hooks/check_builtin_literals.py index 4a4b9ce3..874c68c5 100644 --- a/pre_commit_hooks/check_builtin_literals.py +++ b/pre_commit_hooks/check_builtin_literals.py @@ -4,6 +4,10 @@ import ast import collections import sys +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set BUILTIN_TYPES = { @@ -22,14 +26,17 @@ class BuiltinTypeVisitor(ast.NodeVisitor): def __init__(self, ignore=None, allow_dict_kwargs=True): - self.builtin_type_calls = [] + # type: (Optional[Sequence[str]], bool) -> None + self.builtin_type_calls = [] # type: List[BuiltinTypeCall] self.ignore = set(ignore) if ignore else set() self.allow_dict_kwargs = allow_dict_kwargs - def _check_dict_call(self, node): + def _check_dict_call(self, node): # type: (ast.Call) -> bool + return self.allow_dict_kwargs and (getattr(node, 'kwargs', None) or getattr(node, 'keywords', None)) - def visit_Call(self, node): + def visit_Call(self, node): # type: (ast.Call) -> None + if not isinstance(node.func, ast.Name): # Ignore functions that are object attributes (`foo.bar()`). # Assume that if the user calls `builtins.list()`, they know what @@ -47,6 +54,7 @@ def visit_Call(self, node): def check_file_for_builtin_type_constructors(filename, ignore=None, allow_dict_kwargs=True): + # type: (str, Optional[Sequence[str]], bool) -> List[BuiltinTypeCall] with open(filename, 'rb') as f: tree = ast.parse(f.read(), filename=filename) visitor = BuiltinTypeVisitor(ignore=ignore, allow_dict_kwargs=allow_dict_kwargs) @@ -54,24 +62,22 @@ def check_file_for_builtin_type_constructors(filename, ignore=None, allow_dict_k return visitor.builtin_type_calls -def parse_args(argv): - def parse_ignore(value): - return set(value.split(',')) +def parse_ignore(value): # type: (str) -> Set[str] + return set(value.split(',')) + +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') parser.add_argument('--ignore', type=parse_ignore, default=set()) - allow_dict_kwargs = parser.add_mutually_exclusive_group(required=False) - allow_dict_kwargs.add_argument('--allow-dict-kwargs', action='store_true') - allow_dict_kwargs.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false') - allow_dict_kwargs.set_defaults(allow_dict_kwargs=True) - - return parser.parse_args(argv) + mutex = parser.add_mutually_exclusive_group(required=False) + mutex.add_argument('--allow-dict-kwargs', action='store_true') + mutex.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false') + mutex.set_defaults(allow_dict_kwargs=True) + args = parser.parse_args(argv) -def main(argv=None): - args = parse_args(argv) rc = 0 for filename in args.filenames: calls = check_file_for_builtin_type_constructors( diff --git a/pre_commit_hooks/check_byte_order_marker.py b/pre_commit_hooks/check_byte_order_marker.py index 1541b302..10667c33 100644 --- a/pre_commit_hooks/check_byte_order_marker.py +++ b/pre_commit_hooks/check_byte_order_marker.py @@ -3,9 +3,11 @@ from __future__ import unicode_literals import argparse +from typing import Optional +from typing import Sequence -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to check') args = parser.parse_args(argv) diff --git a/pre_commit_hooks/check_case_conflict.py b/pre_commit_hooks/check_case_conflict.py index 0f782965..e343d61f 100644 --- a/pre_commit_hooks/check_case_conflict.py +++ b/pre_commit_hooks/check_case_conflict.py @@ -3,16 +3,20 @@ from __future__ import unicode_literals import argparse +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Set from pre_commit_hooks.util import added_files from pre_commit_hooks.util import cmd_output -def lower_set(iterable): +def lower_set(iterable): # type: (Iterable[str]) -> Set[str] return {x.lower() for x in iterable} -def find_conflicting_filenames(filenames): +def find_conflicting_filenames(filenames): # type: (Sequence[str]) -> int repo_files = set(cmd_output('git', 'ls-files').splitlines()) relevant_files = set(filenames) | added_files() repo_files -= relevant_files @@ -41,7 +45,7 @@ def find_conflicting_filenames(filenames): return retv -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( 'filenames', nargs='*', diff --git a/pre_commit_hooks/check_docstring_first.py b/pre_commit_hooks/check_docstring_first.py index 9988378a..f4639f17 100644 --- a/pre_commit_hooks/check_docstring_first.py +++ b/pre_commit_hooks/check_docstring_first.py @@ -5,6 +5,8 @@ import argparse import io import tokenize +from typing import Optional +from typing import Sequence NON_CODE_TOKENS = frozenset(( @@ -13,6 +15,7 @@ def check_docstring_first(src, filename=''): + # type: (str, str) -> int """Returns nonzero if the source has what looks like a docstring that is not at the beginning of the source. @@ -50,7 +53,7 @@ def check_docstring_first(src, filename=''): return 0 -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv) diff --git a/pre_commit_hooks/check_executables_have_shebangs.py b/pre_commit_hooks/check_executables_have_shebangs.py index 89ac6e5b..c936a5dd 100644 --- a/pre_commit_hooks/check_executables_have_shebangs.py +++ b/pre_commit_hooks/check_executables_have_shebangs.py @@ -6,9 +6,11 @@ import argparse import pipes import sys +from typing import Optional +from typing import Sequence -def check_has_shebang(path): +def check_has_shebang(path): # type: (str) -> int with open(path, 'rb') as f: first_bytes = f.read(2) @@ -27,7 +29,7 @@ def check_has_shebang(path): return 0 -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser(description=__doc__) parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv) @@ -38,3 +40,7 @@ def main(argv=None): retv |= check_has_shebang(filename) return retv + + +if __name__ == '__main__': + exit(main()) diff --git a/pre_commit_hooks/check_json.py b/pre_commit_hooks/check_json.py index b403f4b2..b9393508 100644 --- a/pre_commit_hooks/check_json.py +++ b/pre_commit_hooks/check_json.py @@ -4,9 +4,11 @@ import io import json import sys +from typing import Optional +from typing import Sequence -def check_json(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='JSON filenames to check.') args = parser.parse_args(argv) @@ -22,4 +24,4 @@ def check_json(argv=None): if __name__ == '__main__': - sys.exit(check_json()) + sys.exit(main()) diff --git a/pre_commit_hooks/check_merge_conflict.py b/pre_commit_hooks/check_merge_conflict.py index 6db5efe9..74e4ae17 100644 --- a/pre_commit_hooks/check_merge_conflict.py +++ b/pre_commit_hooks/check_merge_conflict.py @@ -2,6 +2,9 @@ import argparse import os.path +from typing import Optional +from typing import Sequence + CONFLICT_PATTERNS = [ b'<<<<<<< ', @@ -12,7 +15,7 @@ WARNING_MSG = 'Merge conflict string "{0}" found in {1}:{2}' -def is_in_merge(): +def is_in_merge(): # type: () -> int return ( os.path.exists(os.path.join('.git', 'MERGE_MSG')) and ( @@ -23,7 +26,7 @@ def is_in_merge(): ) -def detect_merge_conflict(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') parser.add_argument('--assume-in-merge', action='store_true') @@ -47,4 +50,4 @@ def detect_merge_conflict(argv=None): if __name__ == '__main__': - exit(detect_merge_conflict()) + exit(main()) diff --git a/pre_commit_hooks/check_symlinks.py b/pre_commit_hooks/check_symlinks.py index 010c8715..736bf99c 100644 --- a/pre_commit_hooks/check_symlinks.py +++ b/pre_commit_hooks/check_symlinks.py @@ -4,9 +4,11 @@ import argparse import os.path +from typing import Optional +from typing import Sequence -def check_symlinks(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser(description='Checks for broken symlinks.') parser.add_argument('filenames', nargs='*', help='Filenames to check') args = parser.parse_args(argv) @@ -25,4 +27,4 @@ def check_symlinks(argv=None): if __name__ == '__main__': - exit(check_symlinks()) + exit(main()) diff --git a/pre_commit_hooks/check_vcs_permalinks.py b/pre_commit_hooks/check_vcs_permalinks.py index f0dcf5b6..f6e2a7d5 100644 --- a/pre_commit_hooks/check_vcs_permalinks.py +++ b/pre_commit_hooks/check_vcs_permalinks.py @@ -5,6 +5,8 @@ import argparse import re import sys +from typing import Optional +from typing import Sequence GITHUB_NON_PERMALINK = re.compile( @@ -12,7 +14,7 @@ ) -def _check_filename(filename): +def _check_filename(filename): # type: (str) -> int retv = 0 with open(filename, 'rb') as f: for i, line in enumerate(f, 1): @@ -24,7 +26,7 @@ def _check_filename(filename): return retv -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv) diff --git a/pre_commit_hooks/check_xml.py b/pre_commit_hooks/check_xml.py index a4c11a59..66e10bac 100644 --- a/pre_commit_hooks/check_xml.py +++ b/pre_commit_hooks/check_xml.py @@ -5,10 +5,12 @@ import argparse import io import sys -import xml.sax +import xml.sax.handler +from typing import Optional +from typing import Sequence -def check_xml(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='XML filenames to check.') args = parser.parse_args(argv) @@ -17,7 +19,7 @@ def check_xml(argv=None): for filename in args.filenames: try: with io.open(filename, 'rb') as xml_file: - xml.sax.parse(xml_file, xml.sax.ContentHandler()) + xml.sax.parse(xml_file, xml.sax.handler.ContentHandler()) except xml.sax.SAXException as exc: print('{}: Failed to xml parse ({})'.format(filename, exc)) retval = 1 @@ -25,4 +27,4 @@ def check_xml(argv=None): if __name__ == '__main__': - sys.exit(check_xml()) + sys.exit(main()) diff --git a/pre_commit_hooks/check_yaml.py b/pre_commit_hooks/check_yaml.py index 208737f1..b638684b 100644 --- a/pre_commit_hooks/check_yaml.py +++ b/pre_commit_hooks/check_yaml.py @@ -3,22 +3,26 @@ import argparse import collections import sys +from typing import Any +from typing import Generator +from typing import Optional +from typing import Sequence import ruamel.yaml yaml = ruamel.yaml.YAML(typ='safe') -def _exhaust(gen): +def _exhaust(gen): # type: (Generator[str, None, None]) -> None for _ in gen: pass -def _parse_unsafe(*args, **kwargs): +def _parse_unsafe(*args, **kwargs): # type: (*Any, **Any) -> None _exhaust(yaml.parse(*args, **kwargs)) -def _load_all(*args, **kwargs): +def _load_all(*args, **kwargs): # type: (*Any, **Any) -> None _exhaust(yaml.load_all(*args, **kwargs)) @@ -31,7 +35,7 @@ def _load_all(*args, **kwargs): } -def check_yaml(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '-m', '--multi', '--allow-multiple-documents', action='store_true', @@ -63,4 +67,4 @@ def check_yaml(argv=None): if __name__ == '__main__': - sys.exit(check_yaml()) + sys.exit(main()) diff --git a/pre_commit_hooks/debug_statement_hook.py b/pre_commit_hooks/debug_statement_hook.py index 5d32277a..02dd3b29 100644 --- a/pre_commit_hooks/debug_statement_hook.py +++ b/pre_commit_hooks/debug_statement_hook.py @@ -5,6 +5,9 @@ import ast import collections import traceback +from typing import List +from typing import Optional +from typing import Sequence DEBUG_STATEMENTS = {'pdb', 'ipdb', 'pudb', 'q', 'rdb'} @@ -12,21 +15,21 @@ class DebugStatementParser(ast.NodeVisitor): - def __init__(self): - self.breakpoints = [] + def __init__(self): # type: () -> None + self.breakpoints = [] # type: List[Debug] - def visit_Import(self, node): + def visit_Import(self, node): # type: (ast.Import) -> None for name in node.names: if name.name in DEBUG_STATEMENTS: st = Debug(node.lineno, node.col_offset, name.name, 'imported') self.breakpoints.append(st) - def visit_ImportFrom(self, node): + def visit_ImportFrom(self, node): # type: (ast.ImportFrom) -> None if node.module in DEBUG_STATEMENTS: st = Debug(node.lineno, node.col_offset, node.module, 'imported') self.breakpoints.append(st) - def visit_Call(self, node): + def visit_Call(self, node): # type: (ast.Call) -> None """python3.7+ breakpoint()""" if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint': st = Debug(node.lineno, node.col_offset, node.func.id, 'called') @@ -34,7 +37,7 @@ def visit_Call(self, node): self.generic_visit(node) -def check_file(filename): +def check_file(filename): # type: (str) -> int try: with open(filename, 'rb') as f: ast_obj = ast.parse(f.read(), filename=filename) @@ -58,7 +61,7 @@ def check_file(filename): return int(bool(visitor.breakpoints)) -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to run') args = parser.parse_args(argv) diff --git a/pre_commit_hooks/detect_aws_credentials.py b/pre_commit_hooks/detect_aws_credentials.py index ecd9d40d..3c87d117 100644 --- a/pre_commit_hooks/detect_aws_credentials.py +++ b/pre_commit_hooks/detect_aws_credentials.py @@ -3,11 +3,16 @@ import argparse import os +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set from six.moves import configparser -def get_aws_credential_files_from_env(): +def get_aws_credential_files_from_env(): # type: () -> Set[str] """Extract credential file paths from environment variables.""" files = set() for env_var in ( @@ -19,7 +24,7 @@ def get_aws_credential_files_from_env(): return files -def get_aws_secrets_from_env(): +def get_aws_secrets_from_env(): # type: () -> Set[str] """Extract AWS secrets from environment variables.""" keys = set() for env_var in ( @@ -30,7 +35,7 @@ def get_aws_secrets_from_env(): return keys -def get_aws_secrets_from_file(credentials_file): +def get_aws_secrets_from_file(credentials_file): # type: (str) -> Set[str] """Extract AWS secrets from configuration files. Read an ini-style configuration file and return a set with all found AWS @@ -62,6 +67,7 @@ def get_aws_secrets_from_file(credentials_file): def check_file_for_aws_keys(filenames, keys): + # type: (Sequence[str], Set[str]) -> List[Dict[str, str]] """Check if files contain AWS secrets. Return a list of all files containing AWS secrets and keys found, with all @@ -82,7 +88,7 @@ def check_file_for_aws_keys(filenames, keys): return bad_files -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='+', help='Filenames to run') parser.add_argument( @@ -111,7 +117,7 @@ def main(argv=None): # of files to to gather AWS secrets from. credential_files |= get_aws_credential_files_from_env() - keys = set() + keys = set() # type: Set[str] for credential_file in credential_files: keys |= get_aws_secrets_from_file(credential_file) diff --git a/pre_commit_hooks/detect_private_key.py b/pre_commit_hooks/detect_private_key.py index c8ee9611..d31957de 100644 --- a/pre_commit_hooks/detect_private_key.py +++ b/pre_commit_hooks/detect_private_key.py @@ -2,6 +2,8 @@ import argparse import sys +from typing import Optional +from typing import Sequence BLACKLIST = [ b'BEGIN RSA PRIVATE KEY', @@ -15,7 +17,7 @@ ] -def detect_private_key(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to check') args = parser.parse_args(argv) @@ -37,4 +39,4 @@ def detect_private_key(argv=None): if __name__ == '__main__': - sys.exit(detect_private_key()) + sys.exit(main()) diff --git a/pre_commit_hooks/end_of_file_fixer.py b/pre_commit_hooks/end_of_file_fixer.py index 5ab1b7b0..4e77c945 100644 --- a/pre_commit_hooks/end_of_file_fixer.py +++ b/pre_commit_hooks/end_of_file_fixer.py @@ -4,9 +4,12 @@ import argparse import os import sys +from typing import IO +from typing import Optional +from typing import Sequence -def fix_file(file_obj): +def fix_file(file_obj): # type: (IO[bytes]) -> int # Test for newline at end of file # Empty files will throw IOError here try: @@ -49,7 +52,7 @@ def fix_file(file_obj): return 0 -def end_of_file_fixer(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) @@ -68,4 +71,4 @@ def end_of_file_fixer(argv=None): if __name__ == '__main__': - sys.exit(end_of_file_fixer()) + sys.exit(main()) diff --git a/pre_commit_hooks/file_contents_sorter.py b/pre_commit_hooks/file_contents_sorter.py index fe7f7ee3..6f13c98a 100644 --- a/pre_commit_hooks/file_contents_sorter.py +++ b/pre_commit_hooks/file_contents_sorter.py @@ -12,12 +12,15 @@ from __future__ import print_function import argparse +from typing import IO +from typing import Optional +from typing import Sequence PASS = 0 FAIL = 1 -def sort_file_contents(f): +def sort_file_contents(f): # type: (IO[bytes]) -> int before = list(f) after = sorted([line.strip(b'\n\r') for line in before if line.strip()]) @@ -33,7 +36,7 @@ def sort_file_contents(f): return FAIL -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='+', help='Files to sort') args = parser.parse_args(argv) diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py index 3bf234ed..b0b5c8ec 100644 --- a/pre_commit_hooks/fix_encoding_pragma.py +++ b/pre_commit_hooks/fix_encoding_pragma.py @@ -4,11 +4,15 @@ import argparse import collections +from typing import IO +from typing import Optional +from typing import Sequence +from typing import Union DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n' -def has_coding(line): +def has_coding(line): # type: (bytes) -> bool if not line.strip(): return False return ( @@ -33,15 +37,16 @@ class ExpectedContents(collections.namedtuple( __slots__ = () @property - def has_any_pragma(self): + def has_any_pragma(self): # type: () -> bool return self.pragma_status is not False - def is_expected_pragma(self, remove): + def is_expected_pragma(self, remove): # type: (bool) -> bool expected_pragma_status = not remove return self.pragma_status is expected_pragma_status def _get_expected_contents(first_line, second_line, rest, expected_pragma): + # type: (bytes, bytes, bytes, bytes) -> ExpectedContents if first_line.startswith(b'#!'): shebang = first_line potential_coding = second_line @@ -51,7 +56,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma): rest = second_line + rest if potential_coding == expected_pragma: - pragma_status = True + pragma_status = True # type: Optional[bool] elif has_coding(potential_coding): pragma_status = None else: @@ -64,6 +69,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma): def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA): + # type: (IO[bytes], bool, bytes) -> int expected = _get_expected_contents( f.readline(), f.readline(), f.read(), expected_pragma, ) @@ -93,17 +99,17 @@ def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA): return 1 -def _normalize_pragma(pragma): +def _normalize_pragma(pragma): # type: (Union[bytes, str]) -> bytes if not isinstance(pragma, bytes): pragma = pragma.encode('UTF-8') return pragma.rstrip() + b'\n' -def _to_disp(pragma): +def _to_disp(pragma): # type: (bytes) -> str return pragma.decode().rstrip() -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser('Fixes the encoding pragma of python files') parser.add_argument('filenames', nargs='*', help='Filenames to fix') parser.add_argument( diff --git a/pre_commit_hooks/forbid_new_submodules.py b/pre_commit_hooks/forbid_new_submodules.py index c9464cf7..bdbd6f7f 100644 --- a/pre_commit_hooks/forbid_new_submodules.py +++ b/pre_commit_hooks/forbid_new_submodules.py @@ -2,10 +2,13 @@ from __future__ import print_function from __future__ import unicode_literals +from typing import Optional +from typing import Sequence + from pre_commit_hooks.util import cmd_output -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int # `argv` is ignored, pre-commit will send us a list of files that we # don't care about added_diff = cmd_output( diff --git a/pre_commit_hooks/mixed_line_ending.py b/pre_commit_hooks/mixed_line_ending.py index e35a65c9..90aef035 100644 --- a/pre_commit_hooks/mixed_line_ending.py +++ b/pre_commit_hooks/mixed_line_ending.py @@ -4,6 +4,9 @@ import argparse import collections +from typing import Dict +from typing import Optional +from typing import Sequence CRLF = b'\r\n' @@ -14,7 +17,7 @@ FIX_TO_LINE_ENDING = {'cr': CR, 'crlf': CRLF, 'lf': LF} -def _fix(filename, contents, ending): +def _fix(filename, contents, ending): # type: (str, bytes, bytes) -> None new_contents = b''.join( line.rstrip(b'\r\n') + ending for line in contents.splitlines(True) ) @@ -22,11 +25,11 @@ def _fix(filename, contents, ending): f.write(new_contents) -def fix_filename(filename, fix): +def fix_filename(filename, fix): # type: (str, str) -> int with open(filename, 'rb') as f: contents = f.read() - counts = collections.defaultdict(int) + counts = collections.defaultdict(int) # type: Dict[bytes, int] for line in contents.splitlines(True): for ending in ALL_ENDINGS: @@ -63,7 +66,7 @@ def fix_filename(filename, fix): return other_endings -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '-f', '--fix', diff --git a/pre_commit_hooks/no_commit_to_branch.py b/pre_commit_hooks/no_commit_to_branch.py index fdd146bc..6b68c915 100644 --- a/pre_commit_hooks/no_commit_to_branch.py +++ b/pre_commit_hooks/no_commit_to_branch.py @@ -1,12 +1,15 @@ from __future__ import print_function import argparse +from typing import Optional +from typing import Sequence +from typing import Set from pre_commit_hooks.util import CalledProcessError from pre_commit_hooks.util import cmd_output -def is_on_branch(protected): +def is_on_branch(protected): # type: (Set[str]) -> bool try: branch = cmd_output('git', 'symbolic-ref', 'HEAD') except CalledProcessError: @@ -15,7 +18,7 @@ def is_on_branch(protected): return '/'.join(chunks[2:]) in protected -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '-b', '--branch', action='append', diff --git a/pre_commit_hooks/pretty_format_json.py b/pre_commit_hooks/pretty_format_json.py index 363037e2..de7f8d71 100644 --- a/pre_commit_hooks/pretty_format_json.py +++ b/pre_commit_hooks/pretty_format_json.py @@ -5,12 +5,20 @@ import json import sys from collections import OrderedDict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union from six import text_type -def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=[]): +def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=()): + # type: (str, str, bool, bool, Sequence[str]) -> str def pairs_first(pairs): + # type: (Sequence[Tuple[str, str]]) -> Mapping[str, str] before = [pair for pair in pairs if pair[0] in top_keys] before = sorted(before, key=lambda x: top_keys.index(x[0])) after = [pair for pair in pairs if pair[0] not in top_keys] @@ -27,13 +35,13 @@ def pairs_first(pairs): return text_type(json_pretty) + '\n' -def _autofix(filename, new_contents): +def _autofix(filename, new_contents): # type: (str, str) -> None print('Fixing file {}'.format(filename)) with io.open(filename, 'w', encoding='UTF-8') as f: f.write(new_contents) -def parse_num_to_int(s): +def parse_num_to_int(s): # type: (str) -> Union[int, str] """Convert string numbers to int, leaving strings as is.""" try: return int(s) @@ -41,11 +49,11 @@ def parse_num_to_int(s): return s -def parse_topkeys(s): +def parse_topkeys(s): # type: (str) -> List[str] return s.split(',') -def pretty_format_json(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '--autofix', @@ -117,4 +125,4 @@ def pretty_format_json(argv=None): if __name__ == '__main__': - sys.exit(pretty_format_json()) + sys.exit(main()) diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py index 6dcf8d09..3f85a17a 100644 --- a/pre_commit_hooks/requirements_txt_fixer.py +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -1,6 +1,10 @@ from __future__ import print_function import argparse +from typing import IO +from typing import List +from typing import Optional +from typing import Sequence PASS = 0 @@ -9,21 +13,23 @@ class Requirement(object): - def __init__(self): + def __init__(self): # type: () -> None super(Requirement, self).__init__() - self.value = None - self.comments = [] + self.value = None # type: Optional[bytes] + self.comments = [] # type: List[bytes] @property - def name(self): + def name(self): # type: () -> bytes + assert self.value is not None, self.value if self.value.startswith(b'-e '): return self.value.lower().partition(b'=')[-1] return self.value.lower().partition(b'==')[0] - def __lt__(self, requirement): + def __lt__(self, requirement): # type: (Requirement) -> int # \n means top of file comment, so always return True, # otherwise just do a string comparison with value. + assert self.value is not None, self.value if self.value == b'\n': return True elif requirement.value == b'\n': @@ -32,10 +38,10 @@ def __lt__(self, requirement): return self.name < requirement.name -def fix_requirements(f): - requirements = [] +def fix_requirements(f): # type: (IO[bytes]) -> int + requirements = [] # type: List[Requirement] before = tuple(f) - after = [] + after = [] # type: List[bytes] before_string = b''.join(before) @@ -46,6 +52,7 @@ def fix_requirements(f): for line in before: # If the most recent requirement object has a value, then it's # time to start building the next requirement object. + if not len(requirements) or requirements[-1].value is not None: requirements.append(Requirement()) @@ -78,6 +85,7 @@ def fix_requirements(f): for requirement in sorted(requirements): after.extend(requirement.comments) + assert requirement.value, requirement.value after.append(requirement.value) after.extend(rest) @@ -92,7 +100,7 @@ def fix_requirements(f): return FAIL -def fix_requirements_txt(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) @@ -109,3 +117,7 @@ def fix_requirements_txt(argv=None): retv |= ret_for_file return retv + + +if __name__ == '__main__': + exit(main()) diff --git a/pre_commit_hooks/sort_simple_yaml.py b/pre_commit_hooks/sort_simple_yaml.py index 7afae917..3c8ef165 100755 --- a/pre_commit_hooks/sort_simple_yaml.py +++ b/pre_commit_hooks/sort_simple_yaml.py @@ -21,12 +21,15 @@ from __future__ import print_function import argparse +from typing import List +from typing import Optional +from typing import Sequence QUOTES = ["'", '"'] -def sort(lines): +def sort(lines): # type: (List[str]) -> List[str] """Sort a YAML file in alphabetical order, keeping blocks together. :param lines: array of strings (without newlines) @@ -44,7 +47,7 @@ def sort(lines): return new_lines -def parse_block(lines, header=False): +def parse_block(lines, header=False): # type: (List[str], bool) -> List[str] """Parse and return a single block, popping off the start of `lines`. If parsing a header block, we stop after we reach a line that is not a @@ -60,7 +63,7 @@ def parse_block(lines, header=False): return block_lines -def parse_blocks(lines): +def parse_blocks(lines): # type: (List[str]) -> List[List[str]] """Parse and return all possible blocks, popping off the start of `lines`. :param lines: list of lines @@ -77,7 +80,7 @@ def parse_blocks(lines): return blocks -def first_key(lines): +def first_key(lines): # type: (List[str]) -> str """Returns a string representing the sort key of a block. The sort key is the first YAML key we encounter, ignoring comments, and @@ -95,9 +98,11 @@ def first_key(lines): if any(line.startswith(quote) for quote in QUOTES): return line[1:] return line + else: + return '' # not actually reached in reality -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py index c432682f..a5ea1ea9 100644 --- a/pre_commit_hooks/string_fixer.py +++ b/pre_commit_hooks/string_fixer.py @@ -4,34 +4,39 @@ import argparse import io +import re import tokenize +from typing import List +from typing import Optional +from typing import Sequence +START_QUOTE_RE = re.compile('^[a-zA-Z]*"') -double_quote_starts = tuple(s for s in tokenize.single_quoted if '"' in s) - -def handle_match(token_text): +def handle_match(token_text): # type: (str) -> str if '"""' in token_text or "'''" in token_text: return token_text - for double_quote_start in double_quote_starts: - if token_text.startswith(double_quote_start): - meat = token_text[len(double_quote_start):-1] - if '"' in meat or "'" in meat: - break - return double_quote_start.replace('"', "'") + meat + "'" - return token_text + match = START_QUOTE_RE.match(token_text) + if match is not None: + meat = token_text[match.end():-1] + if '"' in meat or "'" in meat: + return token_text + else: + return match.group().replace('"', "'") + meat + "'" + else: + return token_text -def get_line_offsets_by_line_no(src): +def get_line_offsets_by_line_no(src): # type: (str) -> List[int] # Padded so we can index with line number - offsets = [None, 0] + offsets = [-1, 0] for line in src.splitlines(): offsets.append(offsets[-1] + len(line) + 1) return offsets -def fix_strings(filename): +def fix_strings(filename): # type: (str) -> int with io.open(filename, encoding='UTF-8') as f: contents = f.read() line_offsets = get_line_offsets_by_line_no(contents) @@ -60,7 +65,7 @@ def fix_strings(filename): return 0 -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) @@ -74,3 +79,7 @@ def main(argv=None): retv |= return_value return retv + + +if __name__ == '__main__': + exit(main()) diff --git a/pre_commit_hooks/tests_should_end_in_test.py b/pre_commit_hooks/tests_should_end_in_test.py index 9bea20db..7a1e7c04 100644 --- a/pre_commit_hooks/tests_should_end_in_test.py +++ b/pre_commit_hooks/tests_should_end_in_test.py @@ -1,12 +1,14 @@ from __future__ import print_function import argparse +import os.path import re import sys -from os.path import basename +from typing import Optional +from typing import Sequence -def validate_files(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') parser.add_argument( @@ -18,7 +20,7 @@ def validate_files(argv=None): retcode = 0 test_name_pattern = 'test.*.py' if args.django else '.*_test.py' for filename in args.filenames: - base = basename(filename) + base = os.path.basename(filename) if ( not re.match(test_name_pattern, base) and not base == '__init__.py' and @@ -35,4 +37,4 @@ def validate_files(argv=None): if __name__ == '__main__': - sys.exit(validate_files()) + sys.exit(main()) diff --git a/pre_commit_hooks/trailing_whitespace_fixer.py b/pre_commit_hooks/trailing_whitespace_fixer.py index 1b54fbd2..4fe7975e 100644 --- a/pre_commit_hooks/trailing_whitespace_fixer.py +++ b/pre_commit_hooks/trailing_whitespace_fixer.py @@ -3,9 +3,11 @@ import argparse import os import sys +from typing import Optional +from typing import Sequence -def _fix_file(filename, is_markdown): +def _fix_file(filename, is_markdown): # type: (str, bool) -> bool with open(filename, mode='rb') as file_processed: lines = file_processed.readlines() newlines = [_process_line(line, is_markdown) for line in lines] @@ -18,7 +20,7 @@ def _fix_file(filename, is_markdown): return False -def _process_line(line, is_markdown): +def _process_line(line, is_markdown): # type: (bytes, bool) -> bytes if line[-2:] == b'\r\n': eol = b'\r\n' elif line[-1:] == b'\n': @@ -31,7 +33,7 @@ def _process_line(line, is_markdown): return line.rstrip() + eol -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '--no-markdown-linebreak-ext', diff --git a/pre_commit_hooks/util.py b/pre_commit_hooks/util.py index 269b5537..5d1d11bd 100644 --- a/pre_commit_hooks/util.py +++ b/pre_commit_hooks/util.py @@ -3,23 +3,25 @@ from __future__ import unicode_literals import subprocess +from typing import Any +from typing import Set class CalledProcessError(RuntimeError): pass -def added_files(): +def added_files(): # type: () -> Set[str] return set(cmd_output( 'git', 'diff', '--staged', '--name-only', '--diff-filter=A', ).splitlines()) -def cmd_output(*cmd, **kwargs): +def cmd_output(*cmd, **kwargs): # type: (*str, **Any) -> str retcode = kwargs.pop('retcode', 0) - popen_kwargs = {'stdout': subprocess.PIPE, 'stderr': subprocess.PIPE} - popen_kwargs.update(kwargs) - proc = subprocess.Popen(cmd, **popen_kwargs) + kwargs.setdefault('stdout', subprocess.PIPE) + kwargs.setdefault('stderr', subprocess.PIPE) + proc = subprocess.Popen(cmd, **kwargs) stdout, stderr = proc.communicate() stdout = stdout.decode('UTF-8') if stderr is not None: diff --git a/setup.py b/setup.py index 84892a7c..756500b2 100644 --- a/setup.py +++ b/setup.py @@ -28,35 +28,36 @@ 'ruamel.yaml>=0.15', 'six', ], + extras_require={':python_version<"3.5"': ['typing']}, entry_points={ 'console_scripts': [ 'autopep8-wrapper = pre_commit_hooks.autopep8_wrapper:main', 'check-added-large-files = pre_commit_hooks.check_added_large_files:main', - 'check-ast = pre_commit_hooks.check_ast:check_ast', + 'check-ast = pre_commit_hooks.check_ast:main', 'check-builtin-literals = pre_commit_hooks.check_builtin_literals:main', 'check-byte-order-marker = pre_commit_hooks.check_byte_order_marker:main', 'check-case-conflict = pre_commit_hooks.check_case_conflict:main', 'check-docstring-first = pre_commit_hooks.check_docstring_first:main', 'check-executables-have-shebangs = pre_commit_hooks.check_executables_have_shebangs:main', - 'check-json = pre_commit_hooks.check_json:check_json', - 'check-merge-conflict = pre_commit_hooks.check_merge_conflict:detect_merge_conflict', - 'check-symlinks = pre_commit_hooks.check_symlinks:check_symlinks', + 'check-json = pre_commit_hooks.check_json:main', + 'check-merge-conflict = pre_commit_hooks.check_merge_conflict:main', + 'check-symlinks = pre_commit_hooks.check_symlinks:main', 'check-vcs-permalinks = pre_commit_hooks.check_vcs_permalinks:main', - 'check-xml = pre_commit_hooks.check_xml:check_xml', - 'check-yaml = pre_commit_hooks.check_yaml:check_yaml', + 'check-xml = pre_commit_hooks.check_xml:main', + 'check-yaml = pre_commit_hooks.check_yaml:main', 'debug-statement-hook = pre_commit_hooks.debug_statement_hook:main', 'detect-aws-credentials = pre_commit_hooks.detect_aws_credentials:main', - 'detect-private-key = pre_commit_hooks.detect_private_key:detect_private_key', + 'detect-private-key = pre_commit_hooks.detect_private_key:main', 'double-quote-string-fixer = pre_commit_hooks.string_fixer:main', - 'end-of-file-fixer = pre_commit_hooks.end_of_file_fixer:end_of_file_fixer', + 'end-of-file-fixer = pre_commit_hooks.end_of_file_fixer:main', 'file-contents-sorter = pre_commit_hooks.file_contents_sorter:main', 'fix-encoding-pragma = pre_commit_hooks.fix_encoding_pragma:main', 'forbid-new-submodules = pre_commit_hooks.forbid_new_submodules:main', 'mixed-line-ending = pre_commit_hooks.mixed_line_ending:main', - 'name-tests-test = pre_commit_hooks.tests_should_end_in_test:validate_files', + 'name-tests-test = pre_commit_hooks.tests_should_end_in_test:main', 'no-commit-to-branch = pre_commit_hooks.no_commit_to_branch:main', - 'pretty-format-json = pre_commit_hooks.pretty_format_json:pretty_format_json', - 'requirements-txt-fixer = pre_commit_hooks.requirements_txt_fixer:fix_requirements_txt', + 'pretty-format-json = pre_commit_hooks.pretty_format_json:main', + 'requirements-txt-fixer = pre_commit_hooks.requirements_txt_fixer:main', 'sort-simple-yaml = pre_commit_hooks.sort_simple_yaml:main', 'trailing-whitespace-fixer = pre_commit_hooks.trailing_whitespace_fixer:main', ], diff --git a/testing/resources/bad_json_latin1.nonjson b/testing/resources/bad_json_latin1.nonjson old mode 100755 new mode 100644 diff --git a/testing/resources/builtin_constructors.py b/testing/resources/builtin_constructors.py deleted file mode 100644 index 174a9e85..00000000 --- a/testing/resources/builtin_constructors.py +++ /dev/null @@ -1,17 +0,0 @@ -from six.moves import builtins - -c1 = complex() -d1 = dict() -f1 = float() -i1 = int() -l1 = list() -s1 = str() -t1 = tuple() - -c2 = builtins.complex() -d2 = builtins.dict() -f2 = builtins.float() -i2 = builtins.int() -l2 = builtins.list() -s2 = builtins.str() -t2 = builtins.tuple() diff --git a/testing/resources/builtin_literals.py b/testing/resources/builtin_literals.py deleted file mode 100644 index 8513b706..00000000 --- a/testing/resources/builtin_literals.py +++ /dev/null @@ -1,7 +0,0 @@ -c1 = 0j -d1 = {} -f1 = 0.0 -i1 = 0 -l1 = [] -s1 = '' -t1 = () diff --git a/tests/check_ast_test.py b/tests/check_ast_test.py index 64916ba4..c16f5fcc 100644 --- a/tests/check_ast_test.py +++ b/tests/check_ast_test.py @@ -1,15 +1,15 @@ from __future__ import absolute_import from __future__ import unicode_literals -from pre_commit_hooks.check_ast import check_ast +from pre_commit_hooks.check_ast import main from testing.util import get_resource_path def test_failing_file(): - ret = check_ast([get_resource_path('cannot_parse_ast.notpy')]) + ret = main([get_resource_path('cannot_parse_ast.notpy')]) assert ret == 1 def test_passing_file(): - ret = check_ast([__file__]) + ret = main([__file__]) assert ret == 0 diff --git a/tests/check_builtin_literals_test.py b/tests/check_builtin_literals_test.py index 86b79e3b..d4ac30f8 100644 --- a/tests/check_builtin_literals_test.py +++ b/tests/check_builtin_literals_test.py @@ -5,7 +5,35 @@ from pre_commit_hooks.check_builtin_literals import BuiltinTypeCall from pre_commit_hooks.check_builtin_literals import BuiltinTypeVisitor from pre_commit_hooks.check_builtin_literals import main -from testing.util import get_resource_path + +BUILTIN_CONSTRUCTORS = '''\ +from six.moves import builtins + +c1 = complex() +d1 = dict() +f1 = float() +i1 = int() +l1 = list() +s1 = str() +t1 = tuple() + +c2 = builtins.complex() +d2 = builtins.dict() +f2 = builtins.float() +i2 = builtins.int() +l2 = builtins.list() +s2 = builtins.str() +t2 = builtins.tuple() +''' +BUILTIN_LITERALS = '''\ +c1 = 0j +d1 = {} +f1 = 0.0 +i1 = 0 +l1 = [] +s1 = '' +t1 = () +''' @pytest.fixture @@ -94,24 +122,26 @@ def test_dict_no_allow_kwargs_exprs(expression, calls): def test_ignore_constructors(): visitor = BuiltinTypeVisitor(ignore=('complex', 'dict', 'float', 'int', 'list', 'str', 'tuple')) - with open(get_resource_path('builtin_constructors.py'), 'rb') as f: - visitor.visit(ast.parse(f.read(), 'builtin_constructors.py')) + visitor.visit(ast.parse(BUILTIN_CONSTRUCTORS)) assert visitor.builtin_type_calls == [] -def test_failing_file(): - rc = main([get_resource_path('builtin_constructors.py')]) +def test_failing_file(tmpdir): + f = tmpdir.join('f.py') + f.write(BUILTIN_CONSTRUCTORS) + rc = main([f.strpath]) assert rc == 1 -def test_passing_file(): - rc = main([get_resource_path('builtin_literals.py')]) +def test_passing_file(tmpdir): + f = tmpdir.join('f.py') + f.write(BUILTIN_LITERALS) + rc = main([f.strpath]) assert rc == 0 -def test_failing_file_ignore_all(): - rc = main([ - '--ignore=complex,dict,float,int,list,str,tuple', - get_resource_path('builtin_constructors.py'), - ]) +def test_failing_file_ignore_all(tmpdir): + f = tmpdir.join('f.py') + f.write(BUILTIN_CONSTRUCTORS) + rc = main(['--ignore=complex,dict,float,int,list,str,tuple', f.strpath]) assert rc == 0 diff --git a/tests/check_json_test.py b/tests/check_json_test.py index 6ba26c14..6654ed10 100644 --- a/tests/check_json_test.py +++ b/tests/check_json_test.py @@ -1,6 +1,6 @@ import pytest -from pre_commit_hooks.check_json import check_json +from pre_commit_hooks.check_json import main from testing.util import get_resource_path @@ -11,8 +11,8 @@ ('ok_json.json', 0), ), ) -def test_check_json(capsys, filename, expected_retval): - ret = check_json([get_resource_path(filename)]) +def test_main(capsys, filename, expected_retval): + ret = main([get_resource_path(filename)]) assert ret == expected_retval if expected_retval == 1: stdout, _ = capsys.readouterr() diff --git a/tests/check_merge_conflict_test.py b/tests/check_merge_conflict_test.py index b04c70e0..50e389c9 100644 --- a/tests/check_merge_conflict_test.py +++ b/tests/check_merge_conflict_test.py @@ -6,7 +6,7 @@ import pytest -from pre_commit_hooks.check_merge_conflict import detect_merge_conflict +from pre_commit_hooks.check_merge_conflict import main from pre_commit_hooks.util import cmd_output from testing.util import get_resource_path @@ -102,7 +102,7 @@ def repository_pending_merge(tmpdir): @pytest.mark.usefixtures('f1_is_a_conflict_file') def test_merge_conflicts_git(): - assert detect_merge_conflict(['f1']) == 1 + assert main(['f1']) == 1 @pytest.mark.parametrize( @@ -110,7 +110,7 @@ def test_merge_conflicts_git(): ) def test_merge_conflicts_failing(contents, repository_pending_merge): repository_pending_merge.join('f2').write_binary(contents) - assert detect_merge_conflict(['f2']) == 1 + assert main(['f2']) == 1 @pytest.mark.parametrize( @@ -118,22 +118,22 @@ def test_merge_conflicts_failing(contents, repository_pending_merge): ) def test_merge_conflicts_ok(contents, f1_is_a_conflict_file): f1_is_a_conflict_file.join('f1').write_binary(contents) - assert detect_merge_conflict(['f1']) == 0 + assert main(['f1']) == 0 @pytest.mark.usefixtures('f1_is_a_conflict_file') def test_ignores_binary_files(): shutil.copy(get_resource_path('img1.jpg'), 'f1') - assert detect_merge_conflict(['f1']) == 0 + assert main(['f1']) == 0 def test_does_not_care_when_not_in_a_merge(tmpdir): f = tmpdir.join('README.md') f.write_binary(b'problem\n=======\n') - assert detect_merge_conflict([str(f.realpath())]) == 0 + assert main([str(f.realpath())]) == 0 def test_care_when_assumed_merge(tmpdir): f = tmpdir.join('README.md') f.write_binary(b'problem\n=======\n') - assert detect_merge_conflict([str(f.realpath()), '--assume-in-merge']) == 1 + assert main([str(f.realpath()), '--assume-in-merge']) == 1 diff --git a/tests/check_symlinks_test.py b/tests/check_symlinks_test.py index 0414df55..ecbc7aec 100644 --- a/tests/check_symlinks_test.py +++ b/tests/check_symlinks_test.py @@ -2,7 +2,7 @@ import pytest -from pre_commit_hooks.check_symlinks import check_symlinks +from pre_commit_hooks.check_symlinks import main xfail_symlink = pytest.mark.xfail(os.name == 'nt', reason='No symlink support') @@ -12,12 +12,12 @@ @pytest.mark.parametrize( ('dest', 'expected'), (('exists', 0), ('does-not-exist', 1)), ) -def test_check_symlinks(tmpdir, dest, expected): # pragma: no cover (symlinks) +def test_main(tmpdir, dest, expected): # pragma: no cover (symlinks) tmpdir.join('exists').ensure() symlink = tmpdir.join('symlink') symlink.mksymlinkto(tmpdir.join(dest)) - assert check_symlinks((symlink.strpath,)) == expected + assert main((symlink.strpath,)) == expected -def test_check_symlinks_normal_file(tmpdir): - assert check_symlinks((tmpdir.join('f').ensure().strpath,)) == 0 +def test_main_normal_file(tmpdir): + assert main((tmpdir.join('f').ensure().strpath,)) == 0 diff --git a/tests/check_xml_test.py b/tests/check_xml_test.py index 84e365d1..357bad64 100644 --- a/tests/check_xml_test.py +++ b/tests/check_xml_test.py @@ -1,6 +1,6 @@ import pytest -from pre_commit_hooks.check_xml import check_xml +from pre_commit_hooks.check_xml import main from testing.util import get_resource_path @@ -10,6 +10,6 @@ ('ok_xml.xml', 0), ), ) -def test_check_xml(filename, expected_retval): - ret = check_xml([get_resource_path(filename)]) +def test_main(filename, expected_retval): + ret = main([get_resource_path(filename)]) assert ret == expected_retval diff --git a/tests/check_yaml_test.py b/tests/check_yaml_test.py index aa357f13..d267150a 100644 --- a/tests/check_yaml_test.py +++ b/tests/check_yaml_test.py @@ -3,7 +3,7 @@ import pytest -from pre_commit_hooks.check_yaml import check_yaml +from pre_commit_hooks.check_yaml import main from testing.util import get_resource_path @@ -13,29 +13,29 @@ ('ok_yaml.yaml', 0), ), ) -def test_check_yaml(filename, expected_retval): - ret = check_yaml([get_resource_path(filename)]) +def test_main(filename, expected_retval): + ret = main([get_resource_path(filename)]) assert ret == expected_retval -def test_check_yaml_allow_multiple_documents(tmpdir): +def test_main_allow_multiple_documents(tmpdir): f = tmpdir.join('test.yaml') f.write('---\nfoo\n---\nbar\n') # should fail without the setting - assert check_yaml((f.strpath,)) + assert main((f.strpath,)) # should pass when we allow multiple documents - assert not check_yaml(('--allow-multiple-documents', f.strpath)) + assert not main(('--allow-multiple-documents', f.strpath)) def test_fails_even_with_allow_multiple_documents(tmpdir): f = tmpdir.join('test.yaml') f.write('[') - assert check_yaml(('--allow-multiple-documents', f.strpath)) + assert main(('--allow-multiple-documents', f.strpath)) -def test_check_yaml_unsafe(tmpdir): +def test_main_unsafe(tmpdir): f = tmpdir.join('test.yaml') f.write( 'some_foo: !vault |\n' @@ -43,12 +43,12 @@ def test_check_yaml_unsafe(tmpdir): ' deadbeefdeadbeefdeadbeef\n', ) # should fail "safe" check - assert check_yaml((f.strpath,)) + assert main((f.strpath,)) # should pass when we allow unsafe documents - assert not check_yaml(('--unsafe', f.strpath)) + assert not main(('--unsafe', f.strpath)) -def test_check_yaml_unsafe_still_fails_on_syntax_errors(tmpdir): +def test_main_unsafe_still_fails_on_syntax_errors(tmpdir): f = tmpdir.join('test.yaml') f.write('[') - assert check_yaml(('--unsafe', f.strpath)) + assert main(('--unsafe', f.strpath)) diff --git a/tests/detect_private_key_test.py b/tests/detect_private_key_test.py index fdd63a21..9266f2b0 100644 --- a/tests/detect_private_key_test.py +++ b/tests/detect_private_key_test.py @@ -1,6 +1,6 @@ import pytest -from pre_commit_hooks.detect_private_key import detect_private_key +from pre_commit_hooks.detect_private_key import main # Input, expected return value TESTS = ( @@ -18,7 +18,7 @@ @pytest.mark.parametrize(('input_s', 'expected_retval'), TESTS) -def test_detect_private_key(input_s, expected_retval, tmpdir): +def test_main(input_s, expected_retval, tmpdir): path = tmpdir.join('file.txt') path.write_binary(input_s) - assert detect_private_key([path.strpath]) == expected_retval + assert main([path.strpath]) == expected_retval diff --git a/tests/end_of_file_fixer_test.py b/tests/end_of_file_fixer_test.py index f8710afc..7f644e76 100644 --- a/tests/end_of_file_fixer_test.py +++ b/tests/end_of_file_fixer_test.py @@ -2,8 +2,8 @@ import pytest -from pre_commit_hooks.end_of_file_fixer import end_of_file_fixer from pre_commit_hooks.end_of_file_fixer import fix_file +from pre_commit_hooks.end_of_file_fixer import main # Input, expected return value, expected output @@ -35,7 +35,7 @@ def test_integration(input_s, expected_retval, output, tmpdir): path = tmpdir.join('file.txt') path.write_binary(input_s) - ret = end_of_file_fixer([path.strpath]) + ret = main([path.strpath]) file_output = path.read_binary() assert file_output == output diff --git a/tests/no_commit_to_branch_test.py b/tests/no_commit_to_branch_test.py index c275bf71..e978ba27 100644 --- a/tests/no_commit_to_branch_test.py +++ b/tests/no_commit_to_branch_test.py @@ -11,24 +11,24 @@ def test_other_branch(temp_git_dir): with temp_git_dir.as_cwd(): cmd_output('git', 'checkout', '-b', 'anotherbranch') - assert is_on_branch(('master',)) is False + assert is_on_branch({'master'}) is False def test_multi_branch(temp_git_dir): with temp_git_dir.as_cwd(): cmd_output('git', 'checkout', '-b', 'another/branch') - assert is_on_branch(('master',)) is False + assert is_on_branch({'master'}) is False def test_multi_branch_fail(temp_git_dir): with temp_git_dir.as_cwd(): cmd_output('git', 'checkout', '-b', 'another/branch') - assert is_on_branch(('another/branch',)) is True + assert is_on_branch({'another/branch'}) is True def test_master_branch(temp_git_dir): with temp_git_dir.as_cwd(): - assert is_on_branch(('master',)) is True + assert is_on_branch({'master'}) is True def test_main_branch_call(temp_git_dir): diff --git a/tests/pretty_format_json_test.py b/tests/pretty_format_json_test.py index 7ce7e160..8d82d746 100644 --- a/tests/pretty_format_json_test.py +++ b/tests/pretty_format_json_test.py @@ -3,8 +3,8 @@ import pytest from six import PY2 +from pre_commit_hooks.pretty_format_json import main from pre_commit_hooks.pretty_format_json import parse_num_to_int -from pre_commit_hooks.pretty_format_json import pretty_format_json from testing.util import get_resource_path @@ -23,8 +23,8 @@ def test_parse_num_to_int(): ('pretty_formatted_json.json', 0), ), ) -def test_pretty_format_json(filename, expected_retval): - ret = pretty_format_json([get_resource_path(filename)]) +def test_main(filename, expected_retval): + ret = main([get_resource_path(filename)]) assert ret == expected_retval @@ -36,8 +36,8 @@ def test_pretty_format_json(filename, expected_retval): ('pretty_formatted_json.json', 0), ), ) -def test_unsorted_pretty_format_json(filename, expected_retval): - ret = pretty_format_json(['--no-sort-keys', get_resource_path(filename)]) +def test_unsorted_main(filename, expected_retval): + ret = main(['--no-sort-keys', get_resource_path(filename)]) assert ret == expected_retval @@ -51,17 +51,17 @@ def test_unsorted_pretty_format_json(filename, expected_retval): ('tab_pretty_formatted_json.json', 0), ), ) -def test_tab_pretty_format_json(filename, expected_retval): # pragma: no cover - ret = pretty_format_json(['--indent', '\t', get_resource_path(filename)]) +def test_tab_main(filename, expected_retval): # pragma: no cover + ret = main(['--indent', '\t', get_resource_path(filename)]) assert ret == expected_retval -def test_non_ascii_pretty_format_json(): - ret = pretty_format_json(['--no-ensure-ascii', get_resource_path('non_ascii_pretty_formatted_json.json')]) +def test_non_ascii_main(): + ret = main(['--no-ensure-ascii', get_resource_path('non_ascii_pretty_formatted_json.json')]) assert ret == 0 -def test_autofix_pretty_format_json(tmpdir): +def test_autofix_main(tmpdir): srcfile = tmpdir.join('to_be_json_formatted.json') shutil.copyfile( get_resource_path('not_pretty_formatted_json.json'), @@ -69,30 +69,30 @@ def test_autofix_pretty_format_json(tmpdir): ) # now launch the autofix on that file - ret = pretty_format_json(['--autofix', srcfile.strpath]) + ret = main(['--autofix', srcfile.strpath]) # it should have formatted it assert ret == 1 # file was formatted (shouldn't trigger linter again) - ret = pretty_format_json([srcfile.strpath]) + ret = main([srcfile.strpath]) assert ret == 0 def test_orderfile_get_pretty_format(): - ret = pretty_format_json(['--top-keys=alist', get_resource_path('pretty_formatted_json.json')]) + ret = main(['--top-keys=alist', get_resource_path('pretty_formatted_json.json')]) assert ret == 0 def test_not_orderfile_get_pretty_format(): - ret = pretty_format_json(['--top-keys=blah', get_resource_path('pretty_formatted_json.json')]) + ret = main(['--top-keys=blah', get_resource_path('pretty_formatted_json.json')]) assert ret == 1 def test_top_sorted_get_pretty_format(): - ret = pretty_format_json(['--top-keys=01-alist,alist', get_resource_path('top_sorted_json.json')]) + ret = main(['--top-keys=01-alist,alist', get_resource_path('top_sorted_json.json')]) assert ret == 0 -def test_badfile_pretty_format_json(): - ret = pretty_format_json([get_resource_path('ok_yaml.yaml')]) +def test_badfile_main(): + ret = main([get_resource_path('ok_yaml.yaml')]) assert ret == 1 diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py index 437cebd9..b3a79423 100644 --- a/tests/requirements_txt_fixer_test.py +++ b/tests/requirements_txt_fixer_test.py @@ -1,7 +1,7 @@ import pytest from pre_commit_hooks.requirements_txt_fixer import FAIL -from pre_commit_hooks.requirements_txt_fixer import fix_requirements_txt +from pre_commit_hooks.requirements_txt_fixer import main from pre_commit_hooks.requirements_txt_fixer import PASS from pre_commit_hooks.requirements_txt_fixer import Requirement @@ -36,7 +36,7 @@ def test_integration(input_s, expected_retval, output, tmpdir): path = tmpdir.join('file.txt') path.write_binary(input_s) - output_retval = fix_requirements_txt([path.strpath]) + output_retval = main([path.strpath]) assert path.read_binary() == output assert output_retval == expected_retval @@ -44,7 +44,7 @@ def test_integration(input_s, expected_retval, output, tmpdir): def test_requirement_object(): top_of_file = Requirement() - top_of_file.comments.append('#foo') + top_of_file.comments.append(b'#foo') top_of_file.value = b'\n' requirement_foo = Requirement() diff --git a/tests/sort_simple_yaml_test.py b/tests/sort_simple_yaml_test.py index 176d12f5..72f5becc 100644 --- a/tests/sort_simple_yaml_test.py +++ b/tests/sort_simple_yaml_test.py @@ -110,9 +110,9 @@ def test_first_key(): lines = ['# some comment', '"a": 42', 'b: 17', '', 'c: 19'] assert first_key(lines) == 'a": 42' - # no lines + # no lines (not a real situation) lines = [] - assert first_key(lines) is None + assert first_key(lines) == '' @pytest.mark.parametrize('bad_lines,good_lines,_', TEST_SORTS) diff --git a/tests/tests_should_end_in_test_test.py b/tests/tests_should_end_in_test_test.py index dc686a5f..4eb98e7d 100644 --- a/tests/tests_should_end_in_test_test.py +++ b/tests/tests_should_end_in_test_test.py @@ -1,36 +1,36 @@ -from pre_commit_hooks.tests_should_end_in_test import validate_files +from pre_commit_hooks.tests_should_end_in_test import main -def test_validate_files_all_pass(): - ret = validate_files(['foo_test.py', 'bar_test.py']) +def test_main_all_pass(): + ret = main(['foo_test.py', 'bar_test.py']) assert ret == 0 -def test_validate_files_one_fails(): - ret = validate_files(['not_test_ending.py', 'foo_test.py']) +def test_main_one_fails(): + ret = main(['not_test_ending.py', 'foo_test.py']) assert ret == 1 -def test_validate_files_django_all_pass(): - ret = validate_files(['--django', 'tests.py', 'test_foo.py', 'test_bar.py', 'tests/test_baz.py']) +def test_main_django_all_pass(): + ret = main(['--django', 'tests.py', 'test_foo.py', 'test_bar.py', 'tests/test_baz.py']) assert ret == 0 -def test_validate_files_django_one_fails(): - ret = validate_files(['--django', 'not_test_ending.py', 'test_foo.py']) +def test_main_django_one_fails(): + ret = main(['--django', 'not_test_ending.py', 'test_foo.py']) assert ret == 1 def test_validate_nested_files_django_one_fails(): - ret = validate_files(['--django', 'tests/not_test_ending.py', 'test_foo.py']) + ret = main(['--django', 'tests/not_test_ending.py', 'test_foo.py']) assert ret == 1 -def test_validate_files_not_django_fails(): - ret = validate_files(['foo_test.py', 'bar_test.py', 'test_baz.py']) +def test_main_not_django_fails(): + ret = main(['foo_test.py', 'bar_test.py', 'test_baz.py']) assert ret == 1 -def test_validate_files_django_fails(): - ret = validate_files(['--django', 'foo_test.py', 'test_bar.py', 'test_baz.py']) +def test_main_django_fails(): + ret = main(['--django', 'foo_test.py', 'test_bar.py', 'test_baz.py']) assert ret == 1 diff --git a/tox.ini b/tox.ini index c131e6f0..d1e6a796 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] # These should match the travis env list -envlist = py27,py36,py37,pypy +envlist = py27,py36,py37,pypy3 [testenv] deps = -rrequirements-dev.txt