From d2c365c2b8ff1bbe47b8d6b4f170c2e3bc8d6f22 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Sun, 31 Oct 2021 18:57:01 +0100 Subject: [PATCH 01/15] Cythonize (de)serialization methods and enable errors customization Closures have been converted into classes with one method. They have been regrouped in modules called methods.py. Only methods modules are cythonized. Cythonization is made by generating a pyx file from the Python file. This approach doesn't require much adaptation to the Python code. Also, it allows optimizations like dynamic dispatch conversion to switch (Python code would have been dirty to with a lot of if-chains). As constraints validation has been completely refactored to be cythonized too, error customization has been added at the same time as it was simpler. --- .gitignore | 3 + README.md | 2 +- apischema/conversions/visitor.py | 12 +- apischema/deserialization/__init__.py | 652 ++++++++---------------- apischema/deserialization/methods.py | 680 ++++++++++++++++++++++++++ apischema/graphql/resolvers.py | 20 +- apischema/graphql/schema.py | 12 +- apischema/recursion.py | 1 - apischema/schemas/constraints.py | 100 +--- apischema/serialization/__init__.py | 438 ++++++++--------- apischema/serialization/methods.py | 371 ++++++++++++++ apischema/settings.py | 24 + apischema/utils.py | 39 +- apischema/visitor.py | 2 +- docs/performance_and_benchmark.md | 2 +- examples/pass_through.py | 6 +- examples/pass_through_primitives.py | 5 +- examples/validation_error.py | 2 +- scripts/generate_pyx.py | 294 +++++++++++ setup.py | 34 +- tests/requirements.txt | 1 + tests/test_deserialization_methods.py | 8 + tests/test_utils.py | 8 - tox.ini | 12 +- 24 files changed, 1874 insertions(+), 854 deletions(-) create mode 100644 apischema/deserialization/methods.py create mode 100644 apischema/serialization/methods.py create mode 100755 scripts/generate_pyx.py create mode 100644 tests/test_deserialization_methods.py diff --git a/.gitignore b/.gitignore index 4dd39b31..9aa7ef6c 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,6 @@ venv.bak/ .idea __generated__ cov-* +*.c +*.pyx +*.pxd diff --git a/README.md b/README.md index 33904df0..8df15843 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ This library fulfills the following goals: - stay as close as possible to the standard library (dataclasses, typing, etc.) — as a consequence we do not need plugins for editors/linters/etc.; - be adaptable, provide tools to support any types (ORM, etc.); -- avoid dynamic things like using raw strings for attributes name - play nicely with your IDE. +- avoid dynamic things like using raw strings for attributes name — play nicely with your IDE. No known alternative achieves all of this, and apischema is also [faster](https://wyfo.github.io/apischema/performance_and_benchmark) than all of them. diff --git a/apischema/conversions/visitor.py b/apischema/conversions/visitor.py index b9d58e08..49e41cc0 100644 --- a/apischema/conversions/visitor.py +++ b/apischema/conversions/visitor.py @@ -72,20 +72,20 @@ def annotated(self, tp: AnyType, annotations: Sequence[Any]) -> Result: return super().annotated(tp, annotations) return super().annotated(tp, annotations) - def _union_results(self, alternatives: Iterable[AnyType]) -> Sequence[Result]: + def _union_results(self, types: Iterable[AnyType]) -> Sequence[Result]: results = [] - for alt in alternatives: + for alt in types: with suppress(Unsupported): results.append(self.visit(alt)) if not results: - raise Unsupported(Union[tuple(alternatives)]) + raise Unsupported(Union[tuple(types)]) return results def _visited_union(self, results: Sequence[Result]) -> Result: raise NotImplementedError - def union(self, alternatives: Sequence[AnyType]) -> Result: - return self._visited_union(self._union_results(alternatives)) + def union(self, types: Sequence[AnyType]) -> Result: + return self._visited_union(self._union_results(types)) @contextmanager def _replace_conversion(self, conversion: Optional[AnyConversion]): @@ -130,7 +130,7 @@ def visit(self, tp: AnyType) -> Result: tp, self.default_conversion(get_origin_or_type(tp)) # type: ignore ) next_conversion = None - if not dynamic and is_subclass(tp, Collection): + if not dynamic and is_subclass(tp, Collection) and not is_subclass(tp, str): next_conversion = self._conversion return self.visit_conversion(tp, conversion, dynamic, next_conversion) diff --git a/apischema/deserialization/__init__.py b/apischema/deserialization/__init__.py index 43fe49d4..35fe958a 100644 --- a/apischema/deserialization/__init__.py +++ b/apischema/deserialization/__init__.py @@ -1,14 +1,14 @@ +import collections.abc +import re from collections import defaultdict from dataclasses import dataclass, replace from enum import Enum from functools import lru_cache from typing import ( - AbstractSet, Any, Callable, Collection, Dict, - List, Mapping, Optional, Pattern, @@ -31,8 +31,41 @@ from apischema.dependencies import get_dependent_required from apischema.deserialization.coercion import Coerce, Coercer from apischema.deserialization.flattened import get_deserialization_flattened_aliases +from apischema.deserialization.methods import ( + AdditionalField, + AnyMethod, + BoolMethod, + CoercerMethod, + CollectionMethod, + ConstrainedFloatMethod, + ConstrainedIntMethod, + ConstrainedStrMethod, + Constraint, + ConversionAlternative, + ConversionMethod, + ConversionUnionMethod, + DeserializationMethod, + Field, + FlattenedField, + FloatMethod, + IntMethod, + LiteralMethod, + MappingMethod, + NoneMethod, + ObjectMethod, + OptionalMethod, + PatternField, + RecMethod, + SetMethod, + StrMethod, + SubprimitiveMethod, + TupleMethod, + UnionByTypeMethod, + UnionMethod, + ValidatorMethod, + VariadicTupleMethod, +) from apischema.json_schema.patterns import infer_pattern -from apischema.json_schema.types import bad_type from apischema.metadata.implem import ValidatorsMetadata from apischema.metadata.keys import SCHEMA_METADATA, VALIDATORS_METADATA from apischema.objects import ObjectField @@ -40,34 +73,28 @@ from apischema.objects.visitor import DeserializationObjectVisitor from apischema.recursion import RecursiveConversionsVisitor from apischema.schemas import Schema, get_schema -from apischema.schemas.constraints import Check, Constraints, merge_constraints +from apischema.schemas.constraints import Constraints, merge_constraints from apischema.types import AnyType, NoneType from apischema.typing import get_args, get_origin from apischema.utils import ( Lazy, - PREFIX, deprecate_kwargs, get_origin_or_type, literal_values, opt_or, + partial_format, + to_pascal_case, + to_snake_case, ) from apischema.validation import get_validators -from apischema.validation.errors import ErrorKey, ValidationError, merge_errors -from apischema.validation.mock import ValidatorMock -from apischema.validation.validators import Validator, validate -from apischema.visitor import Unsupported +from apischema.validation.validators import Validator MISSING_PROPERTY = "missing property" UNEXPECTED_PROPERTY = "unexpected property" -NOT_NONE = object() - -INIT_VARS_ATTR = f"{PREFIX}_init_vars" - T = TypeVar("T") -DeserializationMethod = Callable[[Any], T] Factory = Callable[[Optional[Constraints], Sequence[Validator]], DeserializationMethod] @@ -89,20 +116,50 @@ def merge( validators=(*validators, *self.validators), ) - @property # type: ignore + # private intermediate method instead of decorated property because of mypy @lru_cache() - def method(self) -> DeserializationMethod: + def _method(self) -> DeserializationMethod: return self.factory(self.constraints, self.validators) # type: ignore + @property + def method(self) -> DeserializationMethod: + return self._method() + def get_constraints(schema: Optional[Schema]) -> Optional[Constraints]: return schema.constraints if schema is not None else None -def get_constraint_checks( - constraints: Optional[Constraints], cls: type -) -> Collection[Tuple[Check, Any, str]]: - return () if constraints is None else constraints.checks_by_type[cls] +constraint_classes = {cls.__name__: cls for cls in Constraint.__subclasses__()} + + +def constraints_validators( + constraints: Optional[Constraints], +) -> Mapping[type, Tuple[Constraint, ...]]: + from apischema import settings + + result: Dict[type, Tuple[Constraint, ...]] = defaultdict(tuple) + if constraints is not None: + for name, attr, metadata in constraints.attr_and_metata: + if attr is None or attr is False: + continue + error = getattr(settings.errors, to_snake_case(metadata.alias)) + error = partial_format( + error, + constraint=attr + if not isinstance(attr, type(re.compile(r""))) + else attr.pattern, + ) + constraint_cls = constraint_classes[ + to_pascal_case(metadata.alias) + "Constraint" + ] + result[metadata.cls] = ( + *result[metadata.cls], + constraint_cls(error) if attr is True else constraint_cls(error, attr), # type: ignore + ) + if float in result: + result[int] = result[float] + return result class DeserializationMethodVisitor( @@ -130,15 +187,7 @@ def _recursive_result( def factory( constraints: Optional[Constraints], validators: Sequence[Validator] ) -> DeserializationMethod: - rec_method = None - - def method(data: Any) -> Any: - nonlocal rec_method - if rec_method is None: - rec_method = lazy().merge(constraints, validators).method - return rec_method(data) - - return method + return RecMethod(lazy, constraints, validators) return DeserializationMethodFactory(factory) @@ -175,46 +224,20 @@ def wrapper( ) -> DeserializationMethod: method: DeserializationMethod if validation and validators: - wrapped, aliaser = factory(constraints, ()), self.aliaser - - def method(data: Any) -> Any: - result = wrapped(data) - validate(result, validators, aliaser=aliaser) - return result - + method = ValidatorMethod( + factory(constraints, ()), validators, self.aliaser + ) else: method = factory(constraints, validators) if self.coercer is not None and cls is not None: - coercer = self.coercer - - def wrapper(data: Any) -> Any: - assert cls is not None - return method(coercer(cls, data)) - - return wrapper - - else: - return method + method = CoercerMethod(self.coercer, cls, method) + return method return DeserializationMethodFactory(wrapper, cls) def any(self) -> DeserializationMethodFactory: def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - checks = None if constraints is None else constraints.checks_by_type - - def method(data: Any) -> Any: - if checks is not None: - if data.__class__ in checks: - errors = [ - err - for check, attr, err in checks[data.__class__] - if check(data, attr) - ] - if errors: - raise ValidationError(errors) - return data - - return method + return AnyMethod(dict(constraints_validators(constraints))) return self._factory(factory) @@ -224,35 +247,16 @@ def collection( value_factory = self.visit(value_type) def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - deserialize_value = value_factory.method - checks = get_constraint_checks(constraints, list) - constructor: Optional[Callable[[list], Collection]] = None - if issubclass(cls, AbstractSet): - constructor = set - elif issubclass(cls, tuple): - constructor = tuple - - def method(data: Any) -> Any: - if not isinstance(data, list): - raise bad_type(data, list) - elt_errors: Dict[ErrorKey, ValidationError] = {} - values: list = [None] * len(data) - index = 0 # don't use `enumerate` for performance - for elt in data: - try: - values[index] = deserialize_value(elt) - except ValidationError as err: - elt_errors[index] = err - index += 1 - if checks: - errors = [err for check, attr, err in checks if check(data, attr)] - if errors or elt_errors: - raise ValidationError(errors, elt_errors) - elif elt_errors: - raise ValidationError([], elt_errors) - return constructor(values) if constructor else values - - return method + method_cls: Type[CollectionMethod] + if issubclass(cls, collections.abc.Set): + method_cls = SetMethod + elif isinstance(cls, tuple): + method_cls = VariadicTupleMethod + else: + method_cls = CollectionMethod + return method_cls( + constraints_validators(constraints)[list], value_factory.method + ) return self._factory(factory, list) @@ -261,24 +265,15 @@ def enum(self, cls: Type[Enum]) -> DeserializationMethodFactory: def literal(self, values: Sequence[Any]) -> DeserializationMethodFactory: def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - value_map = dict(zip(literal_values(values), values)) - types = list(set(map(type, value_map))) if self.coercer else [] - error = f"not one of {list(value_map)}" - coercer = self.coercer - - def method(data: Any) -> Any: - try: - return value_map[data] - except KeyError: - if coercer: - for cls in types: - try: - return value_map[coercer(cls, data)] - except IndexError: - pass - raise ValidationError([error]) + from apischema import settings - return method + value_map = dict(zip(literal_values(values), values)) + return LiteralMethod( + value_map, + partial_format(settings.errors.one_of, constraint=list(value_map)), + self.coercer, + tuple(set(map(type, value_map))), + ) return self._factory(factory) @@ -288,30 +283,11 @@ def mapping( key_factory, value_factory = self.visit(key_type), self.visit(value_type) def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - deserialize_key = key_factory.method - deserialize_value = value_factory.method - checks = get_constraint_checks(constraints, dict) - - def method(data: Any) -> Any: - if not isinstance(data, dict): - raise bad_type(data, dict) - item_errors: Dict[ErrorKey, ValidationError] = {} - items = {} - for key, value in data.items(): - assert isinstance(key, str) - try: - items[deserialize_key(key)] = deserialize_value(value) - except ValidationError as err: - item_errors[key] = err - if checks: - errors = [err for check, attr, err in checks if check(data, attr)] - if errors or item_errors: - raise ValidationError(errors, item_errors) - elif item_errors: - raise ValidationError([], item_errors) - return items - - return method + return MappingMethod( + constraints_validators(constraints)[dict], + key_factory.method, + value_factory.method, + ) return self._factory(factory, dict) @@ -328,6 +304,8 @@ def object( def factory( constraints: Optional[Constraints], validators: Sequence[Validator] ) -> DeserializationMethod: + from apischema import settings + cls = get_origin_or_type(tp) alias_by_name = {field.name: self.aliaser(field.alias) for field in fields} requiring: Dict[str, Set[str]] = defaultdict(set) @@ -337,7 +315,7 @@ def factory( normal_fields, flattened_fields, pattern_fields = [], [], [] additional_field = None for field, field_factory in zip(fields, field_factories): - deserialize_field: DeserializationMethod = field_factory.method + field_method: DeserializationMethod = field_factory.method fall_back_on_default = ( field.fall_back_on_default or self.fall_back_on_default ) @@ -346,10 +324,10 @@ def factory( cls, field, self.default_conversion ) flattened_fields.append( - ( + FlattenedField( field.name, - set(map(self.aliaser, flattened_aliases)), - deserialize_field, + tuple(set(map(self.aliaser, flattened_aliases))), + field_method, fall_back_on_default, ) ) @@ -361,236 +339,68 @@ def factory( ) assert isinstance(field_pattern, Pattern) pattern_fields.append( - ( + PatternField( field.name, field_pattern, - deserialize_field, + field_method, fall_back_on_default, ) ) elif field.additional_properties: - additional_field = ( - field.name, - deserialize_field, - fall_back_on_default, + additional_field = AdditionalField( + field.name, field_method, fall_back_on_default ) else: normal_fields.append( - ( + Field( field.name, self.aliaser(field.alias), - deserialize_field, + field_method, field.required, requiring[field.name], fall_back_on_default, ) ) - has_aggregate_field = ( - flattened_fields or pattern_fields or (additional_field is not None) + return ObjectMethod( + cls, + constraints_validators(constraints)[dict], + tuple(normal_fields), + tuple(flattened_fields), + tuple(pattern_fields), + additional_field, + set(alias_by_name.values()), + self.additional_properties, + tuple(validators), + tuple( + (f.name, f.default_factory) + for f in fields + if f.kind == FieldKind.WRITE_ONLY + ), + {field.name for field in fields if field.post_init}, + self.aliaser, + settings.errors.missing_property, + settings.errors.unexpected_property, ) - post_init_modified = {field.name for field in fields if field.post_init} - checks = get_constraint_checks(constraints, dict) - aliaser = self.aliaser - additional_properties = self.additional_properties - all_aliases = set(alias_by_name.values()) - init_defaults = [ - (f.name, f.default_factory) - for f in fields - if f.kind == FieldKind.WRITE_ONLY - ] - - def method(data: Any) -> Any: - if not isinstance(data, dict): - raise bad_type(data, dict) - values: Dict[str, Any] = {} - fields_count = 0 - errors = ( - [err for check, attr, err in checks if check(data, attr)] - if checks - else [] - ) - field_errors: Dict[ErrorKey, ValidationError] = {} - for ( - name, - alias, - deserialize_field, - required, - required_by, - fall_back_on_default, - ) in normal_fields: - if required: - try: - value = data[alias] - except KeyError: - field_errors[alias] = ValidationError([MISSING_PROPERTY]) - else: - fields_count += 1 - try: - values[name] = deserialize_field(value) - except ValidationError as err: - field_errors[alias] = err - elif alias in data: - fields_count += 1 - try: - values[name] = deserialize_field(data[alias]) - except ValidationError as err: - if not fall_back_on_default: - field_errors[alias] = err - elif required_by and not required_by.isdisjoint(data): - requiring = sorted(required_by & data.keys()) - msg = f"missing property (required by {requiring})" - field_errors[alias] = ValidationError([msg]) - if has_aggregate_field: - remain = data.keys() - all_aliases - for ( - name, - flattened_alias, - deserialize_field, - fall_back_on_default, - ) in flattened_fields: - flattened = { - alias: data[alias] - for alias in flattened_alias - if alias in data - } - remain.difference_update(flattened) - try: - values[name] = deserialize_field(flattened) - except ValidationError as err: - if not fall_back_on_default: - errors.extend(err.messages) - field_errors.update(err.children) - for ( - name, - pattern, - deserialize_field, - fall_back_on_default, - ) in pattern_fields: - matched = { - key: data[key] for key in remain if pattern.match(key) - } - remain.difference_update(matched) - try: - values[name] = deserialize_field(matched) - except ValidationError as err: - if not fall_back_on_default: - errors.extend(err.messages) - field_errors.update(err.children) - if additional_field: - name, deserialize_field, fall_back_on_default = additional_field - additional = {key: data[key] for key in remain} - try: - values[name] = deserialize_field(additional) - except ValidationError as err: - if not fall_back_on_default: - errors.extend(err.messages) - field_errors.update(err.children) - elif remain and not additional_properties: - for key in remain: - field_errors[key] = ValidationError([UNEXPECTED_PROPERTY]) - elif not additional_properties and len(data) != fields_count: - for key in data.keys() - all_aliases: - field_errors[key] = ValidationError([UNEXPECTED_PROPERTY]) - validators2: Sequence[Validator] - if validators: - init: Dict[str, Any] = {} - for name, default_factory in init_defaults: - if name in values: - init[name] = values[name] - elif name not in field_errors: - assert default_factory is not None - init[name] = default_factory() - # Don't keep validators when all dependencies are default - validators2 = [ - v - for v in validators - if not v.dependencies.isdisjoint(values.keys()) - ] - if field_errors or errors: - error = ValidationError(errors, field_errors) - invalid_fields = field_errors.keys() | post_init_modified - try: - validate( - ValidatorMock(cls, values), - [ - v - for v in validators2 - if v.dependencies.isdisjoint(invalid_fields) - ], - init, - aliaser=aliaser, - ) - except ValidationError as err: - error = merge_errors(error, err) - raise error - elif field_errors or errors: - raise ValidationError(errors, field_errors) - else: - validators2, init = (), ... # type: ignore # only for linter - try: - res = cls(**values) - except (AssertionError, ValidationError): - raise - except TypeError as err: - if str(err).startswith("__init__() got"): - raise Unsupported(cls) - else: - raise ValidationError([str(err)]) - except Exception as err: - raise ValidationError([str(err)]) - if validators: - validate(res, validators2, init, aliaser=aliaser) - return res - - return method return self._factory(factory, dict, validation=False) def primitive(self, cls: Type) -> DeserializationMethodFactory: def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - checks = get_constraint_checks(constraints, cls) + validators = constraints_validators(constraints)[cls] if cls is NoneType: - - def method(data: Any) -> Any: - if data is not None: - raise bad_type(data, cls) - return data - - elif cls is not float and not checks: - - def method(data: Any) -> Any: - if not isinstance(data, cls): - raise bad_type(data, cls) - return data - - elif cls is not float and len(checks) == 1: - ((check, attr, err),) = checks - - def method(data: Any) -> Any: - if not isinstance(data, cls): - raise bad_type(data, cls) - elif check(data, attr): - raise ValidationError([err]) - return data - + return NoneMethod() + elif cls is bool: + return BoolMethod() + elif cls is str: + return ConstrainedStrMethod(validators) if validators else StrMethod() + elif cls is int: + return ConstrainedIntMethod(validators) if validators else IntMethod() + elif cls is float: + return ( + ConstrainedFloatMethod(validators) if validators else FloatMethod() + ) else: - is_float = cls is float - - def method(data: Any) -> Any: - if not isinstance(data, cls): - if is_float and isinstance(data, int): - data = float(data) - else: - raise bad_type(data, cls) - if checks: - errors = [ - err for check, attr, err in checks if check(data, attr) - ] - if errors: - raise ValidationError(errors) - return data - - return method + raise NotImplementedError return self._factory(factory, cls) @@ -600,14 +410,9 @@ def subprimitive(self, cls: Type, superclass: Type) -> DeserializationMethodFact def factory( constraints: Optional[Constraints], validators: Sequence[Validator] ) -> DeserializationMethod: - deserialize_primitive = primitive_factory.merge( - constraints, validators - ).method - - def method(data: Any) -> Any: - return superclass(deserialize_primitive(data)) - - return method + return SubprimitiveMethod( + cls, primitive_factory.merge(constraints, validators).method + ) return replace(primitive_factory, factory=factory) @@ -615,93 +420,43 @@ def tuple(self, types: Sequence[AnyType]) -> DeserializationMethodFactory: elt_factories = [self.visit(tp) for tp in types] def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - expected_len = len(types) - (_, _, min_err), (_, _, max_err) = Constraints( - min_items=len(types), max_items=len(types) - ).checks_by_type[list] - elt_methods = list(enumerate(fact.method for fact in elt_factories)) - checks = get_constraint_checks(constraints, list) - - def method(data: Any) -> Any: - if not isinstance(data, list): - raise bad_type(data, list) - if len(data) != expected_len: - raise ValidationError([min_err, max_err]) - elt_errors: Dict[ErrorKey, ValidationError] = {} - elts: List[Any] = [None] * expected_len - for i, deserialize_elt in elt_methods: - try: - elts[i] = deserialize_elt(data[i]) - except ValidationError as err: - elt_errors[i] = err - if checks: - errors = [err for check, attr, err in checks if check(data, attr)] - if errors or elt_errors: - raise ValidationError(errors, elt_errors) - elif elt_errors: - raise ValidationError([], elt_errors) - return tuple(elts) - - return method + def len_error(constraints: Constraints) -> str: + return constraints_validators(constraints)[list][0].error + + return TupleMethod( + constraints_validators(constraints)[list], + len_error(Constraints(min_items=len(types))), + len_error(Constraints(max_items=len(types))), + tuple(fact.method for fact in elt_factories), + ) return self._factory(factory, list) - def union(self, alternatives: Sequence[AnyType]) -> DeserializationMethodFactory: - alt_factories = self._union_results(alternatives) + def union(self, types: Sequence[AnyType]) -> DeserializationMethodFactory: + alt_factories = self._union_results(types) if len(alt_factories) == 1: return alt_factories[0] def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - alt_methods = [fact.merge(constraints).method for fact in alt_factories] + alt_methods = tuple( + fact.merge(constraints).method for fact in alt_factories + ) # method_by_cls cannot replace alt_methods, because there could be several # methods for one class - method_by_cls = dict(zip((f.cls for f in alt_factories), alt_methods)) - if NoneType in alternatives and len(alt_methods) == 2: - deserialize_alt = next( + method_by_cls = dict( + zip((f.cls for f in alt_factories if f.cls is not None), alt_methods) + ) + if NoneType in types and len(alt_methods) == 2: + value_method = next( meth for fact, meth in zip(alt_factories, alt_methods) if fact.cls is not NoneType ) - coercer = self.coercer - - def method(data: Any) -> Any: - if data is None: - return None - try: - return deserialize_alt(data) - except ValidationError as err: - if coercer and coercer(NoneType, data) is None: - return None - else: - raise merge_errors(err, bad_type(data, NoneType)) - - elif None not in method_by_cls and len(method_by_cls) == len(alt_factories): - classes = tuple(cls for cls in method_by_cls if cls is not None) - - def method(data: Any) -> Any: - try: - return method_by_cls[data.__class__](data) - except KeyError: - raise bad_type(data, *classes) from None - except ValidationError as err: - other_classes = ( - cls for cls in classes if cls is not data.__class__ - ) - raise merge_errors(err, bad_type(data, *other_classes)) - + return OptionalMethod(value_method, self.coercer) + elif len(method_by_cls) == len(alt_factories): + return UnionByTypeMethod(method_by_cls) else: - - def method(data: Any) -> Any: - error = None - for deserialize_alt in alt_methods: - try: - return deserialize_alt(data) - except ValidationError as err: - error = merge_errors(error, err) - assert error is not None - raise error - - return method + return UnionMethod(alt_methods) return self._factory(factory) @@ -719,42 +474,19 @@ def _visit_conversion( ] def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - conv_methods = [ - ((fact if dynamic else fact.merge(constraints)).method, conv.converter) + conv_alternatives = tuple( + ConversionAlternative( + conv.converter, + (fact if dynamic else fact.merge(constraints)).method, + ) for conv, fact in zip(conversion, conv_factories) - ] - method: DeserializationMethod - if len(conv_methods) == 1: - deserialize_alt, converter = conv_methods[0] - - def method(data: Any) -> Any: - try: - return converter(deserialize_alt(data)) - except (ValidationError, AssertionError): - raise - except Exception as err: - raise ValidationError([str(err)]) - + ) + if len(conv_alternatives) == 1: + return ConversionMethod( + conv_alternatives[0].converter, conv_alternatives[0].method + ) else: - - def method(data: Any) -> Any: - error: Optional[ValidationError] = None - for deserialize_alt, converter in conv_methods: - try: - value = deserialize_alt(data) - except ValidationError as err: - error = merge_errors(error, err) - else: - try: - return converter(value) - except (ValidationError, AssertionError): - raise - except Exception as err: - raise ValidationError([str(err)]) - assert error is not None - raise error - - return method + return ConversionUnionMethod(conv_alternatives) return self._factory(factory, validation=not dynamic) @@ -806,7 +538,7 @@ def deserialization_method( default_conversion: DefaultConversion = None, fall_back_on_default: bool = None, schema: Schema = None, -) -> DeserializationMethod[T]: +) -> Callable[[Any], T]: ... @@ -821,7 +553,7 @@ def deserialization_method( default_conversion: DefaultConversion = None, fall_back_on_default: bool = None, schema: Schema = None, -) -> DeserializationMethod: +) -> Callable[[Any], Any]: ... @@ -835,7 +567,7 @@ def deserialization_method( default_conversion: DefaultConversion = None, fall_back_on_default: bool = None, schema: Schema = None, -) -> DeserializationMethod: +) -> Callable[[Any], Any]: from apischema import settings coercer: Optional[Coercer] = None @@ -854,7 +586,7 @@ def deserialization_method( opt_or(fall_back_on_default, settings.deserialization.fall_back_on_default), ) .merge(get_constraints(schema), ()) - .method + .method.deserialize ) diff --git a/apischema/deserialization/methods.py b/apischema/deserialization/methods.py new file mode 100644 index 00000000..e9d00174 --- /dev/null +++ b/apischema/deserialization/methods.py @@ -0,0 +1,680 @@ +from dataclasses import dataclass, field +from typing import ( + AbstractSet, + Any, + Callable, + Dict, + Optional, + Pattern, + Sequence, + TYPE_CHECKING, + Tuple, +) + +from apischema.aliases import Aliaser +from apischema.conversions.utils import Converter +from apischema.deserialization.coercion import Coercer +from apischema.json_schema.types import bad_type +from apischema.schemas.constraints import Constraints +from apischema.types import NoneType +from apischema.utils import Lazy +from apischema.validation.errors import ValidationError, merge_errors +from apischema.validation.mock import ValidatorMock +from apischema.validation.validators import Validator, validate +from apischema.visitor import Unsupported + +if TYPE_CHECKING: + from apischema.deserialization import DeserializationMethodFactory + + +@dataclass +class Constraint: + error: str + + def validate(self, data: Any) -> bool: + raise NotImplementedError + + +@dataclass +class MinimumConstraint(Constraint): + minimum: int + + def validate(self, data: int) -> bool: + return data >= self.minimum + + +@dataclass +class MaximumConstraint(Constraint): + maximum: int + + def validate(self, data: int) -> bool: + return data <= self.maximum + + +@dataclass +class ExclusiveMinimumConstraint(Constraint): + exc_min: int + + def validate(self, data: int) -> bool: + return data > self.exc_min + + +@dataclass +class ExclusiveMaximumConstraint(Constraint): + exc_max: int + + def validate(self, data: int) -> bool: + return data < self.exc_max + + +@dataclass +class MultipleOfConstraint(Constraint): + mult_of: int + + def validate(self, data: int) -> bool: + return not (data % self.mult_of) + + +@dataclass +class MinLengthConstraint(Constraint): + min_len: int + + def validate(self, data: str) -> bool: + return len(data) >= self.min_len + + +@dataclass +class MaxLengthConstraint(Constraint): + max_len: int + + def validate(self, data: str) -> bool: + return len(data) <= self.max_len + + +@dataclass +class PatternConstraint(Constraint): + pattern: Pattern + + def validate(self, data: str) -> bool: + return self.pattern.match(data) is not None + + +@dataclass +class MinItemsConstraint(Constraint): + min_items: int + + def validate(self, data: list) -> bool: + return len(data) >= self.min_items + + +@dataclass +class MaxItemsConstraint(Constraint): + max_items: int + + def validate(self, data: list) -> bool: + return len(data) <= self.max_items + + +def to_hashable(data: Any) -> Any: + if isinstance(data, list): + return tuple(map(to_hashable, data)) + elif isinstance(data, dict): + # Cython doesn't support tuple comprehension yet -> intermediate list + return tuple([(k, to_hashable(data[k])) for k in sorted(data)]) + else: + return data + + +class UniqueItemsConstraint(Constraint): + def validate(self, data: list) -> bool: + return len(set(map(to_hashable, data))) == len(data) + + +@dataclass +class MinPropertiesConstraint(Constraint): + min_properties: int + + def validate(self, data: dict) -> bool: + return len(data) >= self.min_properties + + +@dataclass +class MaxPropertiesConstraint(Constraint): + max_properties: int + + def validate(self, data: dict) -> bool: + return len(data) <= self.max_properties + + +def validate_constraints( + data: Any, constraints: Tuple[Constraint, ...], children_errors: Optional[dict] +) -> Any: + for i in range(len(constraints)): + constraint: Constraint = constraints[i] + if not constraint.validate(data): + errors: list = [constraint.error.format(data)] + for j in range(i + 1, len(constraints)): + constraint = constraints[j] + if not constraint.validate(data): + errors.append(constraint.error.format(data)) + raise ValidationError(errors, children_errors or {}) + if children_errors: + raise ValidationError([], children_errors) + return data + + +class DeserializationMethod: + def deserialize(self, data: Any) -> Any: + raise NotImplementedError + + +@dataclass +class RecMethod(DeserializationMethod): + lazy: Lazy["DeserializationMethodFactory"] + constraints: Optional[Constraints] + validators: Sequence[Validator] + method: Optional[DeserializationMethod] = field(init=False) + + def __post_init__(self): + self.method = None + + def deserialize(self, data: Any) -> Any: + if self.method is None: + self.method = self.lazy().merge(self.constraints, self.validators).method + return self.method.deserialize(data) + + +@dataclass +class ValidatorMethod(DeserializationMethod): + method: DeserializationMethod + validators: Sequence[Validator] + aliaser: Aliaser + + def deserialize(self, data: Any) -> Any: + return validate( + self.method.deserialize(data), self.validators, aliaser=self.aliaser + ) + + +@dataclass +class CoercerMethod(DeserializationMethod): + coercer: Coercer + cls: type + method: DeserializationMethod + + def deserialize(self, data: Any) -> Any: + return self.method.deserialize(self.coercer(self.cls, data)) + + +@dataclass +class AnyMethod(DeserializationMethod): + constraints: Dict[type, Tuple[Constraint, ...]] + + def deserialize(self, data: Any) -> Any: + if type(data) in self.constraints: + validate_constraints(data, self.constraints[type(data)], None) + return data + + +@dataclass +class CollectionMethod(DeserializationMethod): + constraints: Tuple[Constraint, ...] + value_method: DeserializationMethod + + def deserialize(self, data: Any) -> Any: + if not isinstance(data, list): + raise bad_type(data, list) + data2: list = data + elt_errors: dict = {} + values: list = [None] * len(data2) + for i, elt in enumerate(data2): + try: + values[i] = self.value_method.deserialize(elt) + except ValidationError as err: + elt_errors[i] = err + validate_constraints(data2, self.constraints, elt_errors) + return values + + +@dataclass +class SetMethod(CollectionMethod): + def deserialize(self, data: Any) -> Any: + return set(super().deserialize(data)) + + +@dataclass +class VariadicTupleMethod(CollectionMethod): + def deserialize(self, data: Any) -> Any: + return tuple(super().deserialize(data)) + + +@dataclass +class LiteralMethod(DeserializationMethod): + value_map: dict + error: str + coercer: Optional[Coercer] + types: Tuple[type, ...] + + def deserialize(self, data: Any) -> Any: + try: + return self.value_map[data] + except KeyError: + if self.coercer is not None: + for cls in self.types: + try: + return self.value_map[self.coercer(cls, data)] + except IndexError: + pass + raise ValidationError([self.error.format(data)]) + + +@dataclass +class MappingMethod(DeserializationMethod): + constraints: Tuple[Constraint, ...] + key_method: DeserializationMethod + value_method: DeserializationMethod + + def deserialize(self, data: Any) -> Any: + if not isinstance(data, dict): + raise bad_type(data, dict) + data2: dict = data + item_errors: dict = {} + items: dict = {} + for key, value in data2.items(): + assert isinstance(key, str) + try: + items[self.key_method.deserialize(key)] = self.value_method.deserialize( + value + ) + except ValidationError as err: + item_errors[key] = err + validate_constraints(data2, self.constraints, item_errors) + return items + + +@dataclass +class Field: + name: str + alias: str + method: DeserializationMethod + required: bool + required_by: Optional[AbstractSet[str]] + fall_back_on_default: bool + + +@dataclass +class FlattenedField: + name: str + aliases: Tuple[str, ...] + method: DeserializationMethod + fall_back_on_default: bool + + +@dataclass +class PatternField: + name: str + pattern: Pattern + method: DeserializationMethod + fall_back_on_default: bool + + +@dataclass +class AdditionalField: + name: str + method: DeserializationMethod + fall_back_on_default: bool + + +@dataclass +class ObjectMethod(DeserializationMethod): + cls: Any # cython doesn't handle type subclasses properly + constraints: Tuple[Constraint, ...] + fields: Tuple[Field, ...] + flattened_fields: Tuple[FlattenedField, ...] + pattern_fields: Tuple[PatternField, ...] + additional_field: Optional[AdditionalField] + all_aliases: AbstractSet[str] + additional_properties: bool + validators: Tuple[Validator, ...] + init_defaults: Tuple[Tuple[str, Optional[Callable[[], Any]]], ...] + post_init_modified: AbstractSet[str] + aliaser: Aliaser + missing: str + unexpected: str + aggregate_fields: bool = field(init=False) + + def __post_init__(self): + self.aggregate_fields = bool( + self.flattened_fields + or self.pattern_fields + or self.additional_field is not None + ) + + def deserialize(self, data: Any) -> Any: + if not isinstance(data, dict): + raise bad_type(data, dict) + data2: dict = data + values: dict = {} + fields_count = 0 + errors: list = [] + try: + validate_constraints(data, self.constraints, None) + except ValidationError as err: + errors.extend(err.messages) + field_errors: dict = {} + for i in range(len(self.fields)): + field: Field = self.fields[i] + if field.required: + try: + value = data2[field.alias] + except KeyError: + field_errors[field.alias] = ValidationError([self.missing]) + else: + fields_count += 1 + try: + values[field.name] = field.method.deserialize(value) + except ValidationError as err: + field_errors[field.alias] = err + elif field.alias in data2: + fields_count += 1 + try: + values[field.name] = field.method.deserialize(data2[field.alias]) + except ValidationError as err: + if not field.fall_back_on_default: + field_errors[field.alias] = err + elif field.required_by is not None and not field.required_by.isdisjoint( + data2 + ): + requiring = sorted(field.required_by & data2.keys()) + msg = self.missing + f" (required by {requiring})" + field_errors[field.alias] = ValidationError([msg]) + if self.aggregate_fields: + remain = data2.keys() - self.all_aliases + for i in range(len(self.flattened_fields)): + flattened_field: FlattenedField = self.flattened_fields[i] + flattened = { + alias: data2[alias] + for alias in flattened_field.aliases + if alias in data2 + } + remain.difference_update(flattened) + try: + values[flattened_field.name] = flattened_field.method.deserialize( + flattened + ) + except ValidationError as err: + if not flattened_field.fall_back_on_default: + errors.extend(err.messages) + field_errors.update(err.children) + for i in range(len(self.pattern_fields)): + pattern_field: PatternField = self.pattern_fields[i] + matched = { + key: data2[key] + for key in remain + if pattern_field.pattern.match(key) + } + remain.difference_update(matched) + try: + values[pattern_field.name] = pattern_field.method.deserialize( + matched + ) + except ValidationError as err: + if not pattern_field.fall_back_on_default: + errors.extend(err.messages) + field_errors.update(err.children) + if self.additional_field is not None: + additional = {key: data2[key] for key in remain} + try: + values[ + self.additional_field.name + ] = self.additional_field.method.deserialize(additional) + except ValidationError as err: + if not self.additional_field.fall_back_on_default: + errors.extend(err.messages) + field_errors.update(err.children) + elif remain and not self.additional_properties: + for key in remain: + field_errors[key] = ValidationError([self.unexpected]) + elif not self.additional_properties and len(data2) != fields_count: + for key in data2.keys() - self.all_aliases: + field_errors[key] = ValidationError([self.unexpected]) + validators2: list = [] + init: dict = {} + if self.validators: + for name, default_factory in self.init_defaults: + if name in values: + init[name] = values[name] + elif name not in field_errors: + assert default_factory is not None + init[name] = default_factory() + # Don't keep validators when all dependencies are default + validators2 = [ + v + for v in self.validators + if not v.dependencies.isdisjoint(values.keys()) + ] + if field_errors or errors: + error = ValidationError(errors, field_errors) + invalid_fields = field_errors.keys() | self.post_init_modified + try: + validate( + ValidatorMock(self.cls, values), + [ + v + for v in validators2 + if v.dependencies.isdisjoint(invalid_fields) + ], + init, + aliaser=self.aliaser, + ) + except ValidationError as err: + error = merge_errors(error, err) + raise error + elif field_errors or errors: + raise ValidationError(errors, field_errors) + try: + res = self.cls(**values) + except (AssertionError, ValidationError): + raise + except TypeError as err: + if str(err).startswith("__init__() got"): + raise Unsupported(self.cls) + else: + raise ValidationError([str(err)]) + except Exception as err: + raise ValidationError([str(err)]) + if self.validators: + validate(res, validators2, init, aliaser=self.aliaser) + return res + + +class NoneMethod(DeserializationMethod): + def deserialize(self, data: Any) -> Any: + if data is not None: + raise bad_type(data, NoneType) + return data + + +class IntMethod(DeserializationMethod): + def deserialize(self, data: Any) -> Any: + if not isinstance(data, int): + raise bad_type(data, int) + return data + + +class FloatMethod(DeserializationMethod): + def deserialize(self, data: Any) -> Any: + if isinstance(data, float): + return data + elif isinstance(data, int): + return float(data) + else: + raise bad_type(data, float) + + +class StrMethod(DeserializationMethod): + def deserialize(self, data: Any) -> Any: + if not isinstance(data, str): + raise bad_type(data, str) + return data + + +class BoolMethod(DeserializationMethod): + def deserialize(self, data: Any) -> Any: + if not isinstance(data, bool): + raise bad_type(data, bool) + return data + + +@dataclass +class ConstrainedIntMethod(IntMethod): + constraints: Tuple[Constraint, ...] + + def deserialize(self, data: Any) -> Any: + return validate_constraints(super().deserialize(data), self.constraints, None) + + +@dataclass +class ConstrainedFloatMethod(FloatMethod): + constraints: Tuple[Constraint, ...] + + def deserialize(self, data: Any) -> Any: + return validate_constraints(super().deserialize(data), self.constraints, None) + + +@dataclass +class ConstrainedStrMethod(StrMethod): + constraints: Tuple[Constraint, ...] + + def deserialize(self, data: Any) -> Any: + return validate_constraints(super().deserialize(data), self.constraints, None) + + +@dataclass +class SubprimitiveMethod(DeserializationMethod): + cls: type + method: DeserializationMethod + + def deserialize(self, data: Any) -> Any: + return self.cls(self.method.deserialize(data)) + + +@dataclass +class TupleMethod(DeserializationMethod): + constraints: Tuple[Constraint, ...] + min_len_error: str + max_len_error: str + elt_methods: Tuple[DeserializationMethod, ...] + + def deserialize(self, data: Any) -> Any: + if not isinstance(data, list): + raise bad_type(data, list) + data2: list = data + if len(data2) != len(self.elt_methods): + if len(data2) < len(self.elt_methods): + raise ValidationError([self.min_len_error % len(data2)]) + elif len(data2) > len(self.elt_methods): + raise ValidationError([self.max_len_error % len(data2)]) + else: + raise NotImplementedError + elt_errors: dict = {} + elts: list = [None] * len(self.elt_methods) + for i in range(len(self.elt_methods)): + elt_method: DeserializationMethod = self.elt_methods[i] + try: + elts[i] = elt_method.deserialize(data2[i]) + except ValidationError as err: + elt_errors[i] = err + validate_constraints(data2, self.constraints, elt_errors) + return tuple(elts) + + +@dataclass +class OptionalMethod(DeserializationMethod): + value_method: DeserializationMethod + coercer: Optional[Coercer] + + def deserialize(self, data: Any) -> Any: + if data is None: + return None + try: + return self.value_method.deserialize(data) + except ValidationError as err: + if self.coercer is not None and self.coercer(NoneType, data) is None: + return None + else: + raise merge_errors(err, bad_type(data, NoneType)) + + +@dataclass +class UnionByTypeMethod(DeserializationMethod): + method_by_cls: Dict[type, DeserializationMethod] + + def deserialize(self, data: Any) -> Any: + try: + method: DeserializationMethod = self.method_by_cls[type(data)] + return method.deserialize(data) + except KeyError: + raise bad_type(data, *self.method_by_cls) from None + except ValidationError as err: + other_classes = (cls for cls in self.method_by_cls if cls is not type(data)) + raise merge_errors(err, bad_type(data, *other_classes)) + + +@dataclass +class UnionMethod(DeserializationMethod): + alt_methods: Tuple[DeserializationMethod, ...] + + def deserialize(self, data: Any) -> Any: + error = None + for i in range(len(self.alt_methods)): + alt_method: DeserializationMethod = self.alt_methods[i] + try: + return alt_method.deserialize(data) + except ValidationError as err: + error = merge_errors(error, err) + assert error is not None + raise error + + +@dataclass +class ConversionMethod(DeserializationMethod): + converter: Converter + method: DeserializationMethod + + def deserialize(self, data: Any) -> Any: + try: + return self.converter(self.method.deserialize(data)) + except (ValidationError, AssertionError): + raise + except Exception as err: + raise ValidationError([str(err)]) + + +@dataclass +class ConversionAlternative: + converter: Converter + method: DeserializationMethod + + +@dataclass +class ConversionUnionMethod(DeserializationMethod): + alternatives: Tuple[ConversionAlternative, ...] + + def deserialize(self, data: Any) -> Any: + error: Optional[ValidationError] = None + for i in range(len(self.alternatives)): + alternative: ConversionAlternative = self.alternatives[i] + try: + value = alternative.method.deserialize(data) + except ValidationError as err: + error = merge_errors(error, err) + else: + try: + return alternative.converter(value) + except (ValidationError, AssertionError): + raise + except Exception as err: + raise ValidationError([str(err)]) + assert error is not None + raise error diff --git a/apischema/graphql/resolvers.py b/apischema/graphql/resolvers.py index 14e0c451..ad78bdb9 100644 --- a/apischema/graphql/resolvers.py +++ b/apischema/graphql/resolvers.py @@ -33,6 +33,8 @@ from apischema.ordering import Ordering from apischema.schemas import Schema from apischema.serialization import ( + IDENTITY_METHOD, + METHODS, PassThroughOptions, SerializationMethod, SerializationMethodVisitor, @@ -64,18 +66,16 @@ class PartialSerializationMethodVisitor(SerializationMethodVisitor): @property def _factory(self) -> Callable[[type], SerializationMethod]: - return lambda _: identity + raise NotImplementedError def enum(self, cls: Type[Enum]) -> SerializationMethod: - return identity + return IDENTITY_METHOD def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMethod: - return identity + return IDENTITY_METHOD def visit(self, tp: AnyType) -> SerializationMethod: - if tp is UndefinedType: - return lambda obj: None - return super().visit(tp) + return METHODS[NoneType] if tp is UndefinedType else super().visit(tp) @cache @@ -291,16 +291,16 @@ def handle_enum(tp: AnyType) -> Optional[AnyConversion]: if not serialized: serialize_result = identity elif is_async(resolver.func): - serialize_result = as_async(method_factory(types["return"])) + serialize_result = as_async(method_factory(types["return"]).serialize) else: - serialize_result = method_factory(types["return"]) + serialize_result = method_factory(types["return"]).serialize serialize_error: Optional[Callable[[Any], Any]] if error_handler is None: serialize_error = None elif is_async(error_handler): - serialize_error = as_async(method_factory(resolver.error_type())) + serialize_error = as_async(method_factory(resolver.error_type()).serialize) else: - serialize_error = method_factory(resolver.error_type()) + serialize_error = method_factory(resolver.error_type()).serialize def resolve(__self, __info, **kwargs): values = {} diff --git a/apischema/graphql/schema.py b/apischema/graphql/schema.py index bdf43674..3a198f05 100644 --- a/apischema/graphql/schema.py +++ b/apischema/graphql/schema.py @@ -388,15 +388,13 @@ def factory( def tuple(self, types: Sequence[AnyType]) -> TypeFactory[GraphQLTp]: raise TypeError("Tuple are not supported") - def union(self, alternatives: Sequence[AnyType]) -> TypeFactory[GraphQLTp]: - factories = self._union_results( - (alt for alt in alternatives if alt is not NoneType) - ) + def union(self, types: Sequence[AnyType]) -> TypeFactory[GraphQLTp]: + factories = self._union_results((alt for alt in types if alt is not NoneType)) if len(factories) == 1: factory = factories[0] else: factory = self._visited_union(factories) - if NoneType in alternatives or UndefinedType in alternatives: + if NoneType in types or UndefinedType in types: def nullable(name: Optional[str], description: Optional[str]) -> GraphQLTp: res = factory.factory(name, description) # type: ignore @@ -616,7 +614,7 @@ def resolve_wrapper(__obj, __info, **kwargs): def _field(self, tp: AnyType, field: ObjectField) -> Lazy[graphql.GraphQLField]: field_name = field.name - partial_serialize = self._field_serialization_method(field) + partial_serialize = self._field_serialization_method(field).serialize @self._wrap_resolve def resolve(obj, _): @@ -711,7 +709,7 @@ def _visit_flattened( self.get_flattened if self.get_flattened is not None else identity ) field_name = field.name - partial_serialize = self._field_serialization_method(field) + partial_serialize = self._field_serialization_method(field).serialize def get_flattened(obj): return partial_serialize(getattr(get_prev_flattened(obj), field_name)) diff --git a/apischema/recursion.py b/apischema/recursion.py index d05b6b16..beac1478 100644 --- a/apischema/recursion.py +++ b/apischema/recursion.py @@ -154,7 +154,6 @@ def visit(self, tp: AnyType) -> Result: DeserializationRecursiveChecker # type: ignore if isinstance(self, DeserializationVisitor) else SerializationRecursiveChecker, - # None, ): cache_key = tp, self._conversion if cache_key in self._cache: diff --git a/apischema/schemas/constraints.py b/apischema/schemas/constraints.py index c6a48d07..1367f781 100644 --- a/apischema/schemas/constraints.py +++ b/apischema/schemas/constraints.py @@ -1,37 +1,14 @@ import operator as op -from collections import defaultdict from dataclasses import dataclass, field, fields from math import gcd -from typing import ( - Any, - Callable, - Collection, - Dict, - Mapping, - Optional, - Pattern, - Tuple, - TypeVar, -) +from typing import Any, Callable, Collection, Dict, Optional, Pattern, Tuple, TypeVar from apischema.types import Number -from apischema.utils import merge_opts, to_hashable +from apischema.utils import merge_opts T = TypeVar("T") U = TypeVar("U") -COMPARISON_MERGE_AND_ERRORS: Dict[Callable, Tuple[Callable, str]] = { - op.lt: (max, "less than %s"), - op.le: (max, "less than or equal to %s"), - op.gt: (min, "greater than %s"), - op.ge: (min, "greater than or equal to %s"), -} -PREFIX_DICT: Mapping[type, str] = { - str: "string length", - list: "item count", - dict: "property count", -} -Check = Callable[[Any, Any], Any] CONSTRAINT_METADATA_KEY = "constraint" @@ -39,8 +16,6 @@ class ConstraintMetadata: alias: str cls: type - check: Check - error: Callable[[Any], str] merge: Callable[[T, T], T] @property @@ -48,18 +23,11 @@ def field(self) -> Any: return field(default=None, metadata={CONSTRAINT_METADATA_KEY: self}) -def comparison(alias: str, cls: type, check: Check) -> Any: - merge, error = COMPARISON_MERGE_AND_ERRORS[check] - prefix = PREFIX_DICT.get(cls) # type: ignore - if prefix: - error = prefix + " " + error.replace("less", "lower") - if cls in (str, list, dict): - wrapped = check - - def check(data: Any, value: Any) -> bool: - return wrapped(len(data), value) - - return ConstraintMetadata(alias, cls, check, lambda v: error % v, merge).field +def constraint(alias: str, cls: type, merge: Callable[[T, T], T]) -> Any: + return field( + default=None, + metadata={CONSTRAINT_METADATA_KEY: ConstraintMetadata(alias, cls, merge)}, + ) def merge_mult_of(m1: Number, m2: Number) -> Number: @@ -68,47 +36,32 @@ def merge_mult_of(m1: Number, m2: Number) -> Number: return m1 * m2 / gcd(m1, m2) # type: ignore -def not_match_pattern(data: str, pattern: Pattern) -> bool: - return not pattern.match(data) - - def merge_pattern(p1: Pattern, p2: Pattern) -> Pattern: raise TypeError("Cannot merge patterns") -def not_unique(data: list, unique: bool) -> bool: - return (op.ne if unique else op.eq)(len(set(map(to_hashable, data))), len(data)) +min_, max_ = min, max @dataclass(frozen=True) class Constraints: # number - min: Optional[Number] = comparison("minimum", float, op.lt) - max: Optional[Number] = comparison("maximum", float, op.gt) - exc_min: Optional[Number] = comparison("exclusiveMinimum", float, op.le) - exc_max: Optional[Number] = comparison("exclusiveMaximum", float, op.ge) - mult_of: Optional[Number] = ConstraintMetadata( - "multipleOf", float, op.mod, lambda n: f"not a multiple of {n}", merge_mult_of # type: ignore - ).field + min: Optional[Number] = constraint("minimum", float, max_) + max: Optional[Number] = constraint("maximum", float, min_) + exc_min: Optional[Number] = constraint("exclusiveMinimum", float, max_) + exc_max: Optional[Number] = constraint("exclusiveMaximum", float, min_) + mult_of: Optional[Number] = constraint("multipleOf", float, merge_mult_of) # string - min_len: Optional[int] = comparison("minLength", str, op.lt) - max_len: Optional[int] = comparison("maxLength", str, op.gt) - pattern: Optional[Pattern] = ConstraintMetadata( - "pattern", - str, - not_match_pattern, - lambda p: f"not matching '{p.pattern}'", - merge_pattern, # type: ignore - ).field + min_len: Optional[int] = constraint("minLength", str, max_) + max_len: Optional[int] = constraint("maxLength", str, min_) + pattern: Optional[Pattern] = constraint("pattern", str, merge_pattern) # array - min_items: Optional[int] = comparison("minItems", list, op.lt) - max_items: Optional[int] = comparison("maxItems", list, op.gt) - unique: Optional[bool] = ConstraintMetadata( - "uniqueItems", list, not_unique, lambda _: "duplicate items", op.or_ - ).field + min_items: Optional[int] = constraint("minItems", list, max_) + max_items: Optional[int] = constraint("maxItems", list, min_) + unique: Optional[bool] = constraint("uniqueItems", list, op.or_) # object - min_props: Optional[int] = comparison("minProperties", dict, op.lt) - max_props: Optional[int] = comparison("maxProperties", dict, op.gt) + min_props: Optional[int] = constraint("minProperties", dict, max_) + max_props: Optional[int] = constraint("maxProperties", dict, min_) @property def attr_and_metata( @@ -120,17 +73,6 @@ def attr_and_metata( if CONSTRAINT_METADATA_KEY in f.metadata ] - @property - def checks_by_type(self) -> Mapping[type, Collection[Tuple[Check, Any, str]]]: - result = defaultdict(list) - for _, attr, metadata in self.attr_and_metata: - if attr is None: - continue - error = f"{metadata.error(attr)} ({metadata.alias})" - result[metadata.cls].append((metadata.check, attr, error)) - result[int] = result[float] - return result - def merge_into(self, base_schema: Dict[str, Any]): for name, attr, metadata in self.attr_and_metata: if attr is not None: diff --git a/apischema/serialization/__init__.py b/apischema/serialization/__init__.py index 73b7058b..141cace2 100644 --- a/apischema/serialization/__init__.py +++ b/apischema/serialization/__init__.py @@ -1,9 +1,9 @@ import collections.abc -import operator from contextlib import suppress from dataclasses import dataclass from enum import Enum from functools import lru_cache +from itertools import starmap from typing import ( Any, Callable, @@ -11,7 +11,6 @@ Mapping, Optional, Sequence, - Tuple, Type, TypeVar, Union, @@ -28,19 +27,56 @@ SerializationVisitor, sub_conversion, ) -from apischema.fields import FIELDS_SET_ATTR, support_fields_set +from apischema.fields import support_fields_set from apischema.objects import AliasedStr, ObjectField from apischema.objects.visitor import SerializationObjectVisitor -from apischema.ordering import sort_by_order +from apischema.ordering import Ordering, sort_by_order from apischema.recursion import RecursiveConversionsVisitor +from apischema.serialization.methods import ( + AnyFallback, + AnyMethod, + BaseField, + BoolMethod, + CheckedTupleMethod, + ClassMethod, + ClassWithFieldsSetMethod, + CollectionMethod, + ComplexField, + ConversionMethod, + DictMethod, + EnumMethod, + Fallback, + FloatMethod, + IdentityField, + IdentityMethod, + IntMethod, + ListMethod, + MappingMethod, + NoFallback, + NoneMethod, + OptionalMethod, + RecMethod, + SerializationMethod, + SerializedField, + SimpleField, + StrMethod, + TupleMethod, + TypeCheckIdentityMethod, + TypeCheckMethod, + TypedDictMethod, + TypedDictWithAdditionalMethod, + UnionAlternative, + UnionMethod, + ValueMethod, + WrapperMethod, +) from apischema.serialization.serialized_methods import get_serialized_methods from apischema.types import AnyType, NoneType, Undefined, UndefinedType -from apischema.typing import is_new_type, is_type, is_type_var, is_typed_dict, is_union +from apischema.typing import is_new_type, is_type, is_type_var, is_typed_dict from apischema.utils import ( Lazy, as_predicate, deprecate_kwargs, - get_args2, get_origin_or_type, get_origin_or_type2, identity, @@ -49,36 +85,40 @@ ) from apischema.visitor import Unsupported -SerializationMethod = Callable[[Any], Any] -SerializationMethodFactory = Callable[[AnyType], SerializationMethod] +IDENTITY_METHOD = IdentityMethod() + +METHODS = { + identity: IDENTITY_METHOD, + list: ListMethod(), + dict: DictMethod(), + str: StrMethod(), + int: IntMethod(), + bool: BoolMethod(), + float: FloatMethod(), + NoneType: NoneMethod(), +} +SerializationMethodFactory = Callable[[AnyType], SerializationMethod] T = TypeVar("T") -def instance_checker(tp: AnyType) -> Tuple[Callable[[Any, Any], bool], Any]: +def expected_class(tp: AnyType) -> type: origin = get_origin_or_type2(tp) if origin is NoneType: - return operator.is_, None + return NoneType elif is_typed_dict(origin): - return isinstance, collections.abc.Mapping + return collections.abc.Mapping elif is_type(origin): - return isinstance, origin + return origin elif is_new_type(origin): - return instance_checker(origin.__supertype__) + return expected_class(origin.__supertype__) elif is_type_var(origin) or origin is Any: - return (lambda data, _: True), ... - elif is_union(origin): - checks = list(map(instance_checker, get_args2(tp))) - return (lambda data, _: any(check(data, arg) for check, arg in checks)), ... + return object else: raise TypeError(f"{tp} is not supported in union serialization") -def identity_as_none(method: SerializationMethod) -> Optional[SerializationMethod]: - return method if method is not identity else None - - @dataclass(frozen=True) class PassThroughOptions: any: bool = False @@ -90,6 +130,13 @@ def __post_init__(self): object.__setattr__(self, "types", as_predicate(self.types)) +@dataclass +class FieldToOrder: + name: str + ordering: Optional[Ordering] + field: BaseField + + class SerializationMethodVisitor( RecursiveConversionsVisitor[Serialization, SerializationMethod], SerializationVisitor[SerializationMethod], @@ -139,273 +186,187 @@ def visit_not_recursive(self, tp: AnyType): return self._factory(tp) if self.use_cache else super().visit_not_recursive(tp) def _recursive_result(self, lazy: Lazy[SerializationMethod]) -> SerializationMethod: - rec_method = None - - def method(obj: Any) -> Any: - nonlocal rec_method - if rec_method is None: - rec_method = lazy() - return rec_method(obj) - - return method + return RecMethod(lazy) def any(self) -> SerializationMethod: if self.pass_through_options.any: - return identity - factory = self._factory - - def method(obj: Any) -> Any: - return factory(obj.__class__)(obj) - - return method + return IDENTITY_METHOD + return AnyMethod(self._factory) - def _any_fallback(self, tp: AnyType) -> SerializationMethod: - fallback, serialize_any = self.fall_back_on_any, self.any() + def _any_fallback(self, tp: AnyType) -> Fallback: + return AnyFallback(self.any()) if self.fall_back_on_any else NoFallback(tp) - def method(obj: Any) -> Any: - if fallback: - return serialize_any(obj) - else: - raise TypeError(f"Expected {tp}, found {obj.__class__}") - - return method - - def _wrap(self, cls: type, method: SerializationMethod) -> SerializationMethod: + def _wrap(self, tp: AnyType, method: SerializationMethod) -> SerializationMethod: if not self.check_type: return method - fallback = self._any_fallback(cls) - cls_to_check = Mapping if is_typed_dict(cls) else cls - - def wrapper(obj: Any) -> Any: - if isinstance(obj, cls_to_check): - try: - return method(obj) - except Exception: - pass - return fallback(obj) - - return wrapper + elif method is IDENTITY_METHOD: + return TypeCheckIdentityMethod(expected_class(tp), self._any_fallback(tp)) + else: + return TypeCheckMethod(expected_class(tp), self._any_fallback(tp), method) def collection( self, cls: Type[Collection], value_type: AnyType ) -> SerializationMethod: - serialize_value = self.visit(value_type) + value_method = self.visit(value_type) method: SerializationMethod - if serialize_value is not identity: - - def method(obj: Any) -> Any: - # using map is faster than comprehension - return list(map(serialize_value, obj)) - + if value_method is not IDENTITY_METHOD: + return CollectionMethod(value_method) elif issubclass(cls, (list, tuple)) or ( self.pass_through_options.collections and not issubclass(cls, collections.abc.Set) ): - method = identity + method = IDENTITY_METHOD else: - method = list + method = METHODS[list] return self._wrap(cls, method) def enum(self, cls: Type[Enum]) -> SerializationMethod: + method: SerializationMethod if self.pass_through_options.enums or issubclass(cls, (int, str)): - return identity - elif all( - method is identity - for method in map(self.visit, {elt.value.__class__ for elt in cls}) - ): - method: SerializationMethod = operator.attrgetter("value") + method = IDENTITY_METHOD else: any_method = self.any() - - def method(obj: Any) -> Any: - return any_method(obj.value) - + if any_method is IDENTITY_METHOD or all( + m is IDENTITY_METHOD + for m in map(self.visit, {elt.value.__class__ for elt in cls}) + ): + method = ValueMethod() + else: + assert isinstance(any_method, AnyMethod) + method = EnumMethod(any_method) return self._wrap(cls, method) def literal(self, values: Sequence[Any]) -> SerializationMethod: if self.pass_through_options.enums or all( isinstance(v, (int, str)) for v in values ): - return identity + return IDENTITY_METHOD else: return self.any() def mapping( self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType ) -> SerializationMethod: - serialize_key, serialize_value = self.visit(key_type), self.visit(value_type) + key_method, value_method = self.visit(key_type), self.visit(value_type) method: SerializationMethod - if serialize_key is not identity or serialize_value is not identity: - - def method(obj: Any) -> Any: - return { - serialize_key(key): serialize_value(value) - for key, value in obj.items() - } - + if key_method is not IDENTITY_METHOD or value_method is not IDENTITY_METHOD: + method = MappingMethod(key_method, value_method) elif self.pass_through_options.collections or issubclass(cls, dict): - method = identity + method = IDENTITY_METHOD else: - method = dict + method = METHODS[dict] return self._wrap(cls, method) def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMethod: cls = get_origin_or_type(tp) - typed_dict = is_typed_dict(cls) - getter: Callable[[str], Callable[[Any], Any]] = ( - operator.itemgetter if typed_dict else operator.attrgetter - ) - serialization_fields = [ - ( - field.name, - self.aliaser(field.alias) if not field.is_aggregate else None, - getter(field.name), - field.required, - field.skip.serialization_if, - is_union_of(field.type, UndefinedType) or default is Undefined, - (is_union_of(field.type, NoneType) and self.exclude_none) - or field.none_as_undefined - or (default is None and self.exclude_defaults), - (field.skip.serialization_default or self.exclude_defaults) - and default not in (None, Undefined), - default, - identity_as_none(self.visit_with_conv(field.type, field.serialization)), - field.ordering, - ) - for field in fields - for default in [... if field.required else field.get_default()] - ] + [ - ( - serialized.func.__name__, - self.aliaser(serialized.alias), - serialized.func, - True, - None, - is_union_of(ret_type, UndefinedType), - is_union_of(ret_type, NoneType) and self.exclude_none, - False, - ..., - self.visit_with_conv(ret_type, serialized.conversion), - serialized.ordering, + fields_to_order = [] + for field in fields: + field_alias = self.aliaser(field.alias) if not field.is_aggregate else None + field_method = self.visit_with_conv(field.type, field.serialization) + field_default = ... if field.required else field.get_default() + base_field: BaseField + if field_alias is None or field.skippable( + self.exclude_defaults, self.exclude_none + ): + base_field = ComplexField( + field.name, + field_alias, # type: ignore + field.required, + field_method, + field.skip.serialization_if, + is_union_of(field.type, UndefinedType) + or field_default is Undefined, + (is_union_of(field.type, NoneType) and self.exclude_none) + or field.none_as_undefined + or (field_default is None and self.exclude_defaults), + (field.skip.serialization_default or self.exclude_defaults) + and field_default not in (None, Undefined), + field_default, + ) + elif field_method is IDENTITY_METHOD: + base_field = IdentityField(field.name, field_alias, field.required) + else: + base_field = SimpleField( + field.name, field_alias, field.required, field_method + ) + fields_to_order.append(FieldToOrder(field.name, field.ordering, base_field)) + for serialized, types in get_serialized_methods(tp): + ret_type = types["return"] + fields_to_order.append( + FieldToOrder( + serialized.func.__name__, + serialized.ordering, + SerializedField( + self.aliaser(serialized.alias), + serialized.func, + is_union_of(ret_type, UndefinedType), + is_union_of(ret_type, NoneType) and self.exclude_none, + self.visit_with_conv(ret_type, serialized.conversion), + ), + ) ) - for serialized, types in get_serialized_methods(tp) - for ret_type in [types["return"]] - ] - serialization_fields = sort_by_order( # type: ignore - cls, serialization_fields, lambda f: f[0], lambda f: f[-1] - ) - field_names = {f.name for f in fields} - any_method = self.any() - exclude_unset = self.exclude_unset and support_fields_set(cls) - additional_properties = self.additional_properties and typed_dict - - def method(obj: Any) -> Any: - result = {} - for ( - name, - alias, - get_field, - required, - skip_if, - undefined, - skip_none, - skip_default, - default, - serialize_field, - _, - ) in serialization_fields: - if (not exclude_unset or name in getattr(obj, FIELDS_SET_ATTR)) and ( - not typed_dict or required or name in obj - ): - field_value = get_field(obj) - if not ( - (skip_if and skip_if(field_value)) - or (undefined and field_value is Undefined) - or (skip_none and field_value is None) - or (skip_default and field_value == default) - ): - if serialize_field: - field_value = serialize_field(field_value) - if alias: - result[alias] = field_value - else: - result.update(field_value) - if additional_properties: - assert isinstance(obj, Mapping) - for key, value in obj.items(): - if key not in field_names and isinstance(key, str): - result[key] = any_method(value) - return result + fields_to_order = sort_by_order( # type: ignore + cls, fields_to_order, lambda f: f.name, lambda f: f.ordering + ) + base_fields = tuple(f.field for f in fields_to_order) + method: SerializationMethod + if is_typed_dict(cls): + if self.additional_properties: + method = TypedDictWithAdditionalMethod( + base_fields, {f.name for f in fields}, self.any() + ) + else: + method = TypedDictMethod(base_fields) + elif self.exclude_unset and support_fields_set(cls): + method = ClassWithFieldsSetMethod(base_fields) + else: + method = ClassMethod(base_fields) return self._wrap(cls, method) def primitive(self, cls: Type) -> SerializationMethod: - return self._wrap(cls, identity) + return self._wrap(cls, IDENTITY_METHOD) def subprimitive(self, cls: Type, superclass: Type) -> SerializationMethod: if cls is AliasedStr: - return self.aliaser + return WrapperMethod(self.aliaser) else: return super().subprimitive(cls, superclass) def tuple(self, types: Sequence[AnyType]) -> SerializationMethod: - elt_serializers = list(enumerate(map(self.visit, types))) - if all(method is identity for _, method in elt_serializers): - return identity - - def method(obj: Any) -> Any: - return [serialize_elt(obj[i]) for i, serialize_elt in elt_serializers] - + elt_methods = tuple(map(self.visit, types)) + method: SerializationMethod + if all(method is IDENTITY_METHOD for method in elt_methods): + method = IDENTITY_METHOD + else: + method = TupleMethod(elt_methods) if self.check_type: - nb_elts = len(elt_serializers) - wrapped = method - fall_back_on_any, as_list = self.fall_back_on_any, self._factory(list) - - def method(obj: Any) -> Any: - if len(obj) == nb_elts: - return wrapped(obj) - elif fall_back_on_any: - return as_list(obj) - else: - raise TypeError(f"Expected {nb_elts}-tuple, found {len(obj)}-tuple") - + method = CheckedTupleMethod(len(types), method) return self._wrap(tuple, method) - def union(self, alternatives: Sequence[AnyType]) -> SerializationMethod: - methods = [] - for tp in alternatives: + def union(self, types: Sequence[AnyType]) -> SerializationMethod: + alternatives = [] + for tp in types: with suppress(Unsupported): - methods.append((self.visit(tp), *instance_checker(tp))) - # No need to catch the case with all methods being identity, - # because passthrough - if not methods: - raise Unsupported(Union[tuple(alternatives)]) # type: ignore - elif len(methods) == 1: - return methods[0][0] - elif all(method is identity for method, _, _ in methods): - return identity - elif len(methods) == 2 and NoneType in alternatives: - serialize_alt = next(meth for meth, _, arg in methods if arg is not None) - - def method(obj: Any) -> Any: - return serialize_alt(obj) if obj is not None else None - + # Do NOT use UnionAlternative here because it would erase type checking + # (forward and optional cases would then loose their type checking) + alternatives.append((expected_class(tp), self.visit(tp))) + if not alternatives: + raise Unsupported(Union[tuple(types)]) # type: ignore + elif len(alternatives) == 1: + return alternatives[0][1] + elif all(alt[1] is IDENTITY_METHOD for alt in alternatives): + return IDENTITY_METHOD + elif len(alternatives) == 2 and NoneType in types: + return OptionalMethod( + next(meth for cls, meth in alternatives if cls is not NoneType) + ) else: - fallback = self._any_fallback(Union[alternatives]) - - def method(obj: Any) -> Any: - for serialize_alt, check, arg in methods: - if check(obj, arg): - try: - return serialize_alt(obj) - except Exception: - pass - return fallback(obj) - - return method + fallback = self._any_fallback(Union[types]) + return UnionMethod(tuple(starmap(UnionAlternative, alternatives)), fallback) def unsupported(self, tp: AnyType) -> SerializationMethod: try: @@ -425,20 +386,17 @@ def _visit_conversion( dynamic: bool, next_conversion: Optional[AnyConversion], ) -> SerializationMethod: - serialize_conv = self.visit_with_conv( + conv_method = self.visit_with_conv( conversion.target, sub_conversion(conversion, next_conversion) ) converter = cast(Converter, conversion.converter) if converter is identity: - method = serialize_conv - elif serialize_conv is identity: - method = converter + method = conv_method + elif conv_method is identity: + method = METHODS.get(converter, WrapperMethod(converter)) else: - - def method(obj: Any) -> Any: - return serialize_conv(converter(obj)) - - return self._wrap(get_origin_or_type(tp), method) + method = ConversionMethod(converter, conv_method) + return self._wrap(tp, method) def visit_conversion( self, @@ -448,7 +406,7 @@ def visit_conversion( next_conversion: Optional[AnyConversion] = None, ) -> SerializationMethod: if not dynamic and self.pass_through_type(tp): - return identity + return self._wrap(tp, IDENTITY_METHOD) else: return super().visit_conversion(tp, conversion, dynamic, next_conversion) @@ -496,7 +454,7 @@ def serialization_method( exclude_unset: bool = None, fall_back_on_any: bool = None, pass_through: PassThroughOptions = None, -) -> SerializationMethod: +) -> Callable[[Any], Any]: from apischema import settings return serialization_method_factory( @@ -510,7 +468,7 @@ def serialization_method( opt_or(exclude_unset, settings.serialization.exclude_unset), opt_or(fall_back_on_any, settings.serialization.fall_back_on_any), opt_or(pass_through, settings.serialization.pass_through), - )(type) + )(type).serialize NO_OBJ = object() @@ -597,7 +555,7 @@ def serialization_default( exclude_defaults: bool = None, exclude_none: bool = None, exclude_unset: bool = None, -) -> SerializationMethod: +) -> Callable[[Any], Any]: from apischema import settings factory = serialization_method_factory( @@ -614,6 +572,6 @@ def serialization_default( ) def method(obj: Any) -> Any: - return factory(obj.__class__)(obj) + return factory(obj.__class__).serialize(obj) return method diff --git a/apischema/serialization/methods.py b/apischema/serialization/methods.py new file mode 100644 index 00000000..9936e7ef --- /dev/null +++ b/apischema/serialization/methods.py @@ -0,0 +1,371 @@ +from dataclasses import dataclass, field +from typing import AbstractSet, Any, Callable, Optional, Tuple + +from apischema.conversions.utils import Converter +from apischema.fields import FIELDS_SET_ATTR +from apischema.types import AnyType, Undefined +from apischema.utils import Lazy +from apischema.visitor import Unsupported + + +class SerializationMethod: + def serialize(self, obj: Any) -> Any: + raise NotImplementedError + + +class IdentityMethod(SerializationMethod): + def serialize(self, obj: Any) -> Any: + return obj + + +class ListMethod(SerializationMethod): + serialize = staticmethod(list) # type: ignore + + +class DictMethod(SerializationMethod): + serialize = staticmethod(dict) # type: ignore + + +class StrMethod(SerializationMethod): + serialize = staticmethod(str) # type: ignore + + +class IntMethod(SerializationMethod): + serialize = staticmethod(int) # type: ignore + + +class BoolMethod(SerializationMethod): + serialize = staticmethod(bool) # type: ignore + + +class FloatMethod(SerializationMethod): + serialize = staticmethod(float) # type: ignore + + +class NoneMethod(SerializationMethod): + def serialize(self, obj: Any) -> Any: + return None + + +@dataclass +class RecMethod(SerializationMethod): + lazy: Lazy[SerializationMethod] + method: Optional[SerializationMethod] = field(init=False) + + def __post_init__(self): + self.method = None + + def serialize(self, obj: Any) -> Any: + if self.method is None: + self.method = self.lazy() + return self.method.serialize(obj) + + +@dataclass +class AnyMethod(SerializationMethod): + factory: Callable[[AnyType], SerializationMethod] + + def serialize(self, obj: Any) -> Any: + method = self.factory(obj.__class__) # tmp variable for substitution + return method.serialize(obj) + + +class Fallback: + def fall_back(self, obj: Any) -> Any: + raise NotImplementedError + + +@dataclass +class NoFallback(Fallback): + tp: AnyType + + def fall_back(self, obj: Any) -> Any: + raise TypeError(f"Expected {self.tp}, found {obj.__class__}") + + +@dataclass +class AnyFallback(Fallback): + any_method: SerializationMethod + + def fall_back(self, obj: Any) -> Any: + return self.any_method.serialize(obj) + + +@dataclass +class TypeCheckIdentityMethod(SerializationMethod): + expected: type + fallback: Fallback + + def serialize(self, obj: Any) -> Any: + return obj if isinstance(obj, self.expected) else self.fallback.fall_back(obj) + + +@dataclass +class TypeCheckMethod(TypeCheckIdentityMethod): + method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + if isinstance(obj, self.expected): + try: + return self.method.serialize(obj) + except Unsupported: + raise + except Exception: + pass + return self.fallback.fall_back(obj) + + +@dataclass +class CollectionMethod(SerializationMethod): + value_method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + return [self.value_method.serialize(elt) for elt in obj] + + +class ValueMethod(SerializationMethod): + def serialize(self, obj: Any) -> Any: + return obj.value + + +@dataclass +class EnumMethod(SerializationMethod): + any_method: AnyMethod + + def serialize(self, obj: Any) -> Any: + return self.any_method.serialize(obj.value) + + +@dataclass +class MappingMethod(SerializationMethod): + key_method: SerializationMethod + value_method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + return { + self.key_method.serialize(key): self.value_method.serialize(value) + for key, value in obj.items() + } + + +class BaseField: + def update_result( + self, obj: Any, result: dict, typed_dict: bool, exclude_unset: bool + ): + raise NotImplementedError + + +@dataclass +class IdentityField(BaseField): + name: str + alias: str + required: bool + + def update_result( + self, obj: Any, result: dict, typed_dict: bool, exclude_unset: bool + ): + if serialize_field(self, obj, typed_dict, exclude_unset): + result[self.alias] = get_field_value(self, obj, typed_dict) + + +def serialize_field( + field: IdentityField, obj: Any, typed_dict: bool, exclude_unset: bool +) -> bool: + if typed_dict: + return field.required or field.name in obj + else: + return not exclude_unset or field.name in getattr(obj, FIELDS_SET_ATTR) + + +def get_field_value(field: IdentityField, obj: Any, typed_dict: bool) -> object: + return obj[field.name] if typed_dict else getattr(obj, field.name) + + +@dataclass +class SimpleField(IdentityField): + method: SerializationMethod + + def update_result( + self, obj: Any, result: dict, typed_dict: bool, exclude_unset: bool + ): + if serialize_field(self, obj, typed_dict, exclude_unset): + result[self.alias] = self.method.serialize( + get_field_value(self, obj, typed_dict) + ) + + +@dataclass +class ComplexField(SimpleField): + skip_if: Optional[Callable] + undefined: bool + skip_none: bool + skip_default: bool + default_value: Any # https://github.com/cython/cython/issues/4383 + skippable: bool = field(init=False) + + def __post_init__(self): + self.skippable = ( + self.skip_if or self.undefined or self.skip_none or self.skip_default + ) + + def update_result( + self, obj: Any, result: dict, typed_dict: bool, exclude_unset: bool + ): + if serialize_field(self, obj, typed_dict, exclude_unset): + value: object = get_field_value(self, obj, typed_dict) + if not self.skippable or not ( + (self.skip_if is not None and self.skip_if(value)) + or (self.undefined and value is Undefined) + or (self.skip_none and value is None) + or (self.skip_default and value == self.default_value) + ): + if self.alias is not None: + result[self.alias] = self.method.serialize(value) + else: + result.update(self.method.serialize(value)) + + +@dataclass +class SerializedField(BaseField): + alias: str + func: Callable[[Any], Any] + undefined: bool + skip_none: bool + method: SerializationMethod + + def update_result( + self, obj: Any, result: dict, typed_dict: bool, exclude_unset: bool + ): + value = self.func(obj) + if not (self.undefined and value is Undefined) and not ( + self.skip_none and value is None + ): + result[self.alias] = self.method.serialize(value) + + +@dataclass +class ObjectMethod(SerializationMethod): + fields: Tuple[BaseField, ...] + + +@dataclass +class ClassMethod(ObjectMethod): + def serialize(self, obj: Any) -> Any: + result: dict = {} + for i in range(len(self.fields)): + field: BaseField = self.fields[i] + field.update_result(obj, result, False, False) + return result + + +@dataclass +class ClassWithFieldsSetMethod(ObjectMethod): + def serialize(self, obj: Any) -> Any: + result: dict = {} + for i in range(len(self.fields)): + field: BaseField = self.fields[i] + field.update_result(obj, result, False, True) + return result + + +@dataclass +class TypedDictMethod(ObjectMethod): + def serialize(self, obj: Any) -> Any: + result: dict = {} + for i in range(len(self.fields)): + field: BaseField = self.fields[i] + field.update_result(obj, result, True, False) + return result + + +@dataclass +class TypedDictWithAdditionalMethod(TypedDictMethod): + field_names: AbstractSet[str] + any_method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + result: dict = super().serialize(obj) + for key, value in obj.items(): + if key not in self.field_names and isinstance(key, str): + result[str(key)] = self.any_method.serialize(value) + return result + + +@dataclass +class TupleMethod(SerializationMethod): + elt_methods: Tuple[SerializationMethod, ...] + + def serialize(self, obj: tuple) -> Any: + elts: list = [] + for i in range(len(self.elt_methods)): + method: SerializationMethod = self.elt_methods[i] + elts.append(method.serialize(obj[i])) + return elts + + +@dataclass +class CheckedTupleMethod(SerializationMethod): + nb_elts: int + method: SerializationMethod + + def serialize(self, obj: tuple) -> Any: + if not len(obj) == self.nb_elts: + raise TypeError(f"Expected {self.nb_elts}-tuple, found {len(obj)}-tuple") + return self.method.serialize(obj) + + +# There is no need of an OptionalIdentityMethod because it would mean that all methods +# are IdentityMethod, which gives IdentityMethod. + + +@dataclass +class OptionalMethod(SerializationMethod): + value_method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + return self.value_method.serialize(obj) if obj is not None else None + + +@dataclass +class UnionAlternative: + cls: type + method: SerializationMethod + + def __post_init__(self): + if isinstance(self.method, TypeCheckMethod): + self.method = self.method.method + elif isinstance(self.method, TypeCheckIdentityMethod): + self.method = IdentityMethod() + + +@dataclass +class UnionMethod(SerializationMethod): + alternatives: Tuple[UnionAlternative, ...] + fallback: Fallback + + def serialize(self, obj: Any) -> Any: + for i in range(len(self.alternatives)): + alternative: UnionAlternative = self.alternatives[i] + if isinstance(obj, alternative.cls): + try: + return alternative.method.serialize(obj) + except Exception: + pass + self.fallback.fall_back(obj) + + +@dataclass +class WrapperMethod(SerializationMethod): + wrapped: Callable[[Any], Any] + + def serialize(self, obj: Any) -> Any: + return self.wrapped(obj) + + +@dataclass +class ConversionMethod(SerializationMethod): + converter: Converter + method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + return self.method.serialize(self.converter(obj)) diff --git a/apischema/settings.py b/apischema/settings.py index c4a26df2..ee6168fa 100644 --- a/apischema/settings.py +++ b/apischema/settings.py @@ -66,6 +66,30 @@ class base_schema: ] = lambda *_: None type: Callable[[AnyType], Optional[Schema]] = lambda *_: None + class errors: + minimum: str = "less than {constraint} (minimum)" + maximum: str = "greater than {constraint} (maximum)" + exclusive_minimum: str = "less than or equal to {constraint} (exclusiveMinimum)" + exclusive_maximum: str = ( + "greater than or equal to {constraint} (exclusiveMinimum)" + ) + multiple_of: str = "not a multiple of {constraint} (multipleOf)" + + min_length: str = "string length lower than {constraint} (minLength)" + max_length: str = "string length greater than {constraint} (maxLength)" + pattern: str = 'not matching pattern "{constraint}" (pattern)' + + min_items: str = "item count lower than {constraint} (minItems)" + max_items: str = "item count greater than {constraint} (maxItems)" + unique_items: str = "duplicate items (uniqueItems)" + + min_properties: str = "property count lower than {constraint} (minProperties)" + max_properties: str = "property count greater than {constraint} (maxProperties)" + + one_of: str = "not one of {constraint} (oneOf)" + unexpected_property: str = "unexpected property" + missing_property: str = "missing property" + class deserialization(metaclass=ResetCache): coerce: bool = False coercer: Coercer = coerce_ diff --git a/apischema/utils.py b/apischema/utils.py index a10eeafa..1b9f42cc 100644 --- a/apischema/utils.py +++ b/apischema/utils.py @@ -17,7 +17,6 @@ Container, Dict, Generic, - Hashable, Iterable, Iterator, List, @@ -34,13 +33,7 @@ cast, ) -from apischema.types import ( - AnyType, - COLLECTION_TYPES, - MAPPING_TYPES, - OrderedDict, - PRIMITIVE_TYPES, -) +from apischema.types import AnyType, COLLECTION_TYPES, MAPPING_TYPES, PRIMITIVE_TYPES from apischema.typing import ( _collect_type_vars, generic_mro, @@ -94,16 +87,8 @@ def opt_or(opt: Optional[T], default: U) -> Union[T, U]: return opt if opt is not None else default -def to_hashable(data: Union[None, int, float, str, bool, list, dict]) -> Hashable: - if isinstance(data, list): - return tuple(map(to_hashable, data)) - if isinstance(data, dict): - return tuple(sorted((to_hashable(k), to_hashable(v)) for k, v in data.items())) - return data # type: ignore - - SNAKE_CASE_REGEX = re.compile(r"_([a-z\d])") -CAMEL_CASE_REGEX = re.compile(r"[a-z\d]([A-Z])") +CAMEL_CASE_REGEX = re.compile(r"([a-z\d])([A-Z])") def to_camel_case(s: str) -> str: @@ -111,7 +96,7 @@ def to_camel_case(s: str) -> str: def to_snake_case(s: str) -> str: - return CAMEL_CASE_REGEX.sub(lambda m: "_" + m.group(1).lower(), s) + return CAMEL_CASE_REGEX.sub(lambda m: m.group(1) + "_" + m.group(2).lower(), s) def to_pascal_case(s: str) -> str: @@ -119,7 +104,13 @@ def to_pascal_case(s: str) -> str: return camel[0].upper() + camel[1:] if camel else camel -MakeDataclassField = Union[Tuple[str, AnyType], Tuple[str, AnyType, Any]] +class PartialFormatter(dict): + def __missing__(self, key): + return "{%s}" % key + + +def partial_format(s: str, **kwargs) -> str: + return s.format_map(PartialFormatter(kwargs)) def merge_opts( @@ -260,16 +251,6 @@ def replace_builtins(tp: AnyType) -> AnyType: return keep_annotations(res, tp) -def sort_by_annotations_position( - cls: Type, elts: Collection[T], key: Callable[[T], str] -) -> List[T]: - annotations: Dict[str, Any] = OrderedDict() - for base in reversed(cls.__mro__): - annotations.update(getattr(base, "__annotations__", ())) - positions = {key: i for i, key in enumerate(annotations)} - return sorted(elts, key=lambda elt: positions.get(key(elt), len(positions))) - - def stop_signature_abuse() -> NoReturn: raise TypeError("Stop signature abuse") diff --git a/apischema/visitor.py b/apischema/visitor.py index 951250e0..eee88e2f 100644 --- a/apischema/visitor.py +++ b/apischema/visitor.py @@ -157,7 +157,7 @@ def typed_dict( ) -> Result: raise NotImplementedError - def union(self, alternatives: Sequence[AnyType]) -> Result: + def union(self, types: Sequence[AnyType]) -> Result: raise NotImplementedError def unsupported(self, tp: AnyType) -> Result: diff --git a/docs/performance_and_benchmark.md b/docs/performance_and_benchmark.md index e075ecf2..8e85360d 100644 --- a/docs/performance_and_benchmark.md +++ b/docs/performance_and_benchmark.md @@ -25,7 +25,7 @@ However, if `lru_cache` is fast, using the methods directly is faster, so *apisc JSON serialization libraries expect primitive data types (`dict`/`list`/`str`/etc.). A non-negligible part of objects to be serialized are primitive. -When [type checking](#type-checking) is disabled (this is default), objects annotated with primitive types doesn't need to be transformed or checked; *apischema* can simply "pass through" them, and it will result into an identity serialization method. +When [type checking](#type-checking) is disabled (this is default), objects annotated with primitive types doesn't need to be transformed or checked; *apischema* can simply "pass through" them, and it will result into an identity serialization method, just returning its argument. Container types like `list` or `dict` are passed through only when the contained types are passed through too. diff --git a/examples/pass_through.py b/examples/pass_through.py index e4776fcf..15adc487 100644 --- a/examples/pass_through.py +++ b/examples/pass_through.py @@ -1,10 +1,10 @@ from collections.abc import Collection -from uuid import UUID +from uuid import UUID, uuid4 from apischema import PassThroughOptions, serialization_method -from apischema.conversions import identity uuids_method = serialization_method( Collection[UUID], pass_through=PassThroughOptions(collections=True, types={UUID}) ) -assert uuids_method == identity +uuids = [uuid4() for _ in range(5)] +assert uuids_method(uuids) is uuids diff --git a/examples/pass_through_primitives.py b/examples/pass_through_primitives.py index 98eab911..0a27c583 100644 --- a/examples/pass_through_primitives.py +++ b/examples/pass_through_primitives.py @@ -1,3 +1,4 @@ -from apischema import identity, serialization_method +from apischema import serialize -assert serialization_method(list[int]) == identity +ints = list(range(5)) +assert serialize(list[int], ints) is ints diff --git a/examples/validation_error.py b/examples/validation_error.py index 21df1cc1..bd21d8a4 100644 --- a/examples/validation_error.py +++ b/examples/validation_error.py @@ -27,6 +27,6 @@ class Resource: assert err.value.errors == [ {"loc": ["tags"], "msg": "item count greater than 3 (maxItems)"}, {"loc": ["tags"], "msg": "duplicate items (uniqueItems)"}, - {"loc": ["tags", 3], "msg": "not matching '^\\w*$' (pattern)"}, + {"loc": ["tags", 3], "msg": 'not matching pattern "^\\w*$" (pattern)'}, {"loc": ["tags", 4], "msg": "string length lower than 3 (minLength)"}, ] diff --git a/scripts/generate_pyx.py b/scripts/generate_pyx.py new file mode 100755 index 00000000..ffce198b --- /dev/null +++ b/scripts/generate_pyx.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +import collections.abc +import dataclasses +import importlib +import inspect +import re +import sys +from contextlib import contextmanager +from pathlib import Path +from types import FunctionType +from typing import ( + AbstractSet, + Any, + Dict, + Iterable, + Mapping, + Match, + Optional, + Sequence, + TextIO, + Tuple, + get_type_hints, +) + +try: + from typing import Literal + + CythonDef = Literal["cdef", "cpdef", "cdef inline"] +except ImportError: + CythonDef = str # type: ignore + + +ROOT_DIR = Path(__file__).parent.parent +TYPE_FIELD = "_type" + + +def remove_prev_compilation(package: str): + for ext in ["so", "pyd"]: + for file in (ROOT_DIR / "apischema" / package).glob(f"**/*.{ext}"): + file.unlink() + + +cython_types_mapping = { + type: "type", + bytes: "bytes", + bytearray: "bytearray", + bool: "bint", + str: "str", + tuple: "tuple", + Tuple: "tuple", + list: "list", + int: "long", + dict: "dict", + Mapping: "dict", + collections.abc.Mapping: "dict", + set: "set", + AbstractSet: "set", + collections.abc.Set: "set", +} + + +def cython_type(tp: Any) -> str: + return cython_types_mapping.get(getattr(tp, "__origin__", tp), "object") + + +def cython_signature( + def_type: CythonDef, func: FunctionType, self_type: Optional[type] = None +) -> str: + parameters = list(inspect.signature(func).parameters.values()) + assert all(p.default is inspect.Parameter.empty for p in parameters) + types = get_type_hints(func) + param_with_types = [] + prefix = "" + if parameters[0].name == "self": + if self_type is not None: + types["self"] = self_type + prefix = self_type.__name__ + "_" + else: + param_with_types.append("self") + parameters.pop(0) + for param in parameters: + param_with_types.append(cython_type(types[param.name]) + " " + param.name) + return f"{def_type} {prefix}{func.__name__}(" + ", ".join(param_with_types) + "):" + + +class IndentedFile: + def __init__(self, file: TextIO): + self.file = file + self.indentation = "" + + def write(self, txt: str): + self.file.write(txt) + + def writelines(self, lines: Iterable[str]): + self.file.writelines(lines) + + def writeln(self, txt: str = ""): + self.write((self.indentation + txt + "\n") if txt else "\n") + + @contextmanager + def indent(self): + self.indentation += 4 * " " + yield + self.indentation = self.indentation[:-4] + + @contextmanager + def write_block(self, txt: str): + self.writeln(txt) + with self.indent(): + yield + + +def rec_subclasses(cls: type) -> Iterable[type]: + for sub_cls in cls.__subclasses__(): + yield sub_cls + yield from rec_subclasses(sub_cls) + + +def get_body( + func: FunctionType, + switches: Mapping[str, Tuple[type, FunctionType]], + cls: Optional[type] = None, +) -> Iterable[str]: + lines, _ = inspect.getsourcelines(func) + line_iter = iter(lines) + for line in line_iter: + if line.rstrip().endswith(":"): + break + else: + raise NotImplementedError + for line in line_iter: + if cls is not None: + + def replace_super(match: Match): + assert cls is not None + super_cls = cls.__bases__[0].__name__ + return f"{super_cls}_{match.group(1)}(<{super_cls}>self, " + + line = re.sub(r"super\(\).(\w+)\(", replace_super, line) + names = "|".join(switches) + + def sub(match: Match): + self, name = match.groups() + cls, _ = switches[name] + return f"{cls.__name__}_{name}({self}, " + + yield re.sub(rf"([\w\.]+)\.({names})\(", sub, line) + + +def get_fields(cls: type) -> Sequence[dataclasses.Field]: + return dataclasses.fields(cls) if dataclasses.is_dataclass(cls) else () + + +def generate(package: str): + module = importlib.import_module(f"apischema.{package}.methods") + classes = [ + cls + for cls in module.__dict__.values() + if isinstance(cls, type) and cls.__module__ == module.__name__ + ] + for cls in classes: + cython_types_mapping[cls] = cls.__name__ + cython_types_mapping[Optional[cls]] = cls.__name__ + if sys.version_info >= (3, 10): + cython_types_mapping[cls | None] = cls.__name__ + subclass_type: Dict[type, int] = {} + switches = {} + with open(ROOT_DIR / "apischema" / package / "methods.pyx", "w") as pyx_file: + pyx = IndentedFile(pyx_file) + pyx.write("cimport cython\n") + with open(ROOT_DIR / "apischema" / package / "methods.py") as methods_file: + for line in methods_file: + if ( + line.startswith("from ") + or line.startswith("import ") + or line.startswith(" ") + or line.startswith(")") + or not line.strip() + ): + pyx.write(line) + else: + break + for cls in classes: + class_def = f"cdef class {cls.__name__}" + if cls.__bases__ != (object,): + bases = ", ".join(base.__name__ for base in cls.__bases__) + class_def += f"({bases})" + with pyx.write_block(class_def + ":"): + pyx.writeln("pass") + write_init = cls in subclass_type + for field in get_fields(cls): + if field.name in cls.__dict__.get("__annotations__", ()): + write_init = True + pyx.writeln( + f"cdef readonly {cython_type(field.type)} {field.name}" + ) + pyx.writeln() + if write_init: + init_fields = [ + field.name for field in get_fields(cls) if field.init + ] + with pyx.write_block( + "def __init__(" + ", ".join(["self"] + init_fields) + "):" + ): + for name in init_fields: + pyx.writeln(f"self.{name} = {name}") + if hasattr(cls, "__post_init__"): + lines, _ = inspect.getsourcelines(cls.__post_init__) # type: ignore + pyx.writelines(lines[1:]) + if cls in subclass_type: + pyx.writeln(f"self.{TYPE_FIELD} = {subclass_type[cls]}") + pyx.writeln() + if cls.__bases__ == (object,): + if cls.__subclasses__(): + for i, sub_cls in enumerate(rec_subclasses(cls)): + subclass_type[sub_cls] = i + pyx.writeln(f"cdef int {TYPE_FIELD}") + for name, obj in cls.__dict__.items(): + if isinstance(obj, FunctionType) and not name.startswith("_"): + assert name not in switches + switches[name] = (cls, obj) + with pyx.write_block(cython_signature("cpdef", obj)): + pyx.writeln("raise NotImplementedError") + pyx.writeln() + else: + for name, obj in cls.__dict__.items(): + if ( + isinstance(obj, (FunctionType, staticmethod)) + and name in switches + ): + _, base_method = switches[name] + with pyx.write_block( + cython_signature("cpdef", base_method) + ): + args = ", ".join( + inspect.signature(base_method).parameters + ) + pyx.writeln(f"return {cls.__name__}_{name}({args})") + pyx.writeln() + + for cls, method in switches.values(): + for i, sub_cls in enumerate(rec_subclasses(cls)): + if method.__name__ in sub_cls.__dict__: + sub_method = sub_cls.__dict__[method.__name__] + if isinstance(sub_method, staticmethod): + with pyx.write_block( + cython_signature("cdef inline", method, sub_cls) + ): + _, param = inspect.signature(method).parameters + func = sub_method.__get__(None, object) + pyx.writeln(f"return {func.__name__}({param})") + else: + with pyx.write_block( + cython_signature("cdef inline", sub_method, sub_cls) + ): + pyx.writelines(get_body(sub_method, switches, sub_cls)) + pyx.writeln() + for cls, method in switches.values(): + with pyx.write_block(cython_signature("cdef", method, cls)): + pyx.writeln(f"cdef int {TYPE_FIELD} = self.{TYPE_FIELD}") + for i, sub_cls in enumerate(rec_subclasses(cls)): + if method.__name__ in sub_cls.__dict__: + if_ = "if" if i == 0 else "elif" + with pyx.write_block(f"{if_} {TYPE_FIELD} == {i}:"): + self, *params = inspect.signature(method).parameters + args = ", ".join([f"<{sub_cls.__name__}>{self}", *params]) + pyx.writeln( + f"return {sub_cls.__name__}_{method.__name__}({args})" + ) + pyx.writeln() + for obj in module.__dict__.values(): + if isinstance(obj, FunctionType) and obj.__module__ == module.__name__: + pyx.writeln(cython_signature("cdef inline", obj)) + pyx.writelines(get_body(obj, switches)) + pyx.writeln() + + +packages = ["deserialization", "serialization"] + + +def clean(): + for package in packages: + remove_prev_compilation(package) + + +def main(): + clean() # remove all before generate, because .so would be imported otherwise + sys.path.append(str(ROOT_DIR)) + for package in packages: + generate(package) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 24faf841..b22dc07c 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,32 @@ +import os +import sys + from setuptools import find_packages, setup -with open("README.md") as f: - README = f.read() +import scripts.generate_pyx + +README = None +# README cannot be read by older python version run by tox +if "TOX_ENV_NAME" not in os.environ: + with open("README.md") as f: + README = f.read() + + +ext_modules = None +if "clean" in sys.argv: + scripts.generate_pyx.clean() +if ( + not any(arg in sys.argv for arg in ["clean", "check"]) + and "SKIP_CYTHON" not in os.environ +): + try: + from Cython.Build import cythonize + except ImportError: + pass + else: + scripts.generate_pyx.main() + os.environ["CFLAGS"] = "-O3 " + os.environ.get("CFLAGS", "") + ext_modules = cythonize(["apischema/**/*.pyx"], language_level=3) setup( name="apischema", @@ -16,7 +41,7 @@ long_description=README, long_description_content_type="text/markdown", python_requires=">=3.6", - install_requires=["dataclasses==0.7;python_version<'3.7'"], + install_requires=["dataclasses>=0.7;python_version<'3.7'"], extras_require={ "graphql": ["graphql-core>=3.1.2"], "examples": [ @@ -41,4 +66,7 @@ "Programming Language :: Python :: 3.10", "Topic :: Software Development :: Libraries :: Python Modules", ], + ext_modules=ext_modules + # ext_modules=cythonize("apischema/deserialization/methods.py", language_level=3), + # ext_modules=cythonize(["cythonized.pyx", "cythonized2.py"], language_level=3), ) diff --git a/tests/requirements.txt b/tests/requirements.txt index 8fb6ee51..6749ccd3 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -9,3 +9,4 @@ pytest-cov pytest-asyncio sqlalchemy typing_extensions +cython diff --git a/tests/test_deserialization_methods.py b/tests/test_deserialization_methods.py new file mode 100644 index 00000000..06b8c1c8 --- /dev/null +++ b/tests/test_deserialization_methods.py @@ -0,0 +1,8 @@ +from apischema.deserialization.methods import to_hashable + + +def test_to_hashable(): + hashable1 = to_hashable({"key1": 0, "key2": [1, 2]}) + hashable2 = to_hashable({"key2": [1, 2], "key1": 0}) + assert hashable1 == hashable2 + assert hash(hashable1) == hash(hashable2) diff --git a/tests/test_utils.py b/tests/test_utils.py index 520deee9..1b39cdb5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -27,18 +27,10 @@ is_async, replace_builtins, to_camel_case, - to_hashable, type_dict_wrapper, ) -def test_to_hashable(): - hashable1 = to_hashable({"key1": 0, "key2": [1, 2]}) - hashable2 = to_hashable({"key2": [1, 2], "key1": 0}) - assert hashable1 == hashable2 - assert hash(hashable1) == hash(hashable2) - - def test_to_camel_case(): assert to_camel_case("min_length") == "minLength" diff --git a/tox.ini b/tox.ini index bf5ab8ea..253a01f9 100644 --- a/tox.ini +++ b/tox.ini @@ -27,11 +27,19 @@ exclude_lines = deps = -r tests/requirements.txt +allowlist_externals = which + commands = + which pytest + python3 setup.py clean python3 scripts/generate_tests_from_examples.py - py{36,py3}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py - py3{7,8,9}: pytest tests +; py{36}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py + py{py3,37,38,39}: pytest tests py310: pytest tests --cov=apischema --cov-report html + python3 setup.py install + py{36}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py + py{py3,37,38,39,310}: pytest tests + [testenv:static] deps = From aa501677b95848151b1e764b9b3597f989f7a950 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Sun, 31 Oct 2021 22:00:28 +0100 Subject: [PATCH 02/15] Fix Pypy pipeline --- setup.py | 4 +--- tox.ini | 6 +++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index b22dc07c..024e367d 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,5 @@ "Programming Language :: Python :: 3.10", "Topic :: Software Development :: Libraries :: Python Modules", ], - ext_modules=ext_modules - # ext_modules=cythonize("apischema/deserialization/methods.py", language_level=3), - # ext_modules=cythonize(["cythonized.pyx", "cythonized2.py"], language_level=3), + ext_modules=ext_modules, ) diff --git a/tox.ini b/tox.ini index 253a01f9..c850d839 100644 --- a/tox.ini +++ b/tox.ini @@ -33,12 +33,12 @@ commands = which pytest python3 setup.py clean python3 scripts/generate_tests_from_examples.py -; py{36}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py + py{36,py3}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py py{py3,37,38,39}: pytest tests py310: pytest tests --cov=apischema --cov-report html python3 setup.py install - py{36}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py - py{py3,37,38,39,310}: pytest tests + py{36,py3}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py + py{37,38,39,310}: pytest tests [testenv:static] From 24d7289778c5158b4839db789e159bd4c235e310 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Sun, 31 Oct 2021 23:21:14 +0100 Subject: [PATCH 03/15] Disable cythonization for pypy --- setup.py | 2 ++ tox.ini | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 024e367d..6747998a 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import os +import platform import sys from setuptools import find_packages, setup @@ -18,6 +19,7 @@ if ( not any(arg in sys.argv for arg in ["clean", "check"]) and "SKIP_CYTHON" not in os.environ + and platform.python_implementation() != "PyPy" ): try: from Cython.Build import cythonize diff --git a/tox.ini b/tox.ini index c850d839..8c1902ba 100644 --- a/tox.ini +++ b/tox.ini @@ -34,7 +34,7 @@ commands = python3 setup.py clean python3 scripts/generate_tests_from_examples.py py{36,py3}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py - py{py3,37,38,39}: pytest tests + py{37,38,39}: pytest tests py310: pytest tests --cov=apischema --cov-report html python3 setup.py install py{36,py3}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py From cae5103e05343ddf34eb77fe58e68b532ab7b5f7 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Mon, 1 Nov 2021 14:33:55 +0100 Subject: [PATCH 04/15] Refactor generate_pyx.py -> cythonize.py --- apischema/deserialization/__init__.py | 2 +- apischema/deserialization/methods.py | 9 +- scripts/cythonize.py | 328 ++++++++++++++++++++++++++ scripts/generate_pyx.py | 294 ----------------------- setup.py | 29 +-- 5 files changed, 341 insertions(+), 321 deletions(-) create mode 100755 scripts/cythonize.py delete mode 100755 scripts/generate_pyx.py diff --git a/apischema/deserialization/__init__.py b/apischema/deserialization/__init__.py index 35fe958a..9bc78264 100644 --- a/apischema/deserialization/__init__.py +++ b/apischema/deserialization/__init__.py @@ -187,7 +187,7 @@ def _recursive_result( def factory( constraints: Optional[Constraints], validators: Sequence[Validator] ) -> DeserializationMethod: - return RecMethod(lazy, constraints, validators) + return RecMethod(lambda: lazy().merge(constraints, validators).method) return DeserializationMethodFactory(factory) diff --git a/apischema/deserialization/methods.py b/apischema/deserialization/methods.py index e9d00174..b56728fb 100644 --- a/apischema/deserialization/methods.py +++ b/apischema/deserialization/methods.py @@ -15,7 +15,6 @@ from apischema.conversions.utils import Converter from apischema.deserialization.coercion import Coercer from apischema.json_schema.types import bad_type -from apischema.schemas.constraints import Constraints from apischema.types import NoneType from apischema.utils import Lazy from apischema.validation.errors import ValidationError, merge_errors @@ -24,7 +23,7 @@ from apischema.visitor import Unsupported if TYPE_CHECKING: - from apischema.deserialization import DeserializationMethodFactory + pass @dataclass @@ -170,9 +169,7 @@ def deserialize(self, data: Any) -> Any: @dataclass class RecMethod(DeserializationMethod): - lazy: Lazy["DeserializationMethodFactory"] - constraints: Optional[Constraints] - validators: Sequence[Validator] + lazy: Lazy[DeserializationMethod] method: Optional[DeserializationMethod] = field(init=False) def __post_init__(self): @@ -180,7 +177,7 @@ def __post_init__(self): def deserialize(self, data: Any) -> Any: if self.method is None: - self.method = self.lazy().merge(self.constraints, self.validators).method + self.method = self.lazy() return self.method.deserialize(data) diff --git a/scripts/cythonize.py b/scripts/cythonize.py new file mode 100755 index 00000000..2b78a1e7 --- /dev/null +++ b/scripts/cythonize.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +import collections.abc +import dataclasses +import importlib +import inspect +import re +import sys +from contextlib import contextmanager +from functools import lru_cache +from pathlib import Path +from types import FunctionType +from typing import ( + AbstractSet, + Any, + Iterable, + List, + Mapping, + Match, + NamedTuple, + Optional, + TextIO, + Tuple, + Type, + TypeVar, + Union, + get_type_hints, +) + +from Cython.Build import cythonize + +try: + from typing import Literal + + CythonDef = Literal["cdef", "cpdef", "cdef inline"] +except ImportError: + CythonDef = str # type: ignore + + +ROOT_DIR = Path(__file__).parent.parent +DISPATCH_FIELD = "_dispatch" +CYTHON_TYPES = { + type: "type", + bytes: "bytes", + bytearray: "bytearray", + bool: "bint", + str: "str", + tuple: "tuple", + Tuple: "tuple", + list: "list", + int: "long", + dict: "dict", + Mapping: "dict", + collections.abc.Mapping: "dict", + set: "set", + AbstractSet: "set", + collections.abc.Set: "set", +} + +Elt = TypeVar("Elt", type, FunctionType) + + +@lru_cache() +def module_elements(module: str, cls: Type[Elt]) -> Iterable[Elt]: + return [ + obj + for obj in importlib.import_module(module).__dict__.values() + if isinstance(obj, cls) and obj.__module__ == module + ] + + +@lru_cache() +def module_type_mapping(module: str) -> Mapping[type, str]: + mapping = CYTHON_TYPES.copy() + for cls in module_elements(module, type): + mapping[cls] = cls.__name__ # type: ignore + mapping[Optional[cls]] = cls.__name__ # type: ignore + if sys.version_info >= (3, 10): + mapping[cls | None] = cls.__name__ # type: ignore + return mapping # type: ignore + + +def method_name(cls: type, method: str) -> str: + return f"{cls.__name__}_{method}" + + +def cython_type(tp: Any, module: str) -> str: + return module_type_mapping(module).get(getattr(tp, "__origin__", tp), "object") + + +def cython_signature( + def_type: CythonDef, func: FunctionType, self_type: Optional[type] = None +) -> str: + parameters = list(inspect.signature(func).parameters.values()) + assert all(p.default is inspect.Parameter.empty for p in parameters) + types = get_type_hints(func) + param_with_types = [] + if parameters[0].name == "self": + if self_type is not None: + types["self"] = self_type + else: + param_with_types.append("self") + parameters.pop(0) + for param in parameters: + param_type = cython_type(types[param.name], func.__module__) + param_with_types.append(f"{param_type} {param.name}") + func_name = method_name(self_type, func.__name__) if self_type else func.__name__ + return f"{def_type} {func_name}(" + ", ".join(param_with_types) + "):" + + +class IndentedWriter: + def __init__(self, file: TextIO): + self.file = file + self.indentation = "" + + def write(self, txt: str): + self.file.write(txt) + + def writelines(self, lines: Iterable[str]): + self.file.writelines(lines) + + def writeln(self, txt: str = ""): + self.write((self.indentation + txt + "\n") if txt else "\n") + + @contextmanager + def indent(self): + self.indentation += 4 * " " + yield + self.indentation = self.indentation[:-4] + + @contextmanager + def write_block(self, txt: str): + self.writeln(txt) + with self.indent(): + yield + + +def rec_subclasses(cls: type) -> Iterable[type]: + for sub_cls in cls.__subclasses__(): + yield sub_cls + yield from rec_subclasses(sub_cls) + + +@lru_cache +def get_dispatch(base_class: type) -> Mapping[type, int]: + return {cls: i for i, cls in enumerate(rec_subclasses(base_class))} + + +class Method(NamedTuple): + base_class: type + function: FunctionType + + @property + def name(self) -> str: + return self.function.__name__ + + +@lru_cache() +def module_methods(module: str) -> Mapping[str, Method]: + all_methods = [ + Method(cls, func) # type: ignore + for cls in module_elements(module, type) + if cls.__bases__ == (object,) and cls.__subclasses__() # type: ignore + for func in cls.__dict__.values() + if isinstance(func, FunctionType) and not func.__name__.startswith("_") + ] + methods_by_name = {method.name: method for method in all_methods} + assert len(methods_by_name) == len( + all_methods + ), "method substitution requires unique method names" + return methods_by_name + + +def get_body(func: FunctionType, cls: Optional[type] = None) -> Iterable[str]: + lines, _ = inspect.getsourcelines(func) + line_iter = iter(lines) + for line in line_iter: + if line.rstrip().endswith(":"): + break + else: + raise NotImplementedError + if cls is not None: + + def replace_super(match: Match): + assert cls is not None + super_cls = cls.__bases__[0].__name__ + return f"{super_cls}_{match.group(1)}(<{super_cls}>self, " + + super_regex = re.compile(r"super\(\).(\w+)\(") + line_iter = (super_regex.sub(replace_super, line) for line in line_iter) + methods = module_methods(func.__module__) + + def replace_method(match: Match): + self, name = match.groups() + cls, _ = methods[name] + return f"{cls.__name__}_{name}({self}, " + + method_names = "|".join(methods) + method_regex = re.compile(rf"([\w\.]+)\.({method_names})\(") + return (method_regex.sub(replace_method, line) for line in line_iter) + + +def import_lines(path: Union[str, Path]) -> Iterable[str]: + # could also be retrieved with ast + with open(path) as field: + for line in field: + if not line.strip() or any( + # " " and ")" because of multiline imports + map(line.startswith, ("from ", "import ", " ", ")")) + ): + yield line + else: + break + + +def write_class(pyx: IndentedWriter, cls: type): + bases = ", ".join(b.__name__ for b in cls.__bases__ if b is not object) + with pyx.write_block(f"cdef class {cls.__name__}({bases}):"): + annotations = cls.__dict__.get("__annotations__", {}) + for name, tp in get_type_hints(cls).items(): + if name in annotations: + pyx.writeln(f"cdef readonly {cython_type(tp, cls.__module__)} {name}") + dispatch = None + if cls.__bases__ == (object,): + if cls.__subclasses__(): + pyx.writeln(f"cdef int {DISPATCH_FIELD}") + else: + base_class = cls.__mro__[-2] + dispatch = get_dispatch(base_class)[cls] + for name, obj in cls.__dict__.items(): + if ( + not name.startswith("_") + and name not in annotations + and isinstance(obj, (FunctionType, staticmethod)) + ): + pyx.writeln() + base_method = getattr(base_class, name) + with pyx.write_block(cython_signature("cpdef", base_method)): + args = ", ".join(inspect.signature(base_method).parameters) + pyx.writeln(f"return {cls.__name__}_{name}({args})") + if annotations or dispatch is not None: + pyx.writeln() + init_fields: List[str] = [] + if dataclasses.is_dataclass(cls): + init_fields.extend( + field.name for field in dataclasses.fields(cls) if field.init + ) + with pyx.write_block( + "def __init__(" + ", ".join(["self"] + init_fields) + "):" + ): + for name in init_fields: + pyx.writeln(f"self.{name} = {name}") + if hasattr(cls, "__post_init__"): + lines, _ = inspect.getsourcelines(cls.__post_init__) # type: ignore + pyx.writelines(lines[1:]) + if dispatch is not None: + pyx.writeln(f"self.{DISPATCH_FIELD} = {dispatch}") + + +def write_function(pyx: IndentedWriter, func: FunctionType): + pyx.writeln(cython_signature("cdef inline", func)) + pyx.writelines(get_body(func)) + + +def write_methods(pyx: IndentedWriter, method: Method): + for cls, dispatch in get_dispatch(method.base_class).items(): + if method.name in cls.__dict__: + sub_method = cls.__dict__[method.name] + if isinstance(sub_method, staticmethod): + with pyx.write_block( + cython_signature("cdef inline", method.function, cls) # type: ignore + ): + _, param = inspect.signature(method.function).parameters + func = sub_method.__get__(None, object) + pyx.writeln(f"return {func.__name__}({param})") + else: + with pyx.write_block(cython_signature("cdef inline", sub_method, cls)): + pyx.writelines(get_body(sub_method, cls)) + pyx.writeln() + + +def write_dispatch(pyx: IndentedWriter, method: Method): + with pyx.write_block(cython_signature("cdef", method.function, method.base_class)): # type: ignore + pyx.writeln(f"cdef int {DISPATCH_FIELD} = self.{DISPATCH_FIELD}") + for cls, dispatch in get_dispatch(method.base_class).items(): + if method.name in cls.__dict__: + if_ = "if" if dispatch == 0 else "elif" + with pyx.write_block(f"{if_} {DISPATCH_FIELD} == {dispatch}:"): + self, *params = inspect.signature(method.function).parameters + args = ", ".join([f"<{cls.__name__}>{self}", *params]) + pyx.writeln(f"return {method_name(cls, method.name)}({args})") + + +def generate(package: str) -> str: + module = f"apischema.{package}.methods" + pyx_file_name = ROOT_DIR / "apischema" / package / "methods.pyx" + with open(pyx_file_name, "w") as pyx_file: + pyx = IndentedWriter(pyx_file) + pyx.write("cimport cython\n") + pyx.writelines(import_lines(ROOT_DIR / "apischema" / package / "methods.py")) + for cls in module_elements(module, type): + write_class(pyx, cls) # type: ignore + pyx.writeln() + for func in module_elements(module, FunctionType): + write_function(pyx, func) # type: ignore + pyx.writeln() + methods = module_methods(module) + for method in methods.values(): + write_methods(pyx, method) + for method in methods.values(): + write_dispatch(pyx, method) + pyx.writeln() + return str(pyx_file_name) + + +packages = ["deserialization", "serialization"] + + +def main(): + # remove compiled before generate, because .so would be imported otherwise + for ext in ["so", "pyd"]: + for file in (ROOT_DIR / "apischema").glob(f"**/*.{ext}"): + file.unlink() + sys.path.append(str(ROOT_DIR)) + cythonize(list(map(generate, packages)), language_level=3) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_pyx.py b/scripts/generate_pyx.py deleted file mode 100755 index ffce198b..00000000 --- a/scripts/generate_pyx.py +++ /dev/null @@ -1,294 +0,0 @@ -#!/usr/bin/env python3 -import collections.abc -import dataclasses -import importlib -import inspect -import re -import sys -from contextlib import contextmanager -from pathlib import Path -from types import FunctionType -from typing import ( - AbstractSet, - Any, - Dict, - Iterable, - Mapping, - Match, - Optional, - Sequence, - TextIO, - Tuple, - get_type_hints, -) - -try: - from typing import Literal - - CythonDef = Literal["cdef", "cpdef", "cdef inline"] -except ImportError: - CythonDef = str # type: ignore - - -ROOT_DIR = Path(__file__).parent.parent -TYPE_FIELD = "_type" - - -def remove_prev_compilation(package: str): - for ext in ["so", "pyd"]: - for file in (ROOT_DIR / "apischema" / package).glob(f"**/*.{ext}"): - file.unlink() - - -cython_types_mapping = { - type: "type", - bytes: "bytes", - bytearray: "bytearray", - bool: "bint", - str: "str", - tuple: "tuple", - Tuple: "tuple", - list: "list", - int: "long", - dict: "dict", - Mapping: "dict", - collections.abc.Mapping: "dict", - set: "set", - AbstractSet: "set", - collections.abc.Set: "set", -} - - -def cython_type(tp: Any) -> str: - return cython_types_mapping.get(getattr(tp, "__origin__", tp), "object") - - -def cython_signature( - def_type: CythonDef, func: FunctionType, self_type: Optional[type] = None -) -> str: - parameters = list(inspect.signature(func).parameters.values()) - assert all(p.default is inspect.Parameter.empty for p in parameters) - types = get_type_hints(func) - param_with_types = [] - prefix = "" - if parameters[0].name == "self": - if self_type is not None: - types["self"] = self_type - prefix = self_type.__name__ + "_" - else: - param_with_types.append("self") - parameters.pop(0) - for param in parameters: - param_with_types.append(cython_type(types[param.name]) + " " + param.name) - return f"{def_type} {prefix}{func.__name__}(" + ", ".join(param_with_types) + "):" - - -class IndentedFile: - def __init__(self, file: TextIO): - self.file = file - self.indentation = "" - - def write(self, txt: str): - self.file.write(txt) - - def writelines(self, lines: Iterable[str]): - self.file.writelines(lines) - - def writeln(self, txt: str = ""): - self.write((self.indentation + txt + "\n") if txt else "\n") - - @contextmanager - def indent(self): - self.indentation += 4 * " " - yield - self.indentation = self.indentation[:-4] - - @contextmanager - def write_block(self, txt: str): - self.writeln(txt) - with self.indent(): - yield - - -def rec_subclasses(cls: type) -> Iterable[type]: - for sub_cls in cls.__subclasses__(): - yield sub_cls - yield from rec_subclasses(sub_cls) - - -def get_body( - func: FunctionType, - switches: Mapping[str, Tuple[type, FunctionType]], - cls: Optional[type] = None, -) -> Iterable[str]: - lines, _ = inspect.getsourcelines(func) - line_iter = iter(lines) - for line in line_iter: - if line.rstrip().endswith(":"): - break - else: - raise NotImplementedError - for line in line_iter: - if cls is not None: - - def replace_super(match: Match): - assert cls is not None - super_cls = cls.__bases__[0].__name__ - return f"{super_cls}_{match.group(1)}(<{super_cls}>self, " - - line = re.sub(r"super\(\).(\w+)\(", replace_super, line) - names = "|".join(switches) - - def sub(match: Match): - self, name = match.groups() - cls, _ = switches[name] - return f"{cls.__name__}_{name}({self}, " - - yield re.sub(rf"([\w\.]+)\.({names})\(", sub, line) - - -def get_fields(cls: type) -> Sequence[dataclasses.Field]: - return dataclasses.fields(cls) if dataclasses.is_dataclass(cls) else () - - -def generate(package: str): - module = importlib.import_module(f"apischema.{package}.methods") - classes = [ - cls - for cls in module.__dict__.values() - if isinstance(cls, type) and cls.__module__ == module.__name__ - ] - for cls in classes: - cython_types_mapping[cls] = cls.__name__ - cython_types_mapping[Optional[cls]] = cls.__name__ - if sys.version_info >= (3, 10): - cython_types_mapping[cls | None] = cls.__name__ - subclass_type: Dict[type, int] = {} - switches = {} - with open(ROOT_DIR / "apischema" / package / "methods.pyx", "w") as pyx_file: - pyx = IndentedFile(pyx_file) - pyx.write("cimport cython\n") - with open(ROOT_DIR / "apischema" / package / "methods.py") as methods_file: - for line in methods_file: - if ( - line.startswith("from ") - or line.startswith("import ") - or line.startswith(" ") - or line.startswith(")") - or not line.strip() - ): - pyx.write(line) - else: - break - for cls in classes: - class_def = f"cdef class {cls.__name__}" - if cls.__bases__ != (object,): - bases = ", ".join(base.__name__ for base in cls.__bases__) - class_def += f"({bases})" - with pyx.write_block(class_def + ":"): - pyx.writeln("pass") - write_init = cls in subclass_type - for field in get_fields(cls): - if field.name in cls.__dict__.get("__annotations__", ()): - write_init = True - pyx.writeln( - f"cdef readonly {cython_type(field.type)} {field.name}" - ) - pyx.writeln() - if write_init: - init_fields = [ - field.name for field in get_fields(cls) if field.init - ] - with pyx.write_block( - "def __init__(" + ", ".join(["self"] + init_fields) + "):" - ): - for name in init_fields: - pyx.writeln(f"self.{name} = {name}") - if hasattr(cls, "__post_init__"): - lines, _ = inspect.getsourcelines(cls.__post_init__) # type: ignore - pyx.writelines(lines[1:]) - if cls in subclass_type: - pyx.writeln(f"self.{TYPE_FIELD} = {subclass_type[cls]}") - pyx.writeln() - if cls.__bases__ == (object,): - if cls.__subclasses__(): - for i, sub_cls in enumerate(rec_subclasses(cls)): - subclass_type[sub_cls] = i - pyx.writeln(f"cdef int {TYPE_FIELD}") - for name, obj in cls.__dict__.items(): - if isinstance(obj, FunctionType) and not name.startswith("_"): - assert name not in switches - switches[name] = (cls, obj) - with pyx.write_block(cython_signature("cpdef", obj)): - pyx.writeln("raise NotImplementedError") - pyx.writeln() - else: - for name, obj in cls.__dict__.items(): - if ( - isinstance(obj, (FunctionType, staticmethod)) - and name in switches - ): - _, base_method = switches[name] - with pyx.write_block( - cython_signature("cpdef", base_method) - ): - args = ", ".join( - inspect.signature(base_method).parameters - ) - pyx.writeln(f"return {cls.__name__}_{name}({args})") - pyx.writeln() - - for cls, method in switches.values(): - for i, sub_cls in enumerate(rec_subclasses(cls)): - if method.__name__ in sub_cls.__dict__: - sub_method = sub_cls.__dict__[method.__name__] - if isinstance(sub_method, staticmethod): - with pyx.write_block( - cython_signature("cdef inline", method, sub_cls) - ): - _, param = inspect.signature(method).parameters - func = sub_method.__get__(None, object) - pyx.writeln(f"return {func.__name__}({param})") - else: - with pyx.write_block( - cython_signature("cdef inline", sub_method, sub_cls) - ): - pyx.writelines(get_body(sub_method, switches, sub_cls)) - pyx.writeln() - for cls, method in switches.values(): - with pyx.write_block(cython_signature("cdef", method, cls)): - pyx.writeln(f"cdef int {TYPE_FIELD} = self.{TYPE_FIELD}") - for i, sub_cls in enumerate(rec_subclasses(cls)): - if method.__name__ in sub_cls.__dict__: - if_ = "if" if i == 0 else "elif" - with pyx.write_block(f"{if_} {TYPE_FIELD} == {i}:"): - self, *params = inspect.signature(method).parameters - args = ", ".join([f"<{sub_cls.__name__}>{self}", *params]) - pyx.writeln( - f"return {sub_cls.__name__}_{method.__name__}({args})" - ) - pyx.writeln() - for obj in module.__dict__.values(): - if isinstance(obj, FunctionType) and obj.__module__ == module.__name__: - pyx.writeln(cython_signature("cdef inline", obj)) - pyx.writelines(get_body(obj, switches)) - pyx.writeln() - - -packages = ["deserialization", "serialization"] - - -def clean(): - for package in packages: - remove_prev_compilation(package) - - -def main(): - clean() # remove all before generate, because .so would be imported otherwise - sys.path.append(str(ROOT_DIR)) - for package in packages: - generate(package) - - -if __name__ == "__main__": - main() diff --git a/setup.py b/setup.py index 6747998a..9ed7ddca 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,7 @@ import os import platform -import sys -from setuptools import find_packages, setup - -import scripts.generate_pyx +from setuptools import Extension, find_packages, setup README = None # README cannot be read by older python version run by tox @@ -12,23 +9,15 @@ with open("README.md") as f: README = f.read() - ext_modules = None -if "clean" in sys.argv: - scripts.generate_pyx.clean() -if ( - not any(arg in sys.argv for arg in ["clean", "check"]) - and "SKIP_CYTHON" not in os.environ - and platform.python_implementation() != "PyPy" -): - try: - from Cython.Build import cythonize - except ImportError: - pass - else: - scripts.generate_pyx.main() - os.environ["CFLAGS"] = "-O3 " + os.environ.get("CFLAGS", "") - ext_modules = cythonize(["apischema/**/*.pyx"], language_level=3) +# Cythonization makes apischema a lot slower using PyPy +if platform.python_implementation() != "PyPy": + ext_modules = [ + Extension( + f"apischema.{package}.methods", sources=[f"apischema/{package}/methods.c"] + ) + for package in ("deserialization", "serialization") + ] setup( name="apischema", From fad51ab4a959d911bc54161f5f27513c9aad2d41 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Mon, 1 Nov 2021 19:22:07 +0100 Subject: [PATCH 05/15] Try to test with and without compilation --- .github/workflows/ci.yml | 6 ++++++ tox.ini | 4 ---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 64ed3a38..87350446 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,6 +14,7 @@ jobs: fail-fast: false matrix: python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy3'] + compile: [true, false] steps: - uses: actions/cache@v2 with: @@ -24,6 +25,11 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + - name: cythonize + if: matrix.compile + run: | + python scripts/cythonize.py + python setup.py build_ext --in-place - name: Install tox run: | python -m pip install --upgrade pip diff --git a/tox.ini b/tox.ini index 8c1902ba..68b40178 100644 --- a/tox.ini +++ b/tox.ini @@ -34,10 +34,6 @@ commands = python3 setup.py clean python3 scripts/generate_tests_from_examples.py py{36,py3}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py - py{37,38,39}: pytest tests - py310: pytest tests --cov=apischema --cov-report html - python3 setup.py install - py{36,py3}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py py{37,38,39,310}: pytest tests From 929295e4ca0878c8a1204c1495003d001a007d7c Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Mon, 1 Nov 2021 19:45:14 +0100 Subject: [PATCH 06/15] CI test --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 87350446..791b78d5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ on: jobs: test: - name: Test ${{ matrix.python-version }} + name: Test ${{ matrix.python-version }}${{ matrix.compiled && ' compiled' || '' }} runs-on: ubuntu-latest strategy: fail-fast: false @@ -28,6 +28,7 @@ jobs: - name: cythonize if: matrix.compile run: | + python -m pip install cython ${{ matrix.compiled && matrix.python-version == '3.6' && 'dataclasses' || '' }} python scripts/cythonize.py python setup.py build_ext --in-place - name: Install tox From c798699113a07fa86545f8675673cbbdf3d16a67 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Mon, 1 Nov 2021 19:52:41 +0100 Subject: [PATCH 07/15] Fix CI --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 791b78d5..8a94d1b1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ on: jobs: test: - name: Test ${{ matrix.python-version }}${{ matrix.compiled && ' compiled' || '' }} + name: Test ${{ matrix.python-version }}${{ matrix.compile && ' compiled' || '' }} runs-on: ubuntu-latest strategy: fail-fast: false @@ -28,7 +28,7 @@ jobs: - name: cythonize if: matrix.compile run: | - python -m pip install cython ${{ matrix.compiled && matrix.python-version == '3.6' && 'dataclasses' || '' }} + python -m pip install cython ${{ matrix.compile && matrix.python-version == '3.6' && 'dataclasses' || '' }} python scripts/cythonize.py python setup.py build_ext --in-place - name: Install tox From cdf1d88930758cdc16f6ce689ef44dcdf97ef336 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Mon, 1 Nov 2021 19:54:55 +0100 Subject: [PATCH 08/15] Fix CI --- .github/workflows/ci.yml | 2 +- scripts/cythonize.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8a94d1b1..ec7433eb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,7 +30,7 @@ jobs: run: | python -m pip install cython ${{ matrix.compile && matrix.python-version == '3.6' && 'dataclasses' || '' }} python scripts/cythonize.py - python setup.py build_ext --in-place + python setup.py build_ext --inplace - name: Install tox run: | python -m pip install --upgrade pip diff --git a/scripts/cythonize.py b/scripts/cythonize.py index 2b78a1e7..ff7ed834 100755 --- a/scripts/cythonize.py +++ b/scripts/cythonize.py @@ -140,7 +140,7 @@ def rec_subclasses(cls: type) -> Iterable[type]: yield from rec_subclasses(sub_cls) -@lru_cache +@lru_cache() def get_dispatch(base_class: type) -> Mapping[type, int]: return {cls: i for i, cls in enumerate(rec_subclasses(base_class))} From c7d62616d61892748ebc423eabfe1a67a4d34aa2 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Mon, 1 Nov 2021 20:17:49 +0100 Subject: [PATCH 09/15] Fix CI --- apischema/typing.py | 1 + scripts/cythonize.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/apischema/typing.py b/apischema/typing.py index c34f9731..47dd511f 100644 --- a/apischema/typing.py +++ b/apischema/typing.py @@ -53,6 +53,7 @@ def get_type_hints( # type: ignore try: from typing_extensions import get_origin, get_args except ImportError: + Annotated = ... # noqa # type: ignore def _assemble_tree(tree: Tuple[Any]) -> Any: if not isinstance(tree, tuple): diff --git a/scripts/cythonize.py b/scripts/cythonize.py index ff7ed834..61417214 100755 --- a/scripts/cythonize.py +++ b/scripts/cythonize.py @@ -31,7 +31,7 @@ try: from typing import Literal - CythonDef = Literal["cdef", "cpdef", "cdef inline"] + CythonDef = Literal["cdef", "cpdef", "cdef inline", "cpdef inline"] except ImportError: CythonDef = str # type: ignore @@ -257,7 +257,7 @@ def write_class(pyx: IndentedWriter, cls: type): def write_function(pyx: IndentedWriter, func: FunctionType): - pyx.writeln(cython_signature("cdef inline", func)) + pyx.writeln(cython_signature("cpdef inline", func)) pyx.writelines(get_body(func)) @@ -279,7 +279,7 @@ def write_methods(pyx: IndentedWriter, method: Method): def write_dispatch(pyx: IndentedWriter, method: Method): - with pyx.write_block(cython_signature("cdef", method.function, method.base_class)): # type: ignore + with pyx.write_block(cython_signature("cdef inline", method.function, method.base_class)): # type: ignore pyx.writeln(f"cdef int {DISPATCH_FIELD} = self.{DISPATCH_FIELD}") for cls, dispatch in get_dispatch(method.base_class).items(): if method.name in cls.__dict__: From e4e961dc1689339946663261e7304ee3aeaf1e3b Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Mon, 1 Nov 2021 23:18:12 +0100 Subject: [PATCH 10/15] Fix CI --- apischema/serialization/methods.py | 18 +++++++----------- apischema/typing.py | 10 +++++----- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/apischema/serialization/methods.py b/apischema/serialization/methods.py index 9936e7ef..3a031e0b 100644 --- a/apischema/serialization/methods.py +++ b/apischema/serialization/methods.py @@ -5,7 +5,6 @@ from apischema.fields import FIELDS_SET_ATTR from apischema.types import AnyType, Undefined from apischema.utils import Lazy -from apischema.visitor import Unsupported class SerializationMethod: @@ -93,7 +92,7 @@ def fall_back(self, obj: Any) -> Any: @dataclass class TypeCheckIdentityMethod(SerializationMethod): - expected: type + expected: AnyType # `type` would require exact match (i.e. no EnumMeta) fallback: Fallback def serialize(self, obj: Any) -> Any: @@ -105,14 +104,11 @@ class TypeCheckMethod(TypeCheckIdentityMethod): method: SerializationMethod def serialize(self, obj: Any) -> Any: - if isinstance(obj, self.expected): - try: - return self.method.serialize(obj) - except Unsupported: - raise - except Exception: - pass - return self.fallback.fall_back(obj) + return ( + self.method.serialize(obj) + if isinstance(obj, self.expected) + else self.fallback.fall_back(obj) + ) @dataclass @@ -328,7 +324,7 @@ def serialize(self, obj: Any) -> Any: @dataclass class UnionAlternative: - cls: type + cls: AnyType # `type` would require exact match (i.e. no EnumMeta) method: SerializationMethod def __post_init__(self): diff --git a/apischema/typing.py b/apischema/typing.py index 47dd511f..b6ee27be 100644 --- a/apischema/typing.py +++ b/apischema/typing.py @@ -2,6 +2,7 @@ __all__ = ["get_args", "get_origin", "get_type_hints"] import sys +from contextlib import suppress from types import ModuleType, new_class from typing import ( # type: ignore Any, @@ -53,17 +54,16 @@ def get_type_hints( # type: ignore try: from typing_extensions import get_origin, get_args except ImportError: - Annotated = ... # noqa # type: ignore def _assemble_tree(tree: Tuple[Any]) -> Any: if not isinstance(tree, tuple): return tree else: origin, *args = tree # type: ignore - if origin is Annotated: - return Annotated[(_assemble_tree(args[0]), *args[1])] - else: - return origin[tuple(map(_assemble_tree, args))] + with suppress(NameError): + if origin is Annotated: + return Annotated[(_assemble_tree(args[0]), *args[1])] + return origin[tuple(map(_assemble_tree, args))] def get_origin(tp): # type: ignore # In Python 3.6: List[Collection[T]][int].__args__ == int != Collection[int] From 68ba8be14d41923cfdfb6f22f453fcc194afa0a1 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Mon, 1 Nov 2021 23:38:12 +0100 Subject: [PATCH 11/15] Fix CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ec7433eb..d436ae3e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: - name: cythonize if: matrix.compile run: | - python -m pip install cython ${{ matrix.compile && matrix.python-version == '3.6' && 'dataclasses' || '' }} + python -m pip install cython ${{ matrix.compile && (matrix.python-version == '3.6' || matrix.python-version == 'pypy3') && 'dataclasses' || '' }} python scripts/cythonize.py python setup.py build_ext --inplace - name: Install tox From 2b2673e089598ef5a8f782f7e79ab1668e10c156 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Tue, 2 Nov 2021 07:27:43 +0100 Subject: [PATCH 12/15] Exclude compilation for pypy in test --- .github/workflows/ci.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d436ae3e..09acee35 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,6 +15,9 @@ jobs: matrix: python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy3'] compile: [true, false] + exclude: + - python-version: pypy3 + compile: true steps: - uses: actions/cache@v2 with: From e25ae6de791b1c1b7202a094bdda8e6c2da3b10e Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Wed, 3 Nov 2021 23:21:32 +0100 Subject: [PATCH 13/15] Add documentation --- apischema/deserialization/__init__.py | 36 ++++++++++++++---------- apischema/deserialization/methods.py | 31 +++++++++++++++------ apischema/settings.py | 40 ++++++++++++++++----------- apischema/utils.py | 9 ------ docs/difference_with_pydantic.md | 4 +-- docs/json_schema.md | 3 ++ docs/performance_and_benchmark.md | 12 +++++++- docs/validation.md | 15 ++++++++++ examples/settings_errors.py | 12 ++++++++ examples/validation_error.py | 2 +- 10 files changed, 112 insertions(+), 52 deletions(-) create mode 100644 examples/settings_errors.py diff --git a/apischema/deserialization/__init__.py b/apischema/deserialization/__init__.py index 9bc78264..2a7b651a 100644 --- a/apischema/deserialization/__init__.py +++ b/apischema/deserialization/__init__.py @@ -3,7 +3,7 @@ from collections import defaultdict from dataclasses import dataclass, replace from enum import Enum -from functools import lru_cache +from functools import lru_cache, partial from typing import ( Any, Callable, @@ -14,6 +14,7 @@ Pattern, Sequence, Set, + TYPE_CHECKING, Tuple, Type, TypeVar, @@ -55,6 +56,7 @@ ObjectMethod, OptionalMethod, PatternField, + PreformatedConstraintError, RecMethod, SetMethod, StrMethod, @@ -82,13 +84,15 @@ get_origin_or_type, literal_values, opt_or, - partial_format, to_pascal_case, to_snake_case, ) from apischema.validation import get_validators from apischema.validation.validators import Validator +if TYPE_CHECKING: + from apischema.settings import ConstraintError + MISSING_PROPERTY = "missing property" UNEXPECTED_PROPERTY = "unexpected property" @@ -133,6 +137,16 @@ def get_constraints(schema: Optional[Schema]) -> Optional[Constraints]: constraint_classes = {cls.__name__: cls for cls in Constraint.__subclasses__()} +def preformat_error( + error: "ConstraintError", constraint: Any +) -> PreformatedConstraintError: + return ( + error.format(constraint) + if isinstance(error, str) + else partial(error, constraint) + ) + + def constraints_validators( constraints: Optional[Constraints], ) -> Mapping[type, Tuple[Constraint, ...]]: @@ -143,20 +157,14 @@ def constraints_validators( for name, attr, metadata in constraints.attr_and_metata: if attr is None or attr is False: continue - error = getattr(settings.errors, to_snake_case(metadata.alias)) - error = partial_format( - error, - constraint=attr - if not isinstance(attr, type(re.compile(r""))) - else attr.pattern, + error = preformat_error( + getattr(settings.errors, to_snake_case(metadata.alias)), + attr if not isinstance(attr, type(re.compile(r""))) else attr.pattern, ) constraint_cls = constraint_classes[ to_pascal_case(metadata.alias) + "Constraint" ] - result[metadata.cls] = ( - *result[metadata.cls], - constraint_cls(error) if attr is True else constraint_cls(error, attr), # type: ignore - ) + result[metadata.cls] = (*result[metadata.cls], constraint_cls(error, attr)) # type: ignore if float in result: result[int] = result[float] return result @@ -270,7 +278,7 @@ def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: value_map = dict(zip(literal_values(values), values)) return LiteralMethod( value_map, - partial_format(settings.errors.one_of, constraint=list(value_map)), + preformat_error(settings.errors.one_of, list(value_map)), self.coercer, tuple(set(map(type, value_map))), ) @@ -420,7 +428,7 @@ def tuple(self, types: Sequence[AnyType]) -> DeserializationMethodFactory: elt_factories = [self.visit(tp) for tp in types] def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - def len_error(constraints: Constraints) -> str: + def len_error(constraints: Constraints) -> PreformatedConstraintError: return constraints_validators(constraints)[list][0].error return TupleMethod( diff --git a/apischema/deserialization/methods.py b/apischema/deserialization/methods.py index b56728fb..e1d67f48 100644 --- a/apischema/deserialization/methods.py +++ b/apischema/deserialization/methods.py @@ -9,6 +9,7 @@ Sequence, TYPE_CHECKING, Tuple, + Union, ) from apischema.aliases import Aliaser @@ -25,10 +26,12 @@ if TYPE_CHECKING: pass +PreformatedConstraintError = Union[str, Callable[[Any], str]] + @dataclass class Constraint: - error: str + error: PreformatedConstraintError def validate(self, data: Any) -> bool: raise NotImplementedError @@ -124,7 +127,13 @@ def to_hashable(data: Any) -> Any: return data +@dataclass class UniqueItemsConstraint(Constraint): + unique: bool + + def __post_init__(self): + assert self.unique + def validate(self, data: list) -> bool: return len(set(map(to_hashable, data))) == len(data) @@ -145,17 +154,21 @@ def validate(self, data: dict) -> bool: return len(data) <= self.max_properties +def format_error(err: PreformatedConstraintError, data: Any) -> str: + return err if isinstance(err, str) else err(data) + + def validate_constraints( data: Any, constraints: Tuple[Constraint, ...], children_errors: Optional[dict] ) -> Any: for i in range(len(constraints)): constraint: Constraint = constraints[i] if not constraint.validate(data): - errors: list = [constraint.error.format(data)] + errors: list = [format_error(constraint.error, data)] for j in range(i + 1, len(constraints)): constraint = constraints[j] if not constraint.validate(data): - errors.append(constraint.error.format(data)) + errors.append(format_error(constraint.error, data)) raise ValidationError(errors, children_errors or {}) if children_errors: raise ValidationError([], children_errors) @@ -248,7 +261,7 @@ def deserialize(self, data: Any) -> Any: @dataclass class LiteralMethod(DeserializationMethod): value_map: dict - error: str + error: PreformatedConstraintError coercer: Optional[Coercer] types: Tuple[type, ...] @@ -262,7 +275,7 @@ def deserialize(self, data: Any) -> Any: return self.value_map[self.coercer(cls, data)] except IndexError: pass - raise ValidationError([self.error.format(data)]) + raise ValidationError([format_error(self.error, data)]) @dataclass @@ -559,8 +572,8 @@ def deserialize(self, data: Any) -> Any: @dataclass class TupleMethod(DeserializationMethod): constraints: Tuple[Constraint, ...] - min_len_error: str - max_len_error: str + min_len_error: PreformatedConstraintError + max_len_error: PreformatedConstraintError elt_methods: Tuple[DeserializationMethod, ...] def deserialize(self, data: Any) -> Any: @@ -569,9 +582,9 @@ def deserialize(self, data: Any) -> Any: data2: list = data if len(data2) != len(self.elt_methods): if len(data2) < len(self.elt_methods): - raise ValidationError([self.min_len_error % len(data2)]) + raise ValidationError([format_error(self.min_len_error, data2)]) elif len(data2) > len(self.elt_methods): - raise ValidationError([self.max_len_error % len(data2)]) + raise ValidationError([format_error(self.max_len_error, data2)]) else: raise NotImplementedError elt_errors: dict = {} diff --git a/apischema/settings.py b/apischema/settings.py index ee6168fa..8dca2bcc 100644 --- a/apischema/settings.py +++ b/apischema/settings.py @@ -1,6 +1,6 @@ import warnings from inspect import Parameter -from typing import Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, Union from apischema import cache from apischema.aliases import Aliaser @@ -48,6 +48,9 @@ def __setattr__(self, name, value): super().__setattr__(name, value) +ConstraintError = Union[str, Callable[[Any, Any], str]] + + class settings(metaclass=MetaSettings): additional_properties: bool = False aliaser: Aliaser = lambda s: s @@ -67,26 +70,31 @@ class base_schema: type: Callable[[AnyType], Optional[Schema]] = lambda *_: None class errors: - minimum: str = "less than {constraint} (minimum)" - maximum: str = "greater than {constraint} (maximum)" - exclusive_minimum: str = "less than or equal to {constraint} (exclusiveMinimum)" - exclusive_maximum: str = ( - "greater than or equal to {constraint} (exclusiveMinimum)" + minimum: ConstraintError = "less than {} (minimum)" + maximum: ConstraintError = "greater than {} (maximum)" + exclusive_minimum: ConstraintError = ( + "less than or equal to {} (exclusiveMinimum)" + ) + exclusive_maximum: ConstraintError = ( + "greater than or equal to {} (exclusiveMinimum)" ) - multiple_of: str = "not a multiple of {constraint} (multipleOf)" + multiple_of: ConstraintError = "not a multiple of {} (multipleOf)" - min_length: str = "string length lower than {constraint} (minLength)" - max_length: str = "string length greater than {constraint} (maxLength)" - pattern: str = 'not matching pattern "{constraint}" (pattern)' + min_length: ConstraintError = "string length lower than {} (minLength)" + max_length: ConstraintError = "string length greater than {} (maxLength)" + pattern: ConstraintError = "not matching pattern {} (pattern)" - min_items: str = "item count lower than {constraint} (minItems)" - max_items: str = "item count greater than {constraint} (maxItems)" - unique_items: str = "duplicate items (uniqueItems)" + min_items: ConstraintError = "item count lower than {} (minItems)" + max_items: ConstraintError = "item count greater than {} (maxItems)" + unique_items: ConstraintError = "duplicate items (uniqueItems)" + + min_properties: ConstraintError = "property count lower than {} (minProperties)" + max_properties: ConstraintError = ( + "property count greater than {} (maxProperties)" + ) - min_properties: str = "property count lower than {constraint} (minProperties)" - max_properties: str = "property count greater than {constraint} (maxProperties)" + one_of: ConstraintError = "not one of {} (oneOf)" - one_of: str = "not one of {constraint} (oneOf)" unexpected_property: str = "unexpected property" missing_property: str = "missing property" diff --git a/apischema/utils.py b/apischema/utils.py index 1b9f42cc..47aea8de 100644 --- a/apischema/utils.py +++ b/apischema/utils.py @@ -104,15 +104,6 @@ def to_pascal_case(s: str) -> str: return camel[0].upper() + camel[1:] if camel else camel -class PartialFormatter(dict): - def __missing__(self, key): - return "{%s}" % key - - -def partial_format(s: str, **kwargs) -> str: - return s.format_map(PartialFormatter(kwargs)) - - def merge_opts( func: Callable[[T, T], T] ) -> Callable[[Optional[T], Optional[T]], Optional[T]]: diff --git a/docs/difference_with_pydantic.md b/docs/difference_with_pydantic.md index b86f3d6a..a324e618 100644 --- a/docs/difference_with_pydantic.md +++ b/docs/difference_with_pydantic.md @@ -2,9 +2,9 @@ As the question is often asked, it is answered in a dedicated section. Here are some the key differences between *apischema* and *pydantic*: -### *apischema* is faster +### *apischema* is (a lot) faster -*pydantic* uses Cython to improve its performance; *apischema* doesn't need it and is still 1.5x faster according to [*pydantic* benchmark](performance_and_benchmark.md) — more than 2x when *pydantic* is not compiled with Cython. +According to [*pydantic* benchmark](performance_and_benchmark.md), *apischema* is a lot faster than *pydantic*, especially for serialization. Both use Cython to optimize the code, but even without compilation (running only Python modules), *apischema* is still faster than Cythonized *pydantic*. Better performance, but not at the cost of fewer functionalities; that's rather the opposite: [dynamic aliasing](json_schema.md#dynamic-aliasing-and-default-aliaser), [conversions](conversions.md), [flattened fields](data_model.md#composition-over-inheritance---composed-dataclasses-flattening), etc. diff --git a/docs/json_schema.md b/docs/json_schema.md index 109ac0a3..01be692e 100644 --- a/docs/json_schema.md +++ b/docs/json_schema.md @@ -92,6 +92,9 @@ JSON schema constrains the data deserialized; these constraints are naturally us {!validation_error.py!} ``` +!!! note + Error message are fully [customizable](validation.md#constraint-errors-customization) + ### Extra schema `schema` has two other arguments: `extra` and `override`, which give a finer control of the JSON schema generated: `extra` and `override`. It can be used for example to build "strict" unions (using `oneOf` instead of `anyOf`) diff --git a/docs/performance_and_benchmark.md b/docs/performance_and_benchmark.md index 8e85360d..75602bb4 100644 --- a/docs/performance_and_benchmark.md +++ b/docs/performance_and_benchmark.md @@ -1,6 +1,6 @@ # Performance and benchmark -*apischema* is [faster](#benchmark) than its known alternatives, thanks to advanced optimizations. +*apischema* is (a lot) [faster](#benchmark) than its known alternatives, thanks to advanced optimizations. ## Precomputed (de)serialization methods @@ -81,6 +81,16 @@ Either a collection of types, or a predicate to determine if type has to be pass ``` That's why passthrough optimization should be used wisely. +## Binary compilation using Cython + +*apischema* use Cython in order to compile critical parts of the code, i.e. the (de)serialization methods. + +However, *apischema* remains a pure Python library — it can work without binary modules. Cython source files (`.pyx`) are in fact generated from Python modules. It allows notably keeping the code simple, by adding *switch-case* optimization to replace dynamic dispatch, avoiding big chains of `elif` in Python code. + +!!! note + Compilation is disabled when using PyPy, because it's even faster with the bare Python code. + That's another interest of generating `.pyx` files: keeping Python source for PyPy. + ## Benchmark !!! note diff --git a/docs/validation.md b/docs/validation.md index d8c7d201..68778471 100644 --- a/docs/validation.md +++ b/docs/validation.md @@ -15,6 +15,21 @@ As shown in the example, *apischema* will not stop at the first error met but tr !!! note `ValidationError` can also be serialized using `apischema.serialize` (this will use `errors` internally). +## Constraint errors customization + +Constraints are validated at deserialization, with *apischema* providing default error messages. +Messages can be customized by setting the corresponding attribute of `apischema.settings.errors`. They can be either a string which will be formatted with the constraint value (using `str.format`), e.g. `less than {} (minimum)`, or a function with 2 parameters: the constraint value and the invalid data. + +```python +{!settings_errors.py!} +``` + +!!! note + Default error messages doesn't include the invalid data for security reason (data could for example be a password too short). + +!!! note + Other error message can be customized, for example `missing property` for missing required properties, etc. + ## Dataclass validators Dataclass validation can be completed by custom validators. These are simple decorated methods which will be executed during validation, after all fields having been deserialized. diff --git a/examples/settings_errors.py b/examples/settings_errors.py new file mode 100644 index 00000000..a6726225 --- /dev/null +++ b/examples/settings_errors.py @@ -0,0 +1,12 @@ +from pytest import raises + +from apischema import ValidationError, deserialize, schema, settings + +settings.errors.max_items = ( + lambda constraint, data: f"too-many-items: {len(data)} > {constraint}" +) + + +with raises(ValidationError) as err: + deserialize(list[int], [0, 1, 2, 3], schema=schema(max_items=3)) +assert err.value.errors == [{"loc": [], "msg": "too-many-items: 4 > 3"}] diff --git a/examples/validation_error.py b/examples/validation_error.py index bd21d8a4..5f606550 100644 --- a/examples/validation_error.py +++ b/examples/validation_error.py @@ -27,6 +27,6 @@ class Resource: assert err.value.errors == [ {"loc": ["tags"], "msg": "item count greater than 3 (maxItems)"}, {"loc": ["tags"], "msg": "duplicate items (uniqueItems)"}, - {"loc": ["tags", 3], "msg": 'not matching pattern "^\\w*$" (pattern)'}, + {"loc": ["tags", 3], "msg": "not matching pattern ^\\w*$ (pattern)"}, {"loc": ["tags", 4], "msg": "string length lower than 3 (minLength)"}, ] From ec013521d19178a9ec307706bec10d702db05c70 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Wed, 3 Nov 2021 23:42:59 +0100 Subject: [PATCH 14/15] Fix test wrapper --- scripts/test_wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/test_wrapper.py b/scripts/test_wrapper.py index 5e1e6f75..a6454956 100644 --- a/scripts/test_wrapper.py +++ b/scripts/test_wrapper.py @@ -73,6 +73,7 @@ def __subclasscheck__(self, subclass): settings_classes = ( settings, + settings.errors, settings.base_schema, settings.deserialization, settings.serialization, From 594adfb7f56d2742cb6dcc9f2fef437083023dd5 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Wed, 3 Nov 2021 23:55:22 +0100 Subject: [PATCH 15/15] Fix compilation --- apischema/deserialization/__init__.py | 6 +++--- apischema/deserialization/methods.py | 12 +++++------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/apischema/deserialization/__init__.py b/apischema/deserialization/__init__.py index 2a7b651a..16c3ed2d 100644 --- a/apischema/deserialization/__init__.py +++ b/apischema/deserialization/__init__.py @@ -18,6 +18,7 @@ Tuple, Type, TypeVar, + Union, overload, ) @@ -56,7 +57,6 @@ ObjectMethod, OptionalMethod, PatternField, - PreformatedConstraintError, RecMethod, SetMethod, StrMethod, @@ -139,7 +139,7 @@ def get_constraints(schema: Optional[Schema]) -> Optional[Constraints]: def preformat_error( error: "ConstraintError", constraint: Any -) -> PreformatedConstraintError: +) -> Union[str, Callable[[Any], str]]: return ( error.format(constraint) if isinstance(error, str) @@ -428,7 +428,7 @@ def tuple(self, types: Sequence[AnyType]) -> DeserializationMethodFactory: elt_factories = [self.visit(tp) for tp in types] def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - def len_error(constraints: Constraints) -> PreformatedConstraintError: + def len_error(constraints: Constraints) -> Union[str, Callable[[Any], str]]: return constraints_validators(constraints)[list][0].error return TupleMethod( diff --git a/apischema/deserialization/methods.py b/apischema/deserialization/methods.py index e1d67f48..633f8afe 100644 --- a/apischema/deserialization/methods.py +++ b/apischema/deserialization/methods.py @@ -26,12 +26,10 @@ if TYPE_CHECKING: pass -PreformatedConstraintError = Union[str, Callable[[Any], str]] - @dataclass class Constraint: - error: PreformatedConstraintError + error: Union[str, Callable[[Any], str]] def validate(self, data: Any) -> bool: raise NotImplementedError @@ -154,7 +152,7 @@ def validate(self, data: dict) -> bool: return len(data) <= self.max_properties -def format_error(err: PreformatedConstraintError, data: Any) -> str: +def format_error(err: Union[str, Callable[[Any], str]], data: Any) -> str: return err if isinstance(err, str) else err(data) @@ -261,7 +259,7 @@ def deserialize(self, data: Any) -> Any: @dataclass class LiteralMethod(DeserializationMethod): value_map: dict - error: PreformatedConstraintError + error: Union[str, Callable[[Any], str]] coercer: Optional[Coercer] types: Tuple[type, ...] @@ -572,8 +570,8 @@ def deserialize(self, data: Any) -> Any: @dataclass class TupleMethod(DeserializationMethod): constraints: Tuple[Constraint, ...] - min_len_error: PreformatedConstraintError - max_len_error: PreformatedConstraintError + min_len_error: Union[str, Callable[[Any], str]] + max_len_error: Union[str, Callable[[Any], str]] elt_methods: Tuple[DeserializationMethod, ...] def deserialize(self, data: Any) -> Any: