Skip to content

Commit

Permalink
Add support for custom templates and extra template data (#71)
Browse files Browse the repository at this point in the history
* Add support for custom templates and extra template data

- point to a directory containing template overrides
  (`BaseModel.jinja2`, etc)
- add custom data to templates for rendering from JSON file to support
  incorporating data that exists outside of openapi spec

* fix typo
  • Loading branch information
joshbode authored and koxudaxi committed Oct 16, 2019
1 parent 30ae213 commit 5697b69
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 15 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,26 @@ $ pip install datamodel-code-generator
## Usage

The `datamodel-codegen` command:
``` datamodel-code-generator[10:03:22]
```
usage: datamodel-codegen [-h] [--input INPUT] [--output OUTPUT]
[--base-class BASE_CLASS]
[--custom-template-dir CUSTOM_TEMPLATE_DIR]
[--extra-template-data EXTRA_TEMPLATE_DATA]
[--target-python-version {3.6,3.7}] [--debug]
optional arguments:
-h, --help show this help message and exit
--input INPUT Open API YAML file (default: stdin)
--output OUTPUT Output file (default: stdout)
--base-class BASE_CLASS
Base Class (default: pydantic.BaseModel)
--custom-template-dir CUSTOM_TEMPLATE_DIR
Custom Template Directory
--extra-template-data EXTRA_TEMPLATE_DATA
Extra Template Data
--target-python-version {3.6,3.7}
target python version (default: 3.7)
--debug show debug message
```

## Example
Expand Down
18 changes: 17 additions & 1 deletion datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
Main function.
"""

import json
import sys
from argparse import ArgumentParser, FileType, Namespace
from datetime import datetime, timezone
from enum import IntEnum
from pathlib import Path
from typing import IO, Any, Optional, Sequence
from typing import IO, Any, Mapping, Optional, Sequence

import argcomplete
from datamodel_code_generator import PythonVersion, enable_debug_message
Expand Down Expand Up @@ -41,6 +42,12 @@ class Exit(IntEnum):
type=str,
default='pydantic.BaseModel',
)
arg_parser.add_argument(
'--custom-template-dir', help='Custom Template Directory', type=str
)
arg_parser.add_argument(
'--extra-template-data', help='Extra Template Data', type=FileType('rt')
)
arg_parser.add_argument(
'--target-python-version',
help='target python version (default: 3.7)',
Expand All @@ -66,10 +73,19 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:

from datamodel_code_generator.parser.openapi import OpenAPIParser

extra_template_data: Optional[Mapping[str, Any]]
if namespace.extra_template_data is not None:
with namespace.extra_template_data as data:
extra_template_data = json.load(data)
else:
extra_template_data = None

parser = OpenAPIParser(
BaseModel,
CustomRootType,
base_class=namespace.base_class,
custom_template_dir=namespace.custom_template_dir,
extra_template_data=extra_template_data,
target_python_version=PythonVersion(namespace.target_python_version),
text=namespace.input.read(),
dump_resolve_reference_action=dump_resolve_reference_action,
Expand Down
24 changes: 20 additions & 4 deletions datamodel_code_generator/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from functools import wraps
from pathlib import Path
from typing import Any, Callable, List, Optional, Set
from typing import Any, Callable, List, Mapping, Optional, Set

from datamodel_code_generator.imports import (
IMPORT_LIST,
Expand Down Expand Up @@ -87,8 +87,8 @@ def __init__(self, **values: Any) -> None:


class TemplateBase(ABC):
def __init__(self, template_file_path: str) -> None:
self.template_file_path: str = template_file_path
def __init__(self, template_file_path: Path) -> None:
self.template_file_path: Path = template_file_path
self._template: Template = Template(
(TEMPLATE_DIR / self.template_file_path).read_text()
)
Expand Down Expand Up @@ -119,13 +119,21 @@ def __init__(
decorators: Optional[List[str]] = None,
base_classes: Optional[List[str]] = None,
custom_base_class: Optional[str] = None,
custom_template_dir: Optional[Path] = None,
extra_template_data: Optional[Mapping[str, Any]] = None,
imports: Optional[List[Import]] = None,
auto_import: bool = True,
reference_classes: Optional[List[str]] = None,
) -> None:
if not self.TEMPLATE_FILE_PATH:
raise Exception('TEMPLATE_FILE_PATH is undefined')

template_file_path = Path(self.TEMPLATE_FILE_PATH)
if custom_template_dir is not None:
custom_template_file_path = custom_template_dir / template_file_path.name
if custom_template_file_path.exists():
template_file_path = custom_template_file_path

self.name: str = name
self.fields: List[DataModelField] = fields or []
self.decorators: List[str] = decorators or []
Expand All @@ -149,6 +157,12 @@ def __init__(
self.imports.append(Import.from_full_path(base_class_full_path))
self.base_class = base_class_full_path.split('.')[-1]

self.extra_template_data = (
extra_template_data.get(self.name, {})
if extra_template_data is not None
else {}
)

unresolved_types: Set[str] = set()
for field in self.fields:
unresolved_types.update(set(field.unresolved_types))
Expand All @@ -158,14 +172,16 @@ def __init__(
if auto_import:
for field in self.fields:
self.imports.extend(field.imports)
super().__init__(template_file_path=self.TEMPLATE_FILE_PATH)

super().__init__(template_file_path=template_file_path)

def render(self) -> str:
response = self._render(
class_name=self.name,
fields=self.fields,
decorators=self.decorators,
base_class=self.base_class,
**self.extra_template_data,
)
return response

Expand Down
7 changes: 6 additions & 1 deletion datamodel_code_generator/model/pydantic/base_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, List, Optional
from pathlib import Path
from typing import Any, List, Mapping, Optional

from datamodel_code_generator.model import DataModel, DataModelField
from datamodel_code_generator.model.pydantic.types import get_data_type
Expand All @@ -16,6 +17,8 @@ def __init__(
decorators: Optional[List[str]] = None,
base_classes: Optional[List[str]] = None,
custom_base_class: Optional[str] = None,
custom_template_dir: Optional[Path] = None,
extra_template_data: Optional[Mapping[str, Any]] = None,
auto_import: bool = True,
reference_classes: Optional[List[str]] = None,
):
Expand All @@ -25,6 +28,8 @@ def __init__(
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
auto_import=auto_import,
reference_classes=reference_classes,
)
Expand Down
7 changes: 6 additions & 1 deletion datamodel_code_generator/model/pydantic/custom_root_type.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, List, Optional
from pathlib import Path
from typing import Any, List, Mapping, Optional

from datamodel_code_generator.imports import Import
from datamodel_code_generator.model.base import DataModel, DataModelField
Expand All @@ -17,6 +18,8 @@ def __init__(
decorators: Optional[List[str]] = None,
base_classes: Optional[List[str]] = None,
custom_base_class: Optional[str] = None,
custom_template_dir: Optional[Path] = None,
extra_template_data: Optional[Mapping[str, Any]] = None,
imports: Optional[List[Import]] = None,
auto_import: bool = True,
reference_classes: Optional[List[str]] = None,
Expand All @@ -27,6 +30,8 @@ def __init__(
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
imports=imports,
auto_import=auto_import,
reference_classes=reference_classes,
Expand Down
7 changes: 6 additions & 1 deletion datamodel_code_generator/model/pydantic/dataclass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, List, Optional
from pathlib import Path
from typing import Any, List, Mapping, Optional

from datamodel_code_generator.imports import Import
from datamodel_code_generator.model import DataModel, DataModelField
Expand All @@ -16,6 +17,8 @@ def __init__(
decorators: Optional[List[str]] = None,
base_classes: Optional[List[str]] = None,
custom_base_class: Optional[str] = None,
custom_template_dir: Optional[Path] = None,
extra_template_data: Optional[Mapping[str, Any]] = None,
auto_import: bool = True,
reference_classes: Optional[List[str]] = None,
):
Expand All @@ -26,6 +29,8 @@ def __init__(
decorators,
base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
auto_import=auto_import,
reference_classes=reference_classes,
)
Expand Down
2 changes: 1 addition & 1 deletion datamodel_code_generator/model/pydantic/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_data_type(types: Types, **kwargs: Any) -> DataType:
if types == Types.string:
return get_data_str_type(types, **kwargs)
elif types in (Types.int32, Types.int64, Types.integer):
return get_data_str_type(types, **kwargs)
return get_data_int_type(types, **kwargs)
elif types in (Types.float, Types.double, Types.number, Types.time):
return get_data_float_type(types, **kwargs)
return type_map[types]
27 changes: 24 additions & 3 deletions datamodel_code_generator/parser/openapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from itertools import groupby
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Tuple, Type, Union

from datamodel_code_generator import PythonVersion, snooper_to_methods
from datamodel_code_generator.format import format_code
Expand Down Expand Up @@ -27,6 +28,8 @@ def __init__(
data_model_field_type: Type[DataModelField] = DataModelField,
filename: Optional[str] = None,
base_class: Optional[str] = None,
custom_template_dir: Optional[str] = None,
extra_template_data: Optional[Mapping[str, Any]] = None,
target_python_version: PythonVersion = PythonVersion.PY_37,
text: Optional[str] = None,
result: Optional[List[DataModel]] = None,
Expand All @@ -37,6 +40,12 @@ def __init__(
if filename or text
else None
)
self.custom_template_dir = (
Path(custom_template_dir).expanduser().resolve()
if custom_template_dir is not None
else None
)
self.extra_template_data = extra_template_data

super().__init__(
data_model_type,
Expand Down Expand Up @@ -105,6 +114,8 @@ def parse_all_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]:
base_classes=[b.type for b in base_classes],
auto_import=False,
custom_base_class=self.base_class,
custom_template_dir=self.custom_template_dir,
extra_template_data=self.extra_template_data,
)
self.append_result(data_model_type)

Expand Down Expand Up @@ -178,7 +189,11 @@ def parse_object_fields(self, obj: JsonSchemaObject) -> List[DataModelField]:
def parse_object(self, name: str, obj: JsonSchemaObject) -> None:
fields = self.parse_object_fields(obj)
data_model_type = self.data_model_type(
name, fields=fields, custom_base_class=self.base_class
name,
fields=fields,
custom_base_class=self.base_class,
custom_template_dir=self.custom_template_dir,
extra_template_data=self.extra_template_data,
)
self.append_result(data_model_type)

Expand Down Expand Up @@ -226,7 +241,11 @@ def parse_array_fields(
def parse_array(self, name: str, obj: JsonSchemaObject) -> None:
fields, item_obj_names = self.parse_array_fields(name, obj)
data_model_root = self.data_model_root_type(
name, fields, custom_base_class=self.base_class
name,
fields,
custom_base_class=self.base_class,
custom_template_dir=self.custom_template_dir,
extra_template_data=self.extra_template_data,
)

self.append_result(data_model_root)
Expand All @@ -247,6 +266,8 @@ def parse_root_type(self, name: str, obj: JsonSchemaObject) -> None:
name,
[self.data_model_field_type(data_types=types, required=not obj.nullable)],
custom_base_class=self.base_class,
custom_template_dir=self.custom_template_dir,
extra_template_data=self.extra_template_data,
)
self.append_result(data_model_root_type)

Expand Down
5 changes: 5 additions & 0 deletions tests/data/extra_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"Pet": {
"comment": "1 2, 1 2, this is just a pet"
}
}
14 changes: 14 additions & 0 deletions tests/data/templates/BaseModel.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comment }}{% endif %}
{%- if not fields %}
pass
{%- endif %}
{%- for field in fields -%}
{%- if field.required %}
{{ field.name }}: {{ field.type_hint }}
{%- else %}
{{ field.name }}: {{ field.type_hint }} = {{ field.default }}
{%- endif %}
{%- endfor -%}
Loading

0 comments on commit 5697b69

Please sign in to comment.