Skip to content

Commit

Permalink
add support for pyproject.toml if present (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshbode authored and koxudaxi committed Nov 25, 2019
1 parent 0199de5 commit e375a32
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 5 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,30 @@ optional arguments:
--version show version
```

## Formatting

Code generated by `datamodel-codegen` will be passed through `isort` and
`black` to produce consistent, well-formatted results. Settings for these tools
can be specified in `pyproject.toml` (located in the output directory, or in
some parent of the output directory).

Example `pyproject.toml`:
```toml
[tool.black]
string-normalization = true
line-length = 100

[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
line_length = 100
known_first_party = "kelvin"
```

See the [Black Project](https://black.readthedocs.io/en/stable/pyproject_toml.html) for more information.

## Example

```sh
Expand Down
24 changes: 21 additions & 3 deletions datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
Main function.
"""

import contextlib
import json
import os
import sys
from argparse import ArgumentParser, FileType, Namespace
from datetime import datetime, timezone
from enum import IntEnum
from pathlib import Path
from typing import IO, Any, Mapping, Optional, Sequence
from typing import IO, Any, Iterator, Mapping, Optional, Sequence

import argcomplete
from datamodel_code_generator import PythonVersion, enable_debug_message
Expand All @@ -28,6 +30,21 @@ class Exit(IntEnum):
ERROR = 1


@contextlib.contextmanager
def chdir(path: Optional[Path]) -> Iterator[None]:
"""Changes working directory and returns to previous on exit."""

if path is None:
yield
else:
prev_cwd = Path.cwd()
try:
os.chdir(path if path.is_dir() else path.parent)
yield
finally:
os.chdir(prev_cwd)


arg_parser = ArgumentParser()
arg_parser.add_argument(
'--input',
Expand Down Expand Up @@ -98,10 +115,11 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
dump_resolve_reference_action=dump_resolve_reference_action,
)

result = parser.parse()

output = Path(namespace.output) if namespace.output is not None else None

with chdir(output):
result = parser.parse()

if isinstance(result, str):
modules = {output: result}
else:
Expand Down
18 changes: 16 additions & 2 deletions datamodel_code_generator/format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path
from typing import Dict

import black
import toml
from datamodel_code_generator import PythonVersion
from isort import SortImports

Expand All @@ -12,17 +14,29 @@


def format_code(code: str, python_version: PythonVersion) -> str:

code = apply_isort(code)
code = apply_black(code, python_version)
return apply_isort(code)
return code


def apply_black(code: str, python_version: PythonVersion) -> str:

root = black.find_project_root((Path().resolve(),))
path = root / "pyproject.toml"
if path.is_file():
value = str(path)
pyproject_toml = toml.load(value)
config = pyproject_toml.get("tool", {}).get("black", {})
else:
config = {}

return black.format_str(
code,
mode=black.FileMode(
target_versions={BLACK_PYTHON_VERSION[python_version]},
string_normalization=False,
line_length=config.get("line-length", black.DEFAULT_LINE_LENGTH),
string_normalization=config.get("string-normalization", False),
),
)

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ install_requires =
black==19.3b0
isort==4.3.21
PySnooper==0.2.8
toml==0.10.0

tests_require =
pytest
Expand Down
3 changes: 3 additions & 0 deletions tests/data/project/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[tool.black]
string-normalization = true
line-length = 30
100 changes: 100 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Mapping
Expand Down Expand Up @@ -615,3 +616,102 @@ def test_main_custom_template_dir(capsys: CaptureFixture, expected: str) -> None
captured = capsys.readouterr()
assert captured.out == expected
assert not captured.err


@freeze_time('2019-07-26')
def test_pyproject():
with TemporaryDirectory() as output_dir:
output_dir = Path(output_dir)
pyproject_toml = Path(DATA_PATH) / "project" / "pyproject.toml"
shutil.copy(pyproject_toml, output_dir)
output_file: Path = output_dir / 'output.py'
return_code: Exit = main(
['--input', str(DATA_PATH / 'api.yaml'), '--output', str(output_file)]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== '''# generated by datamodel-codegen:
# filename: api.yaml
# timestamp: 2019-07-26T00:00:00+00:00
from __future__ import (
annotations,
)
from typing import (
List,
Optional,
)
from pydantic import (
BaseModel,
UrlStr,
)
class Pet(BaseModel):
id: int
name: str
tag: Optional[str] = None
class Pets(BaseModel):
__root__: List[Pet]
class User(BaseModel):
id: int
name: str
tag: Optional[str] = None
class Users(BaseModel):
__root__: List[User]
class Id(BaseModel):
__root__: str
class Rules(BaseModel):
__root__: List[str]
class Error(BaseModel):
code: int
message: str
class api(BaseModel):
apiKey: Optional[
str
] = None
apiVersionNumber: Optional[
str
] = None
apiUrl: Optional[
UrlStr
] = None
apiDocumentationUrl: Optional[
UrlStr
] = None
class apis(BaseModel):
__root__: List[api]
class Event(BaseModel):
name: Optional[str] = None
class Result(BaseModel):
event: Optional[
Event
] = None
'''
)

with pytest.raises(SystemExit):
main()

0 comments on commit e375a32

Please sign in to comment.