diff --git a/apischema/objects/getters.py b/apischema/objects/getters.py index 1fab1270..b2220152 100644 --- a/apischema/objects/getters.py +++ b/apischema/objects/getters.py @@ -3,6 +3,7 @@ Any, Callable, Mapping, + Optional, Sequence, Type, TypeVar, @@ -23,7 +24,12 @@ @cache def object_fields( - tp: AnyType, deserialization: bool = False, serialization: bool = False + tp: AnyType, + deserialization: bool = False, + serialization: bool = False, + default: Optional[ + Callable[[type], Optional[Sequence[ObjectField]]] + ] = ObjectVisitor._default_fields, ) -> Mapping[str, ObjectField]: class GetFields(ObjectVisitor[Sequence[ObjectField]]): def _skip_field(self, field: ObjectField) -> bool: @@ -31,6 +37,10 @@ def _skip_field(self, field: ObjectField) -> bool: field.skip.serialization and deserialization ) + @staticmethod + def _default_fields(cls: type) -> Optional[Sequence[ObjectField]]: + return None if default is None else default(cls) + def object( self, cls: Type, fields: Sequence[ObjectField] ) -> Sequence[ObjectField]: diff --git a/apischema/objects/visitor.py b/apischema/objects/visitor.py index d7704e74..9573564c 100644 --- a/apischema/objects/visitor.py +++ b/apischema/objects/visitor.py @@ -53,6 +53,29 @@ def _field_conversion(self, field: ObjectField) -> Optional[AnyConversion]: def _skip_field(self, field: ObjectField) -> bool: raise NotImplementedError + @staticmethod + def _default_fields(cls: type) -> Optional[Sequence[ObjectField]]: + from apischema import settings + + return settings.default_object_fields(cls) + + def _override_fields( + self, tp: AnyType, fields: Sequence[ObjectField] + ) -> Sequence[ObjectField]: + + origin = get_origin_or_type(tp) + if isinstance(origin, type): + default_fields = self._default_fields(origin) + if default_fields is not None: + if get_args(tp): + sub = dict(zip(get_parameters(origin), get_args(tp))) + default_fields = [ + replace(f, type=substitute_type_vars(f.type, sub)) + for f in default_fields + ] + return default_fields + return fields + def _object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result: fields = [f for f in fields if not self._skip_field(f)] aliaser = get_class_aliaser(get_origin_or_type(tp)) @@ -77,7 +100,7 @@ def dataclass( for name in types if name in by_name and by_name[name].kind != self._field_kind_filtered ] - return self._object(tp, object_fields) + return self._object(tp, self._override_fields(tp, object_fields)) def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result: raise NotImplementedError @@ -89,7 +112,7 @@ def named_tuple( ObjectField(name, type_, name not in defaults, default=defaults.get(name)) for name, type_ in types.items() ] - return self._object(tp, fields) + return self._object(tp, self._override_fields(tp, fields)) def typed_dict( self, tp: AnyType, types: Mapping[str, AnyType], required_keys: Collection[str] @@ -98,23 +121,12 @@ def typed_dict( ObjectField(name, type_, name in required_keys, default=Undefined) for name, type_ in types.items() ] - return self._object(tp, fields) + return self._object(tp, self._override_fields(tp, fields)) def unsupported(self, tp: AnyType) -> Result: - from apischema import settings - - origin = get_origin_or_type(tp) - if isinstance(origin, type): - fields = settings.default_object_fields(origin) - if fields is not None: - if get_args(tp): - sub = dict(zip(get_parameters(origin), get_args(tp))) - fields = [ - replace(f, type=substitute_type_vars(f.type, sub)) - for f in fields - ] - return self._object(origin, fields) - return super().unsupported(tp) + dummy: list = [] + fields = self._override_fields(tp, dummy) + return super().unsupported(tp) if fields is dummy else self._object(tp, fields) class DeserializationObjectVisitor(ObjectVisitor[Result]): diff --git a/docs/data_model.md b/docs/data_model.md index cc8ec0bd..4a146070 100644 --- a/docs/data_model.md +++ b/docs/data_model.md @@ -219,6 +219,9 @@ Thus, support of dataclass-like types (*attrs*, *SQLAlchemy* traditional mappers Another way to set object fields is to directly modify *apischema* default behavior, using `apischema.settings.default_object_fields`. +!!! note + `set_object_fields`/`settings.default_object_fields` can be used to override existing fields. Current fields can be retrieved using `apischema.objects.object_fields`. + ```python from collections.abc import Sequence from typing import Optional diff --git a/tests/integration/test_object_fields_overriding.py b/tests/integration/test_object_fields_overriding.py new file mode 100644 index 00000000..88f8f654 --- /dev/null +++ b/tests/integration/test_object_fields_overriding.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass, replace +from typing import Optional + +from pytest import raises + +from apischema import ValidationError, deserialize, serialize +from apischema.metadata import none_as_undefined +from apischema.objects import object_fields, set_object_fields + + +@dataclass +class Foo: + bar: Optional[str] = None + + +def test_object_fields_overriding(): + set_object_fields(Foo, []) + assert serialize(Foo, Foo()) == {} + set_object_fields( + Foo, + [ + replace(f, metadata=none_as_undefined | f.metadata) + for f in object_fields(Foo, default=None).values() + ], + ) + assert serialize(Foo, Foo()) == {} + with raises(ValidationError): + deserialize(Foo, {"bar": None})