Skip to content

Commit

Permalink
Add InitVar support for merged dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Sep 23, 2020
1 parent 75ae4dd commit f4bb73d
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 42 deletions.
18 changes: 17 additions & 1 deletion apischema/dataclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import sys
from dataclasses import is_dataclass, replace as replace_
from dataclasses import ( # type: ignore
Field,
is_dataclass,
replace as replace_,
_FIELDS,
_FIELD_CLASSVAR,
)
from typing import Mapping, Type

if sys.version_info <= (3, 7):
is_dataclass_ = is_dataclass
Expand All @@ -16,3 +23,12 @@ def replace(*args, **changes):
if hasattr(obj, FIELDS_SET_ATTR):
set_fields(result, *fields_set(obj), *changes, overwrite=True)
return result


def fields_items(cls: Type) -> Mapping[str, Field]:
assert is_dataclass(cls)
return {
name: field
for name, field in getattr(cls, _FIELDS).items()
if field._field_type != _FIELD_CLASSVAR
}
35 changes: 15 additions & 20 deletions apischema/dataclasses/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Deserialization,
Serialization,
)
from apischema.dataclasses import fields_items
from apischema.dependencies import DependentRequired
from apischema.metadata.keys import (
ALIAS_METADATA,
Expand Down Expand Up @@ -135,7 +136,7 @@ def _from_aggregate(aggregate_cache: AggregateFieldCache) -> FieldCache:
for field in aggregate_fields:
metadata = field.base_field.metadata
if MERGED_METADATA in metadata:
merged_fields.append((frozenset(_merged_aliases(field.type)), field))
merged_fields.append((_deserialization_merged_aliases(field.type), field))
else:
pattern = metadata[PROPERTIES_METADATA]
if pattern is not None:
Expand Down Expand Up @@ -343,24 +344,23 @@ def _serialization(
)


def _merged_aliases(cls: Type) -> Iterable[str]:
def _deserialization_merged_aliases(cls: Type) -> AbstractSet[str]:
"""Return all aliases used in cls deserialization."""
assert dataclasses.is_dataclass(cls)
types = get_type_hints(cls, include_extras=True)
for field in dataclasses.fields(cls):
result: Set[str] = set()
for field in fields_items(cls).values():
if not field.init:
continue
if MERGED_METADATA in field.metadata:
# No need to check overlapping here because it will be checked
# when merged dataclass will be cached
yield from _merged_aliases(types[field.name])
result |= _deserialization_merged_aliases(types[field.name])
elif PROPERTIES_METADATA in field.metadata:
raise TypeError("Merged dataclass cannot have properties field")
else:
yield field.metadata.get(ALIAS_METADATA, field.name)


def _check_fields_overlap(present: Set[str], other: AbstractSet[str]):
if present & other:
raise TypeError(f"Merged fields {present & other} overlap")
present.update(other)
result.add(field.metadata.get(ALIAS_METADATA, field.name))
return result


def _update_dependencies(cls: AnyType, all_fields: Mapping[str, Field]):
Expand Down Expand Up @@ -390,8 +390,8 @@ def _update_dependencies(cls: AnyType, all_fields: Mapping[str, Field]):


def _filter_by_kind(field_list: Iterable[F], kind: FieldKind) -> Sequence[F]:
fields = [(elt, elt[1] if isinstance(elt, tuple) else elt) for elt in field_list]
return [elt for elt, field in fields if field.kind != kind]
fields = [elt[1] if isinstance(elt, tuple) else elt for elt in field_list]
return [elt for elt, field in zip(field_list, fields) if field.kind != kind]


@dataclasses.dataclass
Expand Down Expand Up @@ -421,10 +421,7 @@ def cache_fields(cls: Type):
types = get_type_hints(cls, include_extras=True)
lists = FieldLists(cls)
all_fields: Dict[str, Field] = {}
all_merged_aliases: Set[str] = set()
for field in getattr(cls, dataclasses._FIELDS).values(): # type: ignore
if field._field_type == dataclasses._FIELD_CLASSVAR: # type: ignore
continue
for field in fields_items(cls).values():
metadata = field.metadata
if SKIP_METADATA in metadata:
continue
Expand Down Expand Up @@ -492,8 +489,7 @@ def cache_fields(cls: Type):
raise TypeError(
f"{error_prefix}Merged field must have a dataclass type"
)
merged_aliases: AbstractSet[str] = frozenset(_merged_aliases(type_))
_check_fields_overlap(all_merged_aliases, merged_aliases)
merged_aliases = _deserialization_merged_aliases(type_)
lists.merged.append((merged_aliases, new_field))
elif PROPERTIES_METADATA in metadata:
if any(key in metadata for key in INCOMPATIBLE_WITH_PROPERTIES):
Expand All @@ -505,7 +501,6 @@ def cache_fields(cls: Type):
lists.pattern.append((pattern, new_field))
else:
lists.normal.append(new_field)
_check_fields_overlap(all_merged_aliases, all_fields.keys())
_update_dependencies(cls, all_fields)
_deserialization_fields[cls] = lists.remove_kind(FieldKind.NO_INIT)
_aggregate_serialization_fields[cls] = _to_aggregate(
Expand Down
29 changes: 9 additions & 20 deletions apischema/validation/mock.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from dataclasses import ( # type: ignore
Field,
MISSING,
_FIELDS,
_FIELD_CLASSVAR,
dataclass,
)
from dataclasses import MISSING, dataclass
from functools import partial
from types import FunctionType, MethodType
from typing import Any, Mapping, Optional, TYPE_CHECKING, Type, TypeVar

from apischema.dataclasses import fields_items
from apischema.fields import FIELDS_SET_ATTR
from apischema.utils import get_default

Expand Down Expand Up @@ -37,15 +32,12 @@ def __getattribute__(self, name: str) -> Any:
if name in fields:
return fields[name]
cls = super().__getattribute__("cls")
cls_fields: Mapping[str, Field] = getattr(cls, _FIELDS)
cls_fields = fields_items(cls)
if name in cls_fields:
if cls_fields[name]._field_type == _FIELD_CLASSVAR: # type: ignore
return getattr(cls, name)
else:
try:
return get_default(cls_fields[name])
except NotImplementedError:
raise NonTrivialDependency(name)
try:
return get_default(cls_fields[name])
except NotImplementedError:
raise NonTrivialDependency(name) from None
if name == "__class__":
return cls
if name == "__dict__":
Expand All @@ -54,11 +46,8 @@ def __getattribute__(self, name: str) -> Any:
**{
name: get_default(field)
for name, field in cls_fields.items()
if field._field_type != _FIELD_CLASSVAR # type: ignore
and (
field.default is not MISSING
or field.default_factory is not MISSING # type: ignore
)
if field.default is not MISSING
or field.default_factory is not MISSING # type: ignore
},
FIELDS_SET_ATTR: set(fields),
}
Expand Down
4 changes: 4 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 0.7.4

- Add `InitVar` support for merged dataclasses.

## 0.7.3

- Fix bugs in settings global default conversion and coercer assignation.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="apischema",
version="0.7.3",
version="0.7.4",
url="https://github.com/wyfo/apischema",
author="Joseph Perez",
author_email="[email protected]",
Expand Down

0 comments on commit f4bb73d

Please sign in to comment.