Skip to content

Commit

Permalink
Allow default_object_fields to override object fields
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Jul 23, 2021
1 parent bfe4934 commit 98c2ca8
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 18 deletions.
12 changes: 11 additions & 1 deletion apischema/objects/getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Any,
Callable,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
Expand All @@ -23,14 +24,23 @@

@cache
def object_fields(
tp: AnyType, deserialization: bool = False, serialization: bool = False
tp: AnyType,
deserialization: bool = False,
serialization: bool = False,
default: Optional[
Callable[[type], Optional[Sequence[ObjectField]]]
] = ObjectVisitor._default_fields,
) -> Mapping[str, ObjectField]:
class GetFields(ObjectVisitor[Sequence[ObjectField]]):
def _skip_field(self, field: ObjectField) -> bool:
return (field.skip.deserialization and serialization) or (
field.skip.serialization and deserialization
)

@staticmethod
def _default_fields(cls: type) -> Optional[Sequence[ObjectField]]:
return None if default is None else default(cls)

def object(
self, cls: Type, fields: Sequence[ObjectField]
) -> Sequence[ObjectField]:
Expand Down
46 changes: 29 additions & 17 deletions apischema/objects/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ def _field_conversion(self, field: ObjectField) -> Optional[AnyConversion]:
def _skip_field(self, field: ObjectField) -> bool:
raise NotImplementedError

@staticmethod
def _default_fields(cls: type) -> Optional[Sequence[ObjectField]]:
from apischema import settings

return settings.default_object_fields(cls)

def _override_fields(
self, tp: AnyType, fields: Sequence[ObjectField]
) -> Sequence[ObjectField]:

origin = get_origin_or_type(tp)
if isinstance(origin, type):
default_fields = self._default_fields(origin)
if default_fields is not None:
if get_args(tp):
sub = dict(zip(get_parameters(origin), get_args(tp)))
default_fields = [
replace(f, type=substitute_type_vars(f.type, sub))
for f in default_fields
]
return default_fields
return fields

def _object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result:
fields = [f for f in fields if not self._skip_field(f)]
aliaser = get_class_aliaser(get_origin_or_type(tp))
Expand All @@ -77,7 +100,7 @@ def dataclass(
for name in types
if name in by_name and by_name[name].kind != self._field_kind_filtered
]
return self._object(tp, object_fields)
return self._object(tp, self._override_fields(tp, object_fields))

def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result:
raise NotImplementedError
Expand All @@ -89,7 +112,7 @@ def named_tuple(
ObjectField(name, type_, name not in defaults, default=defaults.get(name))
for name, type_ in types.items()
]
return self._object(tp, fields)
return self._object(tp, self._override_fields(tp, fields))

def typed_dict(
self, tp: AnyType, types: Mapping[str, AnyType], required_keys: Collection[str]
Expand All @@ -98,23 +121,12 @@ def typed_dict(
ObjectField(name, type_, name in required_keys, default=Undefined)
for name, type_ in types.items()
]
return self._object(tp, fields)
return self._object(tp, self._override_fields(tp, fields))

def unsupported(self, tp: AnyType) -> Result:
from apischema import settings

origin = get_origin_or_type(tp)
if isinstance(origin, type):
fields = settings.default_object_fields(origin)
if fields is not None:
if get_args(tp):
sub = dict(zip(get_parameters(origin), get_args(tp)))
fields = [
replace(f, type=substitute_type_vars(f.type, sub))
for f in fields
]
return self._object(origin, fields)
return super().unsupported(tp)
dummy: list = []
fields = self._override_fields(tp, dummy)
return super().unsupported(tp) if fields is dummy else self._object(tp, fields)


class DeserializationObjectVisitor(ObjectVisitor[Result]):
Expand Down
3 changes: 3 additions & 0 deletions docs/data_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ Thus, support of dataclass-like types (*attrs*, *SQLAlchemy* traditional mappers

Another way to set object fields is to directly modify *apischema* default behavior, using `apischema.settings.default_object_fields`.

!!! note
`set_object_fields`/`settings.default_object_fields` can be used to override existing fields. Current fields can be retrieved using `apischema.objects.object_fields`.

```python
from collections.abc import Sequence
from typing import Optional
Expand Down
28 changes: 28 additions & 0 deletions tests/integration/test_object_fields_overriding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from dataclasses import dataclass, replace
from typing import Optional

from pytest import raises

from apischema import ValidationError, deserialize, serialize
from apischema.metadata import none_as_undefined
from apischema.objects import object_fields, set_object_fields


@dataclass
class Foo:
bar: Optional[str] = None


def test_object_fields_overriding():
set_object_fields(Foo, [])
assert serialize(Foo, Foo()) == {}
set_object_fields(
Foo,
[
replace(f, metadata=none_as_undefined | f.metadata)
for f in object_fields(Foo, default=None).values()
],
)
assert serialize(Foo, Foo()) == {}
with raises(ValidationError):
deserialize(Foo, {"bar": None})

0 comments on commit 98c2ca8

Please sign in to comment.