Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow default_object_fields to override object fields #181

Merged
merged 2 commits into from
Jul 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion apischema/objects/getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Any,
Callable,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
Expand All @@ -23,14 +24,23 @@

@cache
def object_fields(
tp: AnyType, deserialization: bool = False, serialization: bool = False
tp: AnyType,
deserialization: bool = False,
serialization: bool = False,
default: Optional[
Callable[[type], Optional[Sequence[ObjectField]]]
] = ObjectVisitor._default_fields,
) -> Mapping[str, ObjectField]:
class GetFields(ObjectVisitor[Sequence[ObjectField]]):
def _skip_field(self, field: ObjectField) -> bool:
return (field.skip.deserialization and serialization) or (
field.skip.serialization and deserialization
)

@staticmethod
def _default_fields(cls: type) -> Optional[Sequence[ObjectField]]:
return None if default is None else default(cls)

def object(
self, cls: Type, fields: Sequence[ObjectField]
) -> Sequence[ObjectField]:
Expand Down
46 changes: 29 additions & 17 deletions apischema/objects/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ def _field_conversion(self, field: ObjectField) -> Optional[AnyConversion]:
def _skip_field(self, field: ObjectField) -> bool:
raise NotImplementedError

@staticmethod
def _default_fields(cls: type) -> Optional[Sequence[ObjectField]]:
from apischema import settings

return settings.default_object_fields(cls)

def _override_fields(
self, tp: AnyType, fields: Sequence[ObjectField]
) -> Sequence[ObjectField]:

origin = get_origin_or_type(tp)
if isinstance(origin, type):
default_fields = self._default_fields(origin)
if default_fields is not None:
if get_args(tp):
sub = dict(zip(get_parameters(origin), get_args(tp)))
default_fields = [
replace(f, type=substitute_type_vars(f.type, sub))
for f in default_fields
]
return default_fields
return fields

def _object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result:
fields = [f for f in fields if not self._skip_field(f)]
aliaser = get_class_aliaser(get_origin_or_type(tp))
Expand All @@ -77,7 +100,7 @@ def dataclass(
for name in types
if name in by_name and by_name[name].kind != self._field_kind_filtered
]
return self._object(tp, object_fields)
return self._object(tp, self._override_fields(tp, object_fields))

def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result:
raise NotImplementedError
Expand All @@ -89,7 +112,7 @@ def named_tuple(
ObjectField(name, type_, name not in defaults, default=defaults.get(name))
for name, type_ in types.items()
]
return self._object(tp, fields)
return self._object(tp, self._override_fields(tp, fields))

def typed_dict(
self, tp: AnyType, types: Mapping[str, AnyType], required_keys: Collection[str]
Expand All @@ -98,23 +121,12 @@ def typed_dict(
ObjectField(name, type_, name in required_keys, default=Undefined)
for name, type_ in types.items()
]
return self._object(tp, fields)
return self._object(tp, self._override_fields(tp, fields))

def unsupported(self, tp: AnyType) -> Result:
from apischema import settings

origin = get_origin_or_type(tp)
if isinstance(origin, type):
fields = settings.default_object_fields(origin)
if fields is not None:
if get_args(tp):
sub = dict(zip(get_parameters(origin), get_args(tp)))
fields = [
replace(f, type=substitute_type_vars(f.type, sub))
for f in fields
]
return self._object(origin, fields)
return super().unsupported(tp)
dummy: list = []
fields = self._override_fields(tp, dummy)
return super().unsupported(tp) if fields is dummy else self._object(tp, fields)


class DeserializationObjectVisitor(ObjectVisitor[Result]):
Expand Down
3 changes: 3 additions & 0 deletions docs/data_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ Thus, support of dataclass-like types (*attrs*, *SQLAlchemy* traditional mappers

Another way to set object fields is to directly modify *apischema* default behavior, using `apischema.settings.default_object_fields`.

!!! note
`set_object_fields`/`settings.default_object_fields` can be used to override existing fields. Current fields can be retrieved using `apischema.objects.object_fields`.

```python
from collections.abc import Sequence
from typing import Optional
Expand Down
30 changes: 30 additions & 0 deletions tests/integration/test_object_fields_overriding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import sys
from dataclasses import dataclass, replace
from typing import Optional

from pytest import mark, raises

from apischema import ValidationError, deserialize, serialize
from apischema.metadata import none_as_undefined
from apischema.objects import object_fields, set_object_fields


@dataclass
class Foo:
bar: Optional[str] = None


@mark.skipif(sys.version_info < (3, 8), reason="dataclasses.replace bug with InitVar")
def test_object_fields_overriding():
set_object_fields(Foo, [])
assert serialize(Foo, Foo()) == {}
set_object_fields(
Foo,
[
replace(f, metadata=none_as_undefined | f.metadata)
for f in object_fields(Foo, default=None).values()
],
)
assert serialize(Foo, Foo()) == {}
with raises(ValidationError):
deserialize(Foo, {"bar": None})