Skip to content

Commit

Permalink
dataclass generator improvements (#2102)
Browse files Browse the repository at this point in the history
* Use apply_discriminator_type for dataclasses

Allow dataclass models to be properly generated with discriminator field

* Fix dataclass inheritance

Thanks to keyword only, dataclass models can use inheritance and no have issues with default values

* Support datetime types in dataclass fields

applying `--output-datetime-class` from #2100 to dataclass to map date, time and date time to the python `datetime` objects instead of strings.

* fix unittest

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix unittest

* fix unittest

---------

Co-authored-by: Koudai Aono <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 16, 2024
1 parent 2df133c commit d0c0f16
Show file tree
Hide file tree
Showing 29 changed files with 372 additions and 29 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,12 @@ Model customization:
--enable-version-header
Enable package version on file headers
--keep-model-order Keep generated models'' order
--keyword-only Defined models as keyword only (for example
dataclass(kw_only=True)).
--output-datetime-class {datetime,AwareDatetime,NaiveDatetime}
Choose Datetime class between AwareDatetime, NaiveDatetime or
datetime. Each output model has its default mapping, and only
pydantic and dataclass support this override"
--reuse-model Reuse models on the field when a module has the model with the same
content
--target-python-version {3.6,3.7,3.8,3.9,3.10,3.11,3.12}
Expand All @@ -462,8 +468,6 @@ Model customization:
--use-schema-description
Use schema description to populate class docstring
--use-title-as-name use titles as class names of models

----output-datetime-class Choose Datetime class between AwareDatetime, NaiveDatetime or datetime, default: "datetime"

Template customization:
--aliases ALIASES Alias mapping file
Expand Down
4 changes: 3 additions & 1 deletion datamodel_code_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ def generate(
treat_dots_as_module: bool = False,
use_exact_imports: bool = False,
union_mode: Optional[UnionMode] = None,
output_datetime_class: DataModelType = DatetimeClassType.Datetime,
output_datetime_class: Optional[DatetimeClassType] = None,
keyword_only: bool = False,
) -> None:
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
if isinstance(input_, str):
Expand Down Expand Up @@ -476,6 +477,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]:
use_exact_imports=use_exact_imports,
default_field_extras=default_field_extras,
target_datetime_class=output_datetime_class,
keyword_only=keyword_only,
**kwargs,
)

Expand Down
31 changes: 29 additions & 2 deletions datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def validate_use_generic_container_types(
target_python_version: PythonVersion = values['target_python_version']
if target_python_version == target_python_version.PY_36:
raise Error(
f'`--use-generic-container-types` can not be used with `--target-python_version` {target_python_version.PY_36.value}.\n'
f'`--use-generic-container-types` can not be used with `--target-python-version` {target_python_version.PY_36.value}.\n'
' The version will be not supported in a future version'
)
return values
Expand All @@ -184,6 +184,31 @@ def validate_custom_file_header(cls, values: Dict[str, Any]) -> Dict[str, Any]:
) # pragma: no cover
return values

@model_validator(mode='after')
def validate_keyword_only(cls, values: Dict[str, Any]) -> Dict[str, Any]:
python_target: PythonVersion = values.get('target_python_version')
if values.get('keyword_only') and not python_target.has_kw_only_dataclass:
raise Error(
f'`--keyword-only` requires `--target-python-version` {PythonVersion.PY_310.value} or higher.'
)
return values

@model_validator(mode='after')
def validate_output_datetime_class(cls, values: Dict[str, Any]) -> Dict[str, Any]:
datetime_class_type: Optional[DatetimeClassType] = values.get(
'output_datetime_class'
)
if (
datetime_class_type
and datetime_class_type is not DatetimeClassType.Datetime
and values.get('output_model_type') == DataModelType.DataclassesDataclass
):
raise Error(
'`--output-datetime-class` only allows "datetime" for '
f'`--output-model-type` {DataModelType.DataclassesDataclass.value}'
)
return values

# Pydantic 1.5.1 doesn't support each_item=True correctly
@field_validator('http_headers', mode='before')
def validate_http_headers(cls, value: Any) -> Optional[List[Tuple[str, str]]]:
Expand Down Expand Up @@ -314,7 +339,8 @@ def validate_root(cls, values: Any) -> Any:
treat_dot_as_module: bool = False
use_exact_imports: bool = False
union_mode: Optional[UnionMode] = None
output_datetime_class: DatetimeClassType = DatetimeClassType.Datetime
output_datetime_class: Optional[DatetimeClassType] = None
keyword_only: bool = False

def merge_args(self, args: Namespace) -> None:
set_args = {
Expand Down Expand Up @@ -515,6 +541,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
use_exact_imports=config.use_exact_imports,
union_mode=config.union_mode,
output_datetime_class=config.output_datetime_class,
keyword_only=config.keyword_only,
)
return Exit.OK
except InvalidClassNameError as e:
Expand Down
10 changes: 9 additions & 1 deletion datamodel_code_generator/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def start_section(self, heading: Optional[str]) -> None:
action='store_true',
default=None,
)
model_options.add_argument(
'--keyword-only',
help='Defined models as keyword only (for example dataclass(kw_only=True)).',
action='store_true',
default=None,
)
model_options.add_argument(
'--reuse-model',
help='Reuse models on the field when a module has the model with the same content',
Expand Down Expand Up @@ -194,8 +200,10 @@ def start_section(self, heading: Optional[str]) -> None:
)
model_options.add_argument(
'--output-datetime-class',
help='Choose Datetime class between AwareDatetime, NaiveDatetime or datetime, default: "datetime"',
help='Choose Datetime class between AwareDatetime, NaiveDatetime or datetime. '
'Each output model has its default mapping (for example pydantic: datetime, dataclass: str, ...)',
choices=[i.value for i in DatetimeClassType],
default=None,
)

# ======================================================================================
Expand Down
4 changes: 4 additions & 0 deletions datamodel_code_generator/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def has_typed_dict(self) -> bool:
def has_typed_dict_non_required(self) -> bool:
return self._is_py_311_or_later

@property
def has_kw_only_dataclass(self) -> bool:
return self._is_py_310_or_later


if TYPE_CHECKING:

Expand Down
2 changes: 1 addition & 1 deletion datamodel_code_generator/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_data_model_types(
data_model=dataclass.DataClass,
root_model=rootmodel.RootModel,
field_model=dataclass.DataModelField,
data_type_manager=DataTypeManager,
data_type_manager=dataclass.DataTypeManager,
dump_resolve_reference_action=None,
)
elif data_model_type == DataModelType.TypingTypedDict:
Expand Down
3 changes: 3 additions & 0 deletions datamodel_code_generator/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def __init__(
description: Optional[str] = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
) -> None:
self.keyword_only = keyword_only
if not self.TEMPLATE_FILE_PATH:
raise Exception('TEMPLATE_FILE_PATH is undefined')

Expand Down Expand Up @@ -452,6 +454,7 @@ def render(self, *, class_name: Optional[str] = None) -> str:
base_class=self.base_class,
methods=self.methods,
description=self.description,
keyword_only=self.keyword_only,
**self.extra_template_data,
)
return response
69 changes: 65 additions & 4 deletions datamodel_code_generator/model/dataclass.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
from pathlib import Path
from typing import Any, ClassVar, DefaultDict, Dict, List, Optional, Set, Tuple

from datamodel_code_generator.imports import Import
from typing import (
Any,
ClassVar,
DefaultDict,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
)

from datamodel_code_generator import DatetimeClassType, PythonVersion
from datamodel_code_generator.imports import (
IMPORT_DATE,
IMPORT_DATETIME,
IMPORT_TIME,
IMPORT_TIMEDELTA,
Import,
)
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.base import UNDEFINED
from datamodel_code_generator.model.imports import IMPORT_DATACLASS, IMPORT_FIELD
from datamodel_code_generator.model.pydantic.base_model import Constraints
from datamodel_code_generator.model.types import DataTypeManager as _DataTypeManager
from datamodel_code_generator.model.types import type_map_factory
from datamodel_code_generator.reference import Reference
from datamodel_code_generator.types import chain_as_tuple
from datamodel_code_generator.types import DataType, StrictTypes, Types, chain_as_tuple


def _has_field_assignment(field: DataModelFieldBase) -> bool:
Expand Down Expand Up @@ -36,6 +55,7 @@ def __init__(
description: Optional[str] = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
) -> None:
super().__init__(
reference=reference,
Expand All @@ -50,6 +70,7 @@ def __init__(
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
)


Expand Down Expand Up @@ -118,3 +139,43 @@ def __str__(self) -> str:
f'{k}={v if k == "default_factory" else repr(v)}' for k, v in data.items()
]
return f'field({", ".join(kwargs)})'


class DataTypeManager(_DataTypeManager):
def __init__(
self,
python_version: PythonVersion = PythonVersion.PY_38,
use_standard_collections: bool = False,
use_generic_container_types: bool = False,
strict_types: Optional[Sequence[StrictTypes]] = None,
use_non_positive_negative_number_constrained_types: bool = False,
use_union_operator: bool = False,
use_pendulum: bool = False,
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
):
super().__init__(
python_version,
use_standard_collections,
use_generic_container_types,
strict_types,
use_non_positive_negative_number_constrained_types,
use_union_operator,
use_pendulum,
target_datetime_class,
)

datetime_map = (
{
Types.time: self.data_type.from_import(IMPORT_TIME),
Types.date: self.data_type.from_import(IMPORT_DATE),
Types.date_time: self.data_type.from_import(IMPORT_DATETIME),
Types.timedelta: self.data_type.from_import(IMPORT_TIMEDELTA),
}
if target_datetime_class is DatetimeClassType.Datetime
else {}
)

self.type_map: Dict[Types, DataType] = {
**type_map_factory(self.data_type),
**datetime_map,
}
2 changes: 2 additions & 0 deletions datamodel_code_generator/model/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
type_: Optional[Types] = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
):
super().__init__(
reference=reference,
Expand All @@ -61,6 +62,7 @@ def __init__(
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
)

if not base_classes and type_:
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/model/msgspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
description: Optional[str] = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
) -> None:
super().__init__(
reference=reference,
Expand All @@ -100,6 +101,7 @@ def __init__(
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
)


Expand Down
4 changes: 4 additions & 0 deletions datamodel_code_generator/model/pydantic/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(
description: Optional[str] = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
) -> None:
methods: List[str] = [field.method for field in fields if field.method]

Expand All @@ -241,6 +242,7 @@ def __init__(
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
)

@cached_property
Expand Down Expand Up @@ -275,6 +277,7 @@ def __init__(
description: Optional[str] = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
) -> None:
super().__init__(
reference=reference,
Expand All @@ -288,6 +291,7 @@ def __init__(
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
)
config_parameters: Dict[str, Any] = {}

Expand Down
2 changes: 1 addition & 1 deletion datamodel_code_generator/model/pydantic/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __init__(
use_non_positive_negative_number_constrained_types: bool = False,
use_union_operator: bool = False,
use_pendulum: bool = False,
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
target_datetime_class: Optional[DatetimeClassType] = None,
):
super().__init__(
python_version,
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/model/pydantic_v2/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(
description: Optional[str] = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
) -> None:
super().__init__(
reference=reference,
Expand All @@ -196,6 +197,7 @@ def __init__(
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
)
config_parameters: Dict[str, Any] = {}

Expand Down
4 changes: 2 additions & 2 deletions datamodel_code_generator/model/pydantic_v2/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import ClassVar, Dict, Sequence, Type
from typing import ClassVar, Dict, Optional, Sequence, Type

from datamodel_code_generator.format import DatetimeClassType
from datamodel_code_generator.model.pydantic import DataTypeManager as _DataTypeManager
Expand All @@ -20,7 +20,7 @@ def type_map_factory(
data_type: Type[DataType],
strict_types: Sequence[StrictTypes],
pattern_key: str,
target_datetime_class: DatetimeClassType,
target_datetime_class: Optional[DatetimeClassType] = None,
) -> Dict[Types, DataType]:
result = {
**super().type_map_factory(
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/model/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
description: Optional[str] = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
):
extra_template_data = extra_template_data or defaultdict(dict)

Expand Down Expand Up @@ -75,4 +76,5 @@ def __init__(
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
)
2 changes: 1 addition & 1 deletion datamodel_code_generator/model/template/dataclass.jinja2
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
@dataclass
@dataclass{%- if keyword_only -%}(kw_only=True){%- endif %}
{%- if base_class %}
class {{ class_name }}({{ base_class }}):
{%- else %}
Expand Down
Loading

0 comments on commit d0c0f16

Please sign in to comment.