From 347b330750f4329cac0daec80447d0ece74e6092 Mon Sep 17 00:00:00 2001 From: wyfo Date: Sun, 18 Jul 2021 21:11:09 +0200 Subject: [PATCH] Remove Conversion (de)serialization parameters (coerce/check_type/etc.) (#177) --- apischema/conversions/conversions.py | 12 +---- apischema/conversions/dataclass_models.py | 11 +---- apischema/conversions/visitor.py | 2 + apischema/deserialization/__init__.py | 53 +++++++++-------------- apischema/deserialization/coercion.py | 13 ++---- apischema/serialization/__init__.py | 43 ++++++++---------- 6 files changed, 48 insertions(+), 86 deletions(-) diff --git a/apischema/conversions/conversions.py b/apischema/conversions/conversions.py index d5e62c44..639ec1a8 100644 --- a/apischema/conversions/conversions.py +++ b/apischema/conversions/conversions.py @@ -26,7 +26,7 @@ ) if TYPE_CHECKING: - from apischema.deserialization.coercion import Coerce + pass @dataclass(frozen=True) @@ -36,12 +36,6 @@ class Conversion: target: AnyType = None sub_conversion: Optional["AnyConversion"] = None inherited: Optional[bool] = None - additional_properties: Optional[bool] = None - check_type: Optional[bool] = None - coerce: Optional["Coerce"] = None - fall_back_on_any: Optional[bool] = None - fall_back_on_default: Optional[bool] = None - exclude_unset: Optional[bool] = None def __call__(self, *args, **kwargs): return self.converter(*args, **kwargs) @@ -118,8 +112,4 @@ def is_identity(conversion: ResolvedConversion) -> bool: conversion.converter == identity and conversion.source == conversion.target and conversion.sub_conversion is None - and conversion.additional_properties is None - and conversion.coerce is None - and conversion.fall_back_on_default is None - and conversion.exclude_unset is None ) diff --git a/apischema/conversions/dataclass_models.py b/apischema/conversions/dataclass_models.py index 50ca63a6..84b9bca8 100644 --- a/apischema/conversions/dataclass_models.py +++ b/apischema/conversions/dataclass_models.py @@ -70,16 +70,9 @@ def dataclass_model( check_model(origin, model) model_type = DataclassModel(origin, model, fields_only) - conversion = Conversion( - identity, - additional_properties=additional_properties, - coerce=coercion, - fall_back_on_default=fall_back_on_default, - exclude_unset=exclude_unset, + return Conversion(identity, source=model_type, target=origin), Conversion( + identity, source=origin, target=model_type ) - d_conv = replace(conversion, source=model_type, target=origin) - s_conv = replace(conversion, source=origin, target=model_type) - return d_conv, s_conv def has_model_origin(cls: Type) -> bool: diff --git a/apischema/conversions/visitor.py b/apischema/conversions/visitor.py index 067aa9b0..4d8d60f3 100644 --- a/apischema/conversions/visitor.py +++ b/apischema/conversions/visitor.py @@ -252,6 +252,8 @@ def __init__(self, default_conversion: DefaultConversion): ] = {} def _cache_key(self) -> Hashable: + """When other attributes are modified during visit, they can be used as + additional cache key""" return None def _recursive_result(self, lazy: Lazy[Result]) -> Result: diff --git a/apischema/deserialization/__init__.py b/apischema/deserialization/__init__.py index ba9d4b16..31d60524 100644 --- a/apischema/deserialization/__init__.py +++ b/apischema/deserialization/__init__.py @@ -8,7 +8,6 @@ Callable, Collection, Dict, - Hashable, List, Mapping, Optional, @@ -32,7 +31,7 @@ sub_conversion, ) from apischema.dependencies import get_dependent_required -from apischema.deserialization.coercion import Coerce, Coercer, get_coercer +from apischema.deserialization.coercion import Coerce, Coercer, wrap_coercer from apischema.deserialization.flattened import get_deserialization_flattened_aliases from apischema.json_schema.patterns import infer_pattern from apischema.json_schema.types import bad_type @@ -54,7 +53,6 @@ from apischema.utils import ( Lazy, PREFIX, - context_setter, deprecate_kwargs, get_origin_or_type, identity, @@ -141,20 +139,15 @@ def __init__( self, additional_properties: bool, aliaser: Aliaser, - coercion: bool, - coercer: Coercer, + coercer: Optional[Coercer], default_conversion: DefaultConversion, fall_back_on_default: bool, ): super().__init__(default_conversion) - self._additional_properties = additional_properties + self.additional_properties = additional_properties self.aliaser = aliaser - self._coerce = coercion - self._coercer = coercer - self._fall_back_on_default = fall_back_on_default - - def _cache_key(self) -> Hashable: - return self._coerce, self._coercer + self.coercer = coercer + self.fall_back_on_default = fall_back_on_default def _recursive_result( self, lazy: Lazy[DeserializationMethodFactory] @@ -325,8 +318,8 @@ def object( ) for f in fields ] - additional_properties = self._additional_properties - fall_back_on_default = self._fall_back_on_default + additional_properties = self.additional_properties + fall_back_on_default = self.fall_back_on_default def factory( constraints: Optional[Constraints], validators: Sequence[Validator] @@ -678,18 +671,10 @@ def _visit_conversion( next_conversion: Optional[AnyConversion], ) -> DeserializationMethodFactory: assert conversion - conv_factories = [] - for conv in conversion: - with context_setter(self) as setter: - if conv.additional_properties is not None: - setter._additional_properties = conv.additional_properties - if conv.fall_back_on_default is not None: - setter._fall_back_on_default = conv.fall_back_on_default - setter._coerce, setter._coercer = get_coercer( - conv.coerce, self._coerce, self._coercer - ) - sub_conv = sub_conversion(conv, next_conversion) - conv_factories.append(self.visit_with_conv(conv.source, sub_conv)) + conv_factories = [ + self.visit_with_conv(conv.source, sub_conversion(conv, next_conversion)) + for conv in conversion + ] def factory( constraints: Optional[Constraints], validators: Sequence[Validator] @@ -749,8 +734,8 @@ def visit_conversion( next_conversion: Optional[AnyConversion] = None, ) -> DeserializationMethodFactory: factory = super().visit_conversion(tp, conversion, dynamic, next_conversion) - if factory.coercer is None and self._coerce: - factory = replace(factory, coercer=self._coercer) + if self.coercer is not None and factory.coercer is None: + factory = replace(factory, coercer=self.coercer) if not dynamic: factory = factory.merge(get_constraints(get_schema(tp)), get_validators(tp)) if get_args(tp): @@ -805,14 +790,18 @@ def deserialization_method( ) -> DeserializationMethod: from apischema import settings - coerce, coercer = get_coercer( - coerce, settings.deserialization.coerce, settings.deserialization.coercer - ) + coerce = opt_or(coerce, settings.deserialization.coerce) + coercer: Optional[Coercer] + if callable(coerce): + coercer = wrap_coercer(coerce) + elif coerce: + coercer = settings.deserialization.coercer + else: + coercer = None return ( DeserializationMethodVisitor( opt_or(additional_properties, settings.additional_properties), opt_or(aliaser, settings.aliaser), - coerce, coercer, opt_or(default_conversion, settings.deserialization.default_conversion), opt_or(fall_back_on_default, settings.deserialization.fall_back_on_default), diff --git a/apischema/deserialization/coercion.py b/apischema/deserialization/coercion.py index c5c8abcd..faade167 100644 --- a/apischema/deserialization/coercion.py +++ b/apischema/deserialization/coercion.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, Type, TypeVar, Union from apischema.json_schema.types import bad_type from apischema.types import NoneType @@ -45,6 +45,9 @@ def coerce(cls: Type[T], data: Any) -> T: def wrap_coercer(coercer: Coercer) -> Coercer: + if coercer is coerce: + return coercer + @wraps(coercer) def wrapper(cls, data): try: @@ -58,11 +61,3 @@ def wrapper(cls, data): return result return wrapper - - -def get_coercer( - coerce: Optional[Coerce], default_coerce: bool, default_coercer: Coercer -) -> Tuple[bool, Coercer]: - if coerce is None: - return default_coerce, default_coercer - return bool(coerce), wrap_coercer(coerce) if callable(coerce) else default_coercer diff --git a/apischema/serialization/__init__.py b/apischema/serialization/__init__.py index 6f08c1c7..457b0620 100644 --- a/apischema/serialization/__init__.py +++ b/apischema/serialization/__init__.py @@ -33,7 +33,6 @@ from apischema.typing import is_new_type, is_type, is_type_var, is_typed_dict from apischema.utils import ( Lazy, - context_setter, deprecate_kwargs, get_origin_or_type, get_origin_or_type2, @@ -80,10 +79,10 @@ def __init__( super().__init__(default_conversion) self.additional_properties = additional_properties self.aliaser = aliaser + self.check_type = check_type + self.exclude_unset = exclude_unset + self.fall_back_on_any = fall_back_on_any self.pass_through_options = pass_through_options - self._fall_back_on_any = fall_back_on_any - self._check_type = check_type - self._exclude_unset = exclude_unset def _recursive_result(self, lazy: Lazy[SerializationMethod]) -> SerializationMethod: rec_method = None @@ -104,7 +103,7 @@ def pass_through(self, tp: AnyType) -> bool: aliaser=self.aliaser, conversions=self._conversions, default_conversion=self.default_conversion, - exclude_unset=self._exclude_unset, + exclude_unset=self.exclude_unset, options=self.pass_through_options, ) except (TypeError, Unsupported): # TypeError because tp can be unhashable @@ -115,11 +114,11 @@ def _factory(self) -> SerializationMethodFactory: return serialization_method_factory( self.additional_properties, self.aliaser, - self._check_type, + self.check_type, self._conversions, self.default_conversion, - self._exclude_unset, - self._fall_back_on_any, + self.exclude_unset, + self.fall_back_on_any, self.pass_through_options, ) @@ -132,9 +131,9 @@ def method(obj: Any) -> Any: return method def _wrap(self, cls: type, method: SerializationMethod) -> SerializationMethod: - if not self._check_type: + if not self.check_type: return method - fall_back_on_any, any_method = self._fall_back_on_any, self.any() + fall_back_on_any, any_method = self.fall_back_on_any, self.any() if is_typed_dict(cls): cls = Mapping @@ -260,7 +259,7 @@ def method( result[alias] = field_method(getattr(obj, name)) return result - if self._exclude_unset and support_fields_set(cls): + if self.exclude_unset and support_fields_set(cls): wrapped_exclude_unset = method def method(obj: Any) -> Any: @@ -339,9 +338,9 @@ def method(obj: Any) -> Any: serialize_elt(elt) for serialize_elt, elt in zip(elt_serializers, obj) ] - if self._check_type: + if self.check_type: wrapped = method - fall_back_on_any, as_list = self._fall_back_on_any, self._factory(list) + fall_back_on_any, as_list = self.fall_back_on_any, self._factory(list) def method(obj: Any) -> Any: if len(obj) == len(elt_serializers): @@ -361,7 +360,7 @@ def union(self, alternatives: Sequence[AnyType]) -> SerializationMethod: checks = [instance_checker(alt) for alt in alternatives if alt is not NoneType] methods_and_checks = list(zip(methods, checks)) none_check = None if NoneType in alternatives else NOT_NONE - fall_back_on_any, any_method = self._fall_back_on_any, self.any() + fall_back_on_any, any_method = self.fall_back_on_any, self.any() def method(obj: Any) -> Any: # Optional/Undefined optimization @@ -390,7 +389,7 @@ def unsupported(self, tp: AnyType) -> SerializationMethod: try: return super().unsupported(tp) except Unsupported: - if self._fall_back_on_any and is_type(tp): + if self.fall_back_on_any and is_type(tp): any_method = self.any() if issubclass(tp, Mapping): @@ -418,15 +417,9 @@ def _visit_conversion( dynamic: bool, next_conversion: Optional[AnyConversion], ) -> SerializationMethod: - with context_setter(self) as setter: - if conversion.fall_back_on_any is not None: - setter._fall_back_on_any = conversion.fall_back_on_any - if conversion.exclude_unset is not None: - setter._exclude_unset = conversion.exclude_unset - serialize_conv = self.visit_with_conv( - conversion.target, sub_conversion(conversion, next_conversion) - ) - + serialize_conv = self.visit_with_conv( + conversion.target, sub_conversion(conversion, next_conversion) + ) converter = cast(Converter, conversion.converter) if converter is identity: method = serialize_conv @@ -442,7 +435,7 @@ def method(obj: Any) -> Any: def visit(self, tp: AnyType) -> SerializationMethod: if tp is AliasedStr: return self._wrap(AliasedStr, self.aliaser) - elif not self._check_type and self.pass_through(tp): + elif not self.check_type and self.pass_through(tp): return identity else: return super().visit(tp)