From e375a32e09d6c94a18c0ac663c49244f14f53a4c Mon Sep 17 00:00:00 2001 From: Josh Bode Date: Mon, 25 Nov 2019 10:58:55 +1000 Subject: [PATCH] add support for pyproject.toml if present (#80) --- README.md | 24 +++++++ datamodel_code_generator/__main__.py | 24 ++++++- datamodel_code_generator/format.py | 18 ++++- setup.cfg | 1 + tests/data/project/pyproject.toml | 3 + tests/test_main.py | 100 +++++++++++++++++++++++++++ 6 files changed, 165 insertions(+), 5 deletions(-) create mode 100644 tests/data/project/pyproject.toml diff --git a/README.md b/README.md index e3bc68a47..d9c1e6feb 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,30 @@ optional arguments: --version show version ``` +## Formatting + +Code generated by `datamodel-codegen` will be passed through `isort` and +`black` to produce consistent, well-formatted results. Settings for these tools +can be specified in `pyproject.toml` (located in the output directory, or in +some parent of the output directory). + +Example `pyproject.toml`: +```toml +[tool.black] +string-normalization = true +line-length = 100 + +[tool.isort] +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +line_length = 100 +known_first_party = "kelvin" +``` + +See the [Black Project](https://black.readthedocs.io/en/stable/pyproject_toml.html) for more information. + ## Example ```sh diff --git a/datamodel_code_generator/__main__.py b/datamodel_code_generator/__main__.py index ac9ab7a41..77602650a 100755 --- a/datamodel_code_generator/__main__.py +++ b/datamodel_code_generator/__main__.py @@ -4,13 +4,15 @@ Main function. """ +import contextlib import json +import os import sys from argparse import ArgumentParser, FileType, Namespace from datetime import datetime, timezone from enum import IntEnum from pathlib import Path -from typing import IO, Any, Mapping, Optional, Sequence +from typing import IO, Any, Iterator, Mapping, Optional, Sequence import argcomplete from datamodel_code_generator import PythonVersion, enable_debug_message @@ -28,6 +30,21 @@ class Exit(IntEnum): ERROR = 1 +@contextlib.contextmanager +def chdir(path: Optional[Path]) -> Iterator[None]: + """Changes working directory and returns to previous on exit.""" + + if path is None: + yield + else: + prev_cwd = Path.cwd() + try: + os.chdir(path if path.is_dir() else path.parent) + yield + finally: + os.chdir(prev_cwd) + + arg_parser = ArgumentParser() arg_parser.add_argument( '--input', @@ -98,10 +115,11 @@ def main(args: Optional[Sequence[str]] = None) -> Exit: dump_resolve_reference_action=dump_resolve_reference_action, ) - result = parser.parse() - output = Path(namespace.output) if namespace.output is not None else None + with chdir(output): + result = parser.parse() + if isinstance(result, str): modules = {output: result} else: diff --git a/datamodel_code_generator/format.py b/datamodel_code_generator/format.py index cf0c5ba55..0cb611aec 100644 --- a/datamodel_code_generator/format.py +++ b/datamodel_code_generator/format.py @@ -1,6 +1,8 @@ +from pathlib import Path from typing import Dict import black +import toml from datamodel_code_generator import PythonVersion from isort import SortImports @@ -12,17 +14,29 @@ def format_code(code: str, python_version: PythonVersion) -> str: + + code = apply_isort(code) code = apply_black(code, python_version) - return apply_isort(code) + return code def apply_black(code: str, python_version: PythonVersion) -> str: + root = black.find_project_root((Path().resolve(),)) + path = root / "pyproject.toml" + if path.is_file(): + value = str(path) + pyproject_toml = toml.load(value) + config = pyproject_toml.get("tool", {}).get("black", {}) + else: + config = {} + return black.format_str( code, mode=black.FileMode( target_versions={BLACK_PYTHON_VERSION[python_version]}, - string_normalization=False, + line_length=config.get("line-length", black.DEFAULT_LINE_LENGTH), + string_normalization=config.get("string-normalization", False), ), ) diff --git a/setup.cfg b/setup.cfg index 6a2762cf0..c3153021f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ install_requires = black==19.3b0 isort==4.3.21 PySnooper==0.2.8 + toml==0.10.0 tests_require = pytest diff --git a/tests/data/project/pyproject.toml b/tests/data/project/pyproject.toml new file mode 100644 index 000000000..e5e55fc0a --- /dev/null +++ b/tests/data/project/pyproject.toml @@ -0,0 +1,3 @@ +[tool.black] +string-normalization = true +line-length = 30 diff --git a/tests/test_main.py b/tests/test_main.py index 722a652f0..d6ef89cf9 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,3 +1,4 @@ +import shutil from pathlib import Path from tempfile import TemporaryDirectory from typing import Mapping @@ -615,3 +616,102 @@ def test_main_custom_template_dir(capsys: CaptureFixture, expected: str) -> None captured = capsys.readouterr() assert captured.out == expected assert not captured.err + + +@freeze_time('2019-07-26') +def test_pyproject(): + with TemporaryDirectory() as output_dir: + output_dir = Path(output_dir) + pyproject_toml = Path(DATA_PATH) / "project" / "pyproject.toml" + shutil.copy(pyproject_toml, output_dir) + output_file: Path = output_dir / 'output.py' + return_code: Exit = main( + ['--input', str(DATA_PATH / 'api.yaml'), '--output', str(output_file)] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == '''# generated by datamodel-codegen: +# filename: api.yaml +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import ( + annotations, +) + +from typing import ( + List, + Optional, +) + +from pydantic import ( + BaseModel, + UrlStr, +) + + +class Pet(BaseModel): + id: int + name: str + tag: Optional[str] = None + + +class Pets(BaseModel): + __root__: List[Pet] + + +class User(BaseModel): + id: int + name: str + tag: Optional[str] = None + + +class Users(BaseModel): + __root__: List[User] + + +class Id(BaseModel): + __root__: str + + +class Rules(BaseModel): + __root__: List[str] + + +class Error(BaseModel): + code: int + message: str + + +class api(BaseModel): + apiKey: Optional[ + str + ] = None + apiVersionNumber: Optional[ + str + ] = None + apiUrl: Optional[ + UrlStr + ] = None + apiDocumentationUrl: Optional[ + UrlStr + ] = None + + +class apis(BaseModel): + __root__: List[api] + + +class Event(BaseModel): + name: Optional[str] = None + + +class Result(BaseModel): + event: Optional[ + Event + ] = None +''' + ) + + with pytest.raises(SystemExit): + main()