diff --git a/apischema/__init__.py b/apischema/__init__.py index f9f6255d..cd3b34ab 100644 --- a/apischema/__init__.py +++ b/apischema/__init__.py @@ -10,6 +10,7 @@ "deserialize", "deserializer", "identity", + "order", "properties", "schema", "schema_ref", @@ -42,6 +43,7 @@ from .dependencies import dependent_required from .deserialization import deserialization_method, deserialize from .metadata import properties +from .ordering import order from .schemas import schema from .serialization import serialization_default, serialization_method, serialize from .serialization.pass_through import PassThroughOptions diff --git a/apischema/graphql/relay/global_identification.py b/apischema/graphql/relay/global_identification.py index d67304d2..ccc92d09 100644 --- a/apischema/graphql/relay/global_identification.py +++ b/apischema/graphql/relay/global_identification.py @@ -18,6 +18,7 @@ from apischema import deserialize, deserializer, serialize, serializer, type_name from apischema.graphql import ID, interface, resolver from apischema.metadata import skip +from apischema.ordering import order from apischema.type_names import get_type_name from apischema.typing import generic_mro, get_args, get_origin from apischema.utils import PREFIX, has_type_vars, wrap_generic_init_subclass @@ -117,7 +118,9 @@ def __init_subclass__(cls, not_a_node: bool = False, **kwargs): _tmp_nodes.append(cls) -resolver(alias="id")(Node.global_id) # cannot directly decorate property because py36 +resolver("id", order=order(-1))( + Node.global_id +) # cannot directly decorate property because py36 _tmp_nodes: List[Type[Node]] = [] _nodes: Dict[str, Type[Node]] = {} diff --git a/apischema/graphql/resolvers.py b/apischema/graphql/resolvers.py index 2bedddec..1f556c2c 100644 --- a/apischema/graphql/resolvers.py +++ b/apischema/graphql/resolvers.py @@ -29,6 +29,7 @@ from apischema.deserialization import deserialization_method from apischema.methods import method_registerer from apischema.objects import ObjectField +from apischema.ordering import Ordering from apischema.schemas import Schema from apischema.serialization import ( PassThroughOptions, @@ -160,8 +161,9 @@ def resolver( alias: str = None, *, conversion: AnyConversion = None, - schema: Schema = None, error_handler: ErrorHandler = Undefined, + order: Optional[Ordering] = None, + schema: Schema = None, parameters_metadata: Mapping[str, Mapping] = None, serialized: bool = False, owner: Type = None, @@ -175,8 +177,9 @@ def resolver( *, alias: str = None, conversion: AnyConversion = None, - schema: Schema = None, error_handler: ErrorHandler = Undefined, + order: Optional[Ordering] = None, + schema: Schema = None, parameters_metadata: Mapping[str, Mapping] = None, serialized: bool = False, owner: Type = None, @@ -192,8 +195,9 @@ def register(func: Callable, owner: Type, alias2: str): resolver = Resolver( func, conversion, - schema, error_handler2, + order, + schema, parameters, parameters_metadata or {}, ) diff --git a/apischema/graphql/schema.py b/apischema/graphql/schema.py index da6e1143..abdb9422 100644 --- a/apischema/graphql/schema.py +++ b/apischema/graphql/schema.py @@ -2,6 +2,7 @@ from enum import Enum from functools import wraps from inspect import Parameter, iscoroutinefunction +from itertools import chain from typing import ( Any, AsyncIterable, @@ -51,6 +52,7 @@ ObjectVisitor, SerializationObjectVisitor, ) +from apischema.ordering import Ordering, sort_by_order from apischema.recursion import RecursiveConversionsVisitor from apischema.schemas import Schema, get_schema, merge_schema from apischema.serialization import SerializationMethod, serialize @@ -69,7 +71,6 @@ get_origin_or_type, identity, is_union_of, - sort_by_annotations_position, to_camel_case, ) @@ -157,22 +158,6 @@ class ResolverField: metadata: Mapping[str, Mapping] subscribe: Optional[Callable] = None - @property - def name(self) -> str: - return self.resolver.func.__name__ - - @property - def type(self) -> AnyType: - return self.types["return"] - - @property - def description(self) -> Optional[str]: - return get_description(self.resolver.schema) - - @property - def deprecated(self) -> Optional[str]: - return get_deprecated(self.resolver.schema) - IdPredicate = Callable[[AnyType], bool] UnionNameFactory = Callable[[Sequence[str]], str] @@ -320,9 +305,6 @@ def collection( ) -> TypeFactory[GraphQLTp]: return TypeFactory(lambda *_: graphql.GraphQLList(self.visit(value_type).type)) - def _visit_flattened(self, field: ObjectField) -> TypeFactory[GraphQLTp]: - raise NotImplementedError - @cache_type def enum(self, cls: Type[Enum]) -> TypeFactory[GraphQLTp]: def factory( @@ -434,30 +416,61 @@ def visit_conversion( FieldType = TypeVar("FieldType", graphql.GraphQLInputField, graphql.GraphQLField) -def merge_fields( - cls: Type, - fields: Mapping[str, Lazy[FieldType]], - flattened_types: Mapping[str, TypeFactory], -) -> Dict[str, FieldType]: - all_flattened_fields: Dict[str, FieldType] = {} - for flattened_name, flattened_factory in flattened_types.items(): - flattened_type = flattened_factory.raw_type +class BaseField(Generic[FieldType]): + name: str + ordering: Optional[Ordering] + + def items(self) -> Iterable[Tuple[str, FieldType]]: + raise NotImplementedError + + +@dataclass +class NormalField(BaseField[FieldType]): + alias: str + name: str + field: Lazy[FieldType] + ordering: Optional[Ordering] + + def items(self) -> Iterable[Tuple[str, FieldType]]: + yield self.alias, self.field() + + +@dataclass +class FlattenedField(BaseField[FieldType]): + name: str + ordering: Optional[Ordering] + type: TypeFactory + + def items(self) -> Iterable[Tuple[str, FieldType]]: + tp = self.type.raw_type if not isinstance( - flattened_type, + tp, ( graphql.GraphQLObjectType, graphql.GraphQLInterfaceType, graphql.GraphQLInputObjectType, ), ): - raise TypeError( - f"Flattened field {cls.__name__}.{flattened_name} must have an object type" - ) - flattened_fields: Mapping[str, FieldType] = flattened_type.fields - if flattened_fields.keys() & all_flattened_fields.keys() & fields.keys(): - raise TypeError(f"Conflict in flattened fields of {cls}") - all_flattened_fields.update(flattened_fields) - return {**{name: field() for name, field in fields.items()}, **all_flattened_fields} + raise FlattenedError(self) + yield from tp.fields.items() + + +class FlattenedError(Exception): + def __init__(self, field: FlattenedField): + self.field = field + + +def merge_fields(cls: type, fields: Sequence[BaseField]) -> Dict[str, FieldType]: + try: + sorted_fields = sort_by_order( + cls, fields, lambda f: f.name, lambda f: f.ordering + ) + except FlattenedError as err: + raise TypeError( + f"Flattened field {cls.__name__}.{err.field.name}" + f" must have an object type" + ) + return OrderedDict(chain.from_iterable(map(lambda f: f.items(), sorted_fields))) class InputSchemaBuilder( @@ -467,11 +480,6 @@ class InputSchemaBuilder( ): types = graphql.type.definition.graphql_input_types - def _visit_flattened( - self, field: ObjectField - ) -> TypeFactory[graphql.GraphQLInputType]: - return self.visit_with_conv(field.type, field.deserialization) - def _field(self, field: ObjectField) -> Lazy[graphql.GraphQLInputField]: field_type = field.type field_default = graphql.Undefined if field.required else field.get_default() @@ -494,19 +502,29 @@ def _field(self, field: ObjectField) -> Lazy[graphql.GraphQLInputField]: factory.type, # type: ignore default_value=default, description=get_description(field.schema, field.type), - extensions={field.name: ""}, ) @cache_type def object( self, tp: AnyType, fields: Sequence[ObjectField] ) -> TypeFactory[graphql.GraphQLInputType]: - visited_fields = { - self.aliaser(f.alias): self._field(f) for f in fields if not f.is_aggregate - } - flattened_types = { - f.name: self._visit_flattened(f) for f in fields if f.flattened - } + visited_fields: List[BaseField] = [] + for field in fields: + if not field.is_aggregate: + normal_field = NormalField( + self.aliaser(field.alias), + field.name, + self._field(field), + field.ordering, + ) + visited_fields.append(normal_field) + elif field.flattened: + flattened_fields = FlattenedField( + field.name, + field.ordering, + self.visit_with_conv(field.type, field.deserialization), + ) + visited_fields.append(flattened_fields) def factory( name: Optional[str], description: Optional[str] @@ -516,7 +534,7 @@ def factory( name += "Input" return graphql.GraphQLInputObjectType( name, - lambda: merge_fields(tp, visited_fields, flattened_types), + lambda: merge_fields(get_origin_or_type(tp), visited_fields), description, ) @@ -657,14 +675,14 @@ def arg_thunk( ) args[self.aliaser(param_field.alias)] = arg_thunk - factory = self.visit_with_conv(field.type, field.resolver.conversion) + factory = self.visit_with_conv(field.types["return"], field.resolver.conversion) return lambda: graphql.GraphQLField( factory.type, # type: ignore {name: arg() for name, arg in args.items()} if args else None, resolve, field.subscribe, - field.description, - field.deprecated, + get_description(field.resolver.schema), + get_deprecated(field.resolver.schema), ) def _visit_flattened( @@ -691,10 +709,24 @@ def object( resolvers: Sequence[ResolverField] = (), ) -> TypeFactory[graphql.GraphQLOutputType]: cls = get_origin_or_type(tp) - all_fields = {f.alias: self._field(f) for f in fields if not f.is_aggregate} - name_by_aliases = {f.alias: f.name for f in fields} - all_fields.update({r.alias: self._resolver(r) for r in resolvers}) - name_by_aliases.update({r.alias: r.resolver.func.__name__ for r in resolvers}) + visited_fields: List[BaseField[graphql.GraphQLField]] = [] + flattened_factories = [] + for field in fields: + if not field.is_aggregate: + normal_field = NormalField( + self.aliaser(field.name), + field.name, + self._field(field), + field.ordering, + ) + visited_fields.append(normal_field) + elif field.flattened: + flattened_factory = self._visit_flattened(field) + flattened_factories.append(flattened_factory) + visited_fields.append( + FlattenedField(field.name, field.ordering, flattened_factory) + ) + resolvers = list(resolvers) for alias, (resolver, types) in get_resolvers(tp).items(): resolver_field = ResolverField( alias, @@ -703,36 +735,34 @@ def object( resolver.parameters, resolver.parameters_metadata, ) - all_fields[alias] = self._resolver(resolver_field) - name_by_aliases[alias] = resolver.func.__name__ - sorted_fields = sort_by_annotations_position( - cls, all_fields, name_by_aliases.__getitem__ - ) - visited_fields = OrderedDict( - (self.aliaser(a), all_fields[a]) for a in sorted_fields - ) - flattened_types = { - f.name: self._visit_flattened(f) for f in fields if f.flattened - } - - def field_thunk() -> graphql.GraphQLFieldMap: - return merge_fields(cls, visited_fields, flattened_types) + resolvers.append(resolver_field) + for resolver_field in resolvers: + normal_field = NormalField( + self.aliaser(resolver_field.alias), + resolver_field.resolver.func.__name__, + self._resolver(resolver_field), + resolver_field.resolver.ordering, + ) + visited_fields.append(normal_field) - interfaces = list(map(self.visit, get_interfaces(cls))) interface_thunk = None - if interfaces: + interfaces = list(map(self.visit, get_interfaces(cls))) + if interfaces or flattened_factories: def interface_thunk() -> Collection[graphql.GraphQLInterfaceType]: - result = { + all_interfaces = { cast(graphql.GraphQLInterfaceType, i.raw_type) for i in interfaces } - for flattened_factory in flattened_types.values(): + for flattened_factory in flattened_factories: flattened = cast( Union[graphql.GraphQLObjectType, graphql.GraphQLInterfaceType], flattened_factory.raw_type, ) - result.update(flattened.interfaces) - return sorted(result, key=lambda i: i.name) + if isinstance(flattened, graphql.GraphQLObjectType): + all_interfaces.update(flattened.interfaces) + elif isinstance(flattened, graphql.GraphQLInterfaceType): + all_interfaces.add(flattened) + return sorted(all_interfaces, key=lambda i: i.name) def factory( name: Optional[str], description: Optional[str] @@ -740,12 +770,15 @@ def factory( name = unwrap_name(name, cls) if is_interface(cls): return graphql.GraphQLInterfaceType( - name, field_thunk, interface_thunk, description=description + name, + lambda: merge_fields(cls, visited_fields), + interface_thunk, + description=description, ) else: return graphql.GraphQLObjectType( name, - field_thunk, + lambda: merge_fields(cls, visited_fields), interface_thunk, is_type_of=lambda obj, _: isinstance(obj, cls), description=description, @@ -783,8 +816,9 @@ class Operation(Generic[T]): function: Callable[..., T] alias: Optional[str] = None conversion: Optional[AnyConversion] = None - schema: Optional[Schema] = None error_handler: ErrorHandler = Undefined + order: Optional[Ordering] = None + schema: Optional[Schema] = None parameters_metadata: Mapping[str, Mapping] = field_(default_factory=dict) @@ -833,8 +867,9 @@ def wrapper(_, *args, **kwargs): return operation.alias or operation.function.__name__, Resolver( wrapper, operation.conversion, - operation.schema, error_handler, + operation.order, + operation.schema, parameters, operation.parameters_metadata, ) @@ -897,8 +932,9 @@ def graphql_schema( resolver = Resolver( sub_op.resolver, sub_op.conversion, - sub_op.schema, subscriber2.error_handler, + sub_op.order, + sub_op.schema, sub_parameters, sub_op.parameters_metadata, ) @@ -917,8 +953,9 @@ def graphql_schema( resolver = Resolver( lambda _: _, sub_op.conversion, - sub_op.schema, subscriber2.error_handler, + sub_op.order, + sub_op.schema, (), {}, ) diff --git a/apischema/json_schema/patterns.py b/apischema/json_schema/patterns.py index 35f44209..a52488d8 100644 --- a/apischema/json_schema/patterns.py +++ b/apischema/json_schema/patterns.py @@ -9,7 +9,7 @@ def infer_pattern(tp: AnyType, default_conversion: DefaultConversion) -> Pattern try: builder = DeserializationSchemaBuilder( - False, lambda s: s, default_conversion, False, lambda r: r, {} + False, default_conversion, False, lambda r: r, {} ) prop_schema = builder.visit(tp) except RecursionError: diff --git a/apischema/json_schema/schema.py b/apischema/json_schema/schema.py index 5590888b..bce52c6b 100644 --- a/apischema/json_schema/schema.py +++ b/apischema/json_schema/schema.py @@ -20,6 +20,8 @@ Union, ) +from dataclasses import dataclass + from apischema.aliases import Aliaser from apischema.conversions import converters from apischema.conversions.conversions import AnyConversion, DefaultConversion @@ -43,12 +45,13 @@ from apischema.json_schema.types import JsonSchema, JsonType, json_schema from apischema.json_schema.versions import JsonSchemaVersion, RefFactory from apischema.metadata.keys import SCHEMA_METADATA -from apischema.objects import ObjectField +from apischema.objects import AliasedStr, ObjectField from apischema.objects.visitor import ( DeserializationObjectVisitor, ObjectVisitor, SerializationObjectVisitor, ) +from apischema.ordering import Ordering, sort_by_order from apischema.schemas import Schema, get_schema from apischema.serialization import serialize from apischema.serialization.serialized_methods import get_serialized_methods @@ -61,7 +64,6 @@ get_origin_or_type, is_union_of, literal_values, - sort_by_annotations_position, ) @@ -75,6 +77,15 @@ def full_schema(base_schema: JsonSchema, schema: Optional[Schema]) -> JsonSchema Method = TypeVar("Method", bound=Callable) +@dataclass +class Property: + alias: str + name: str + ordering: Optional[Ordering] + required: bool + schema: JsonSchema + + class SchemaBuilder( ConversionsVisitor[Conv, JsonSchema], ObjectVisitor[JsonSchema], @@ -83,7 +94,6 @@ class SchemaBuilder( def __init__( self, additional_properties: bool, - aliaser: Aliaser, default_conversion: DefaultConversion, ignore_first_ref: bool, ref_factory: RefFactory, @@ -91,7 +101,6 @@ def __init__( ): super().__init__(default_conversion) self.additional_properties = additional_properties - self.aliaser = aliaser self._ignore_first_ref = ignore_first_ref self.ref_factory = ref_factory self.refs = refs @@ -162,18 +171,12 @@ def mapping( else: return json_schema(type=JsonType.OBJECT, additionalProperties=value) - def visit_field(self, field: ObjectField, required: bool = False) -> JsonSchema: + def visit_field(self, field: ObjectField, required: bool = True) -> JsonSchema: result = full_schema( self.visit_with_conv(field.type, self._field_conversion(field)), field.schema, ) - if ( - not field.flattened - and not field.pattern_properties - and not field.additional_properties - and not required - and "default" not in result - ): + if not field.is_aggregate and not required and "default" not in result: result = JsonSchema(result) with suppress(Exception): result["default"] = serialize( @@ -185,51 +188,52 @@ def visit_field(self, field: ObjectField, required: bool = False) -> JsonSchema: ) return result - def _properties_schema(self, field: ObjectField) -> JsonSchema: - assert field.pattern_properties is not None or field.additional_properties + def _object_schema(self, cls: type, field: ObjectField) -> JsonSchema: + assert field.is_aggregate with context_setter(self): self._ignore_first_ref = True - props_schema = self.visit_field(field) - if not props_schema.get("type") == JsonType.OBJECT: - raise TypeError("properties field must have an 'object' type") - if "patternProperties" in props_schema: - if ( - len(props_schema["patternProperties"]) != 1 - or "additionalProperties" in props_schema - ): # don't try to merge the schemas - pass - else: - return next(iter(props_schema["patternProperties"].values())) - elif "additionalProperties" in props_schema: - if isinstance(props_schema["additionalProperties"], JsonSchema): - return props_schema["additionalProperties"] - else: # there is maybe only properties - pass - return JsonSchema() + object_schema = self.visit_field(field) + if object_schema.get("type") not in {JsonType.OBJECT, "object"}: + field_type = "Flattened" if field.flattened else "Properties" + raise TypeError( + f"{field_type} field {cls.__name__}.{field.name}" + f" must have an object type" + ) + return object_schema - def _check_flattened_schema(self, cls: Type, field: ObjectField): - assert field.flattened - with context_setter(self): - self._ignore_first_ref = True - if self.visit_field(field).get("type") not in {JsonType.OBJECT, "object"}: - raise TypeError( - f"Flattened field {cls.__name__}.{field.name} must have an object type" - ) + def _properties_schema( + self, object_schema: JsonSchema, pattern: Optional[Pattern] = None + ): + if "patternProperties" in object_schema: + if pattern is not None: + for p in (pattern, pattern.pattern): + if p in object_schema["patternProperties"]: + return object_schema["patternProperties"][p] + elif ( + len(object_schema["patternProperties"]) == 1 + and "additionalProperties" not in object_schema + ): + return next(iter(object_schema["patternProperties"].values())) + if isinstance(object_schema.get("additionalProperties"), Mapping): + return object_schema["additionalProperties"] + return JsonSchema() - @staticmethod - def _field_required(field: ObjectField): - return field.required + def properties( + self, tp: AnyType, fields: Sequence[ObjectField] + ) -> Sequence[Property]: + raise NotImplementedError def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> JsonSchema: cls = get_origin_or_type(tp) + properties = sort_by_order( + cls, self.properties(tp, fields), lambda p: p.name, lambda p: p.ordering + ) flattened_schemas: List[JsonSchema] = [] pattern_properties = {} additional_properties: Union[bool, JsonSchema] = self.additional_properties - properties = {} - required = [] for field in fields: if field.flattened: - self._check_flattened_schema(cls, field) + self._object_schema(cls, field) # check the field is an object flattened_schemas.append(self.visit_field(field)) elif field.pattern_properties is not None: if field.pattern_properties is ...: @@ -237,24 +241,19 @@ def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> JsonSchema: else: assert isinstance(field.pattern_properties, Pattern) pattern = field.pattern_properties - pattern_properties[pattern] = self._properties_schema(field) + pattern_properties[pattern] = self._properties_schema( + self._object_schema(cls, field), pattern + ) elif field.additional_properties: - additional_properties = self._properties_schema(field) - else: - alias = self.aliaser(field.alias) - if is_typed_dict(cls): - is_required = field.required - else: - is_required = self._field_required(field) - properties[alias] = self.visit_field(field, is_required) - if is_required: - required.append(alias) + additional_properties = self._properties_schema( + self._object_schema(cls, field) + ) alias_by_names = {f.name: f.alias for f in fields}.__getitem__ dependent_required = get_dependent_required(cls) result = json_schema( type=JsonType.OBJECT, - properties=properties, - required=required, + properties={p.alias: p.schema for p in properties}, + required=[p.alias for p in properties if p.required], additionalProperties=additional_properties, patternProperties=pattern_properties, dependentRequired=OrderedDict( @@ -341,6 +340,21 @@ class DeserializationSchemaBuilder( ): RefsExtractor = DeserializationRefsExtractor + def properties( + self, tp: AnyType, fields: Sequence[ObjectField] + ) -> Sequence[Property]: + return [ + Property( + field.alias, + field.name, + field.ordering, + field.required, + self.visit_field(field, field.required), + ) + for field in fields + if not field.is_aggregate + ] + class SerializationSchemaBuilder( SchemaBuilder[Serialization], @@ -357,34 +371,42 @@ def _field_required(field: ObjectField): settings.serialization.exclude_defaults, settings.serialization.exclude_none ) - def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> JsonSchema: - result = super().object(tp, fields) - name_by_aliases = {f.alias: f.name for f in fields} - properties = {} - required = [] - for alias, (serialized, types) in get_serialized_methods(tp).items(): - return_type = types["return"] - properties[self.aliaser(alias)] = full_schema( - self.visit_with_conv(return_type, serialized.conversion), - serialized.schema, + def properties( + self, tp: AnyType, fields: Sequence[ObjectField] + ) -> Sequence[Property]: + from apischema import settings + + return [ + Property( + field.alias, + field.name, + field.ordering, + required, + self.visit_field(field, required), ) - if not is_union_of(return_type, UndefinedType): - required.append(alias) - name_by_aliases[alias] = serialized.func.__name__ - if "allOf" not in result: - to_update = result - else: - to_update = result["allOf"][0] - if required: - required.extend(to_update.get("required", ())) - to_update["required"] = sorted(required) - if properties: - properties.update(to_update.get("properties", {})) - props = sort_by_annotations_position( - get_origin_or_type(tp), properties, lambda p: name_by_aliases[p] + for field in fields + if not field.is_aggregate + for required in [ + field.required + if is_typed_dict(get_origin_or_type(tp)) + else not field.skippable( + settings.serialization.exclude_defaults, + settings.serialization.exclude_none, + ) + ] + ] + [ + Property( + AliasedStr(alias), + serialized.func.__name__, + serialized.ordering, + not is_union_of(types["return"], UndefinedType), + full_schema( + self.visit_with_conv(types["return"], serialized.conversion), + serialized.schema, + ), ) - to_update["properties"] = {p: properties[p] for p in props} - return result + for alias, (serialized, types) in get_serialized_methods(tp).items() + ] TypesWithConversion = Collection[Union[AnyType, Tuple[AnyType, AnyConversion]]] @@ -424,7 +446,6 @@ def _extract_refs( def _refs_schema( builder: Type[SchemaBuilder], - aliaser: Aliaser, default_conversion: DefaultConversion, refs: Mapping[str, AnyType], ref_factory: RefFactory, @@ -432,7 +453,7 @@ def _refs_schema( ) -> Mapping[str, JsonSchema]: return { ref: builder( - additional_properties, aliaser, default_conversion, True, ref_factory, refs + additional_properties, default_conversion, True, ref_factory, refs ).visit(tp) for ref, tp in refs.items() } @@ -461,17 +482,12 @@ def _schema( version, ref_factory, all_refs = _default_version(version, ref_factory, all_refs) refs = _extract_refs([(tp, conversion)], default_conversion, builder, all_refs) json_schema = builder( - additional_properties, aliaser, default_conversion, False, ref_factory, refs + additional_properties, default_conversion, False, ref_factory, refs ).visit_with_conv(tp, conversion) json_schema = full_schema(json_schema, schema) if add_defs: defs = _refs_schema( - builder, - aliaser, - default_conversion, - refs, - ref_factory, - additional_properties, + builder, default_conversion, refs, ref_factory, additional_properties ) if defs: json_schema["$defs"] = defs @@ -479,10 +495,10 @@ def _schema( JsonSchema, json_schema, aliaser=aliaser, - fall_back_on_any=True, check_type=True, conversion=version.conversion, default_conversion=converters.default_serialization, + fall_back_on_any=True, ) if with_schema and version.schema is not None: result["$schema"] = version.schema @@ -555,14 +571,12 @@ def _defs_schema( types: TypesWithConversion, default_conversion: DefaultConversion, builder: Type[SchemaBuilder], - aliaser: Aliaser, ref_factory: RefFactory, all_refs: bool, additional_properties: bool, ) -> Mapping[str, JsonSchema]: return _refs_schema( builder, - aliaser, default_conversion, _extract_refs(types, default_conversion, builder, all_refs), ref_factory, @@ -647,7 +661,6 @@ def definitions_schema( deserialization, default_deserialization, DeserializationSchemaBuilder, - aliaser, ref_factory, all_refs, additional_properties, @@ -656,7 +669,6 @@ def definitions_schema( serialization, default_serialization, SerializationSchemaBuilder, - aliaser, ref_factory, all_refs, additional_properties, diff --git a/apischema/metadata/__init__.py b/apischema/metadata/__init__.py index 82094ee5..3fb6556f 100644 --- a/apischema/metadata/__init__.py +++ b/apischema/metadata/__init__.py @@ -6,6 +6,7 @@ "flatten", "init_var", "none_as_undefined", + "order", "post_init", "properties", "required", @@ -18,6 +19,7 @@ import warnings from apischema.aliases import alias +from apischema.ordering import order from apischema.schemas import schema from .implem import ( conversion, diff --git a/apischema/metadata/keys.py b/apischema/metadata/keys.py index 883aa46d..de3d9241 100644 --- a/apischema/metadata/keys.py +++ b/apischema/metadata/keys.py @@ -8,6 +8,7 @@ FLATTEN_METADATA = f"{PREFIX}flattened" INIT_VAR_METADATA = f"{PREFIX}init_var" NONE_AS_UNDEFINED_METADATA = f"{PREFIX}none_as_undefined" +ORDERING_METADATA = f"{PREFIX}ordering" POST_INIT_METADATA = f"{PREFIX}post_init" PROPERTIES_METADATA = f"{PREFIX}properties" REQUIRED_METADATA = f"{PREFIX}required" diff --git a/apischema/objects/fields.py b/apischema/objects/fields.py index 38c7b4a1..b9b5d9a6 100644 --- a/apischema/objects/fields.py +++ b/apischema/objects/fields.py @@ -1,5 +1,6 @@ from dataclasses import Field, InitVar, MISSING, dataclass, field from enum import Enum, auto +from types import FunctionType from typing import ( Any, Callable, @@ -30,6 +31,7 @@ FALL_BACK_ON_DEFAULT_METADATA, FLATTEN_METADATA, NONE_AS_UNDEFINED_METADATA, + ORDERING_METADATA, POST_INIT_METADATA, PROPERTIES_METADATA, REQUIRED_METADATA, @@ -48,6 +50,7 @@ ) if TYPE_CHECKING: + from apischema.ordering import Ordering from apischema.schemas import Schema from apischema.validation.validators import Validator @@ -112,7 +115,7 @@ def override_alias(self) -> bool: @property def _conversion(self) -> Optional[ConversionMetadata]: - return self.metadata.get(CONVERSION_METADATA, None) + return self.metadata.get(CONVERSION_METADATA) @property def default_as_set(self) -> bool: @@ -152,17 +155,21 @@ def is_aggregate(self) -> bool: def none_as_undefined(self): return NONE_AS_UNDEFINED_METADATA in self.full_metadata + @property + def ordering(self) -> Optional["Ordering"]: + return self.full_metadata.get(ORDERING_METADATA) + @property def post_init(self) -> bool: return POST_INIT_METADATA in self.full_metadata @property def pattern_properties(self) -> Union[Pattern, "ellipsis", None]: # noqa: F821 - return self.full_metadata.get(PROPERTIES_METADATA, None) + return self.full_metadata.get(PROPERTIES_METADATA) @property def schema(self) -> Optional["Schema"]: - return self.metadata.get(SCHEMA_METADATA, None) + return self.metadata.get(SCHEMA_METADATA) @property def serialization(self) -> Optional[AnyConversion]: @@ -202,24 +209,34 @@ def validators(self) -> Sequence["Validator"]: FieldOrName = Union[str, ObjectField, Field] -def _bad_field(obj: Any) -> NoReturn: +def _bad_field(obj: Any, methods: bool) -> NoReturn: + method_types = "property/types.FunctionType" if methods else "" raise TypeError( - f"Expected dataclasses.Field/apischema.ObjectField/str, found {obj}" + f"Expected dataclasses.Field/apischema.ObjectField/str{method_types}, found {obj}" ) -def check_field_or_name(field_or_name: Any): - if not isinstance(field_or_name, (str, ObjectField, Field)): - _bad_field(field_or_name) +def check_field_or_name(field_or_name: Any, *, methods: bool = False): + method_types = (property, FunctionType) if methods else () + if not isinstance(field_or_name, (str, ObjectField, Field, *method_types)): + _bad_field(field_or_name, methods) -def get_field_name(field_or_name: Any) -> str: +def get_field_name(field_or_name: Any, *, methods: bool = False) -> str: if isinstance(field_or_name, (Field, ObjectField)): return field_or_name.name elif isinstance(field_or_name, str): return field_or_name + elif ( + methods + and isinstance(field_or_name, property) + and field_or_name.fget is not None + ): + return field_or_name.fget.__name__ + elif methods and isinstance(field_or_name, FunctionType): + return field_or_name.__name__ else: - _bad_field(field_or_name) + _bad_field(field_or_name, methods) _class_fields: MutableMapping[ diff --git a/apischema/ordering.py b/apischema/ordering.py new file mode 100644 index 00000000..9f57877a --- /dev/null +++ b/apischema/ordering.py @@ -0,0 +1,135 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + TypeVar, + overload, +) + +from apischema.cache import CacheAwareDict +from apischema.metadata.keys import ORDERING_METADATA +from apischema.types import MetadataMixin +from apischema.utils import stop_signature_abuse + +Cls = TypeVar("Cls", bound=type) + + +@dataclass(frozen=True) +class Ordering(MetadataMixin): + key = ORDERING_METADATA + order: Optional[int] = None + after: Optional[Any] = None + before: Optional[Any] = None + + def __post_init__(self): + from apischema.objects.fields import check_field_or_name + + if self.after is not None: + check_field_or_name(self.after, methods=True) + if self.before is not None: + check_field_or_name(self.before, methods=True) + + +_order_overriding: MutableMapping[type, Mapping[Any, Ordering]] = CacheAwareDict({}) + + +@overload +def order(__value: int) -> Ordering: + ... + + +@overload +def order(*, after: Any) -> Ordering: + ... + + +@overload +def order(*, before: Any) -> Ordering: + ... + + +@overload +def order(__fields: Sequence[Any]) -> Callable[[Cls], Cls]: + ... + + +@overload +def order(__override: Mapping[Any, Ordering]) -> Callable[[Cls], Cls]: + ... + + +def order(__arg=None, *, before=None, after=None): + if len([arg for arg in (__arg, before, after) if arg is not None]) != 1: + stop_signature_abuse() + if isinstance(__arg, Sequence): + __arg = {field: order(after=prev) for field, prev in zip(__arg[1:], __arg)} + if isinstance(__arg, Mapping): + if not all(isinstance(val, Ordering) for val in __arg.values()): + stop_signature_abuse() + + def decorator(cls: Cls) -> Cls: + _order_overriding[cls] = __arg + return cls + + return decorator + elif __arg is not None and not isinstance(__arg, int): + stop_signature_abuse() + else: + return Ordering(__arg, after, before) + + +def get_order_overriding(cls: type) -> Mapping[str, Ordering]: + from apischema.objects.fields import get_field_name + + return { + get_field_name(field, methods=True): ordering + for sub_cls in reversed(cls.__mro__) + if sub_cls in _order_overriding + for field, ordering in _order_overriding[sub_cls].items() + } + + +T = TypeVar("T") + + +def sort_by_order( + cls: type, + elts: Collection[T], + name: Callable[[T], str], + order: Callable[[T], Optional[Ordering]], +) -> Sequence[T]: + from apischema.objects.fields import get_field_name + + order_overriding = get_order_overriding(cls) + groups: Dict[int, List[T]] = defaultdict(list) + after: Dict[str, List[T]] = defaultdict(list) + before: Dict[str, List[T]] = defaultdict(list) + for elt in elts: + ordering = order_overriding.get(name(elt), order(elt)) + if ordering is None: + groups[0].append(elt) + elif ordering.order is not None: + groups[ordering.order].append(elt) + elif ordering.after is not None: + after[get_field_name(ordering.after, methods=True)].append(elt) + elif ordering.before is not None: + before[get_field_name(ordering.before, methods=True)].append(elt) + else: + raise NotImplementedError + if not after and not before and len(groups) == 1: + return next(iter(groups.values())) + result = [] + for value in sorted(groups): + for elt in groups[value]: + result.extend(before[name(elt)]) + result.append(elt) + result.extend(after[name(elt)]) + return result diff --git a/apischema/serialization/__init__.py b/apischema/serialization/__init__.py index 40b6e2ca..2bde24bf 100644 --- a/apischema/serialization/__init__.py +++ b/apischema/serialization/__init__.py @@ -28,6 +28,7 @@ from apischema.fields import FIELDS_SET_ATTR, support_fields_set from apischema.objects import AliasedStr, ObjectField from apischema.objects.visitor import SerializationObjectVisitor +from apischema.ordering import sort_by_order from apischema.recursion import RecursiveConversionsVisitor from apischema.serialization.pass_through import PassThroughOptions, pass_through from apischema.serialization.serialized_methods import get_serialized_methods @@ -180,6 +181,8 @@ def wrapper(obj: Any) -> Any: if isinstance(obj, cls_to_check): try: return method(obj) + except Unsupported: + raise except Exception: return fallback(obj) else: @@ -241,7 +244,7 @@ def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMet ) serialization_fields = [ ( - cast(Optional[str], field.name), + field.name, self.aliaser(field.alias) if not field.is_aggregate else None, getter(field.name), field.required, @@ -254,12 +257,13 @@ def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMet and default not in (None, Undefined), default, identity_as_none(self.visit_with_conv(field.type, field.serialization)), + field.ordering, ) for field in fields for default in [... if field.required else field.get_default()] ] + [ ( - None, + serialized.func.__name__, self.aliaser(name), serialized.func, True, @@ -269,10 +273,14 @@ def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMet False, ..., self.visit_with_conv(ret_type, serialized.conversion), + serialized.ordering, ) for name, (serialized, types) in get_serialized_methods(tp).items() for ret_type in [types["return"]] ] + serialization_fields = sort_by_order( # type: ignore + cls, serialization_fields, lambda f: f[0], lambda f: f[-1] + ) field_names = {f.name for f in fields} any_method = self.any() exclude_unset = self.exclude_unset and support_fields_set(cls) @@ -291,6 +299,7 @@ def method(obj: Any) -> Any: skip_default, default, serialize_field, + _, ) in serialization_fields: if (exclude_unset and name not in getattr(obj, FIELDS_SET_ATTR)) or ( typed_dict and not required and name not in obj @@ -374,6 +383,8 @@ def method(obj: Any) -> Any: if is_instance(obj, cls): try: return serialize_alt(obj) + except Unsupported: + raise except Exception: pass return fallback(obj) diff --git a/apischema/serialization/serialized_methods.py b/apischema/serialization/serialized_methods.py index 22b52f89..893e9e1c 100644 --- a/apischema/serialization/serialized_methods.py +++ b/apischema/serialization/serialized_methods.py @@ -21,6 +21,7 @@ from apischema.conversions.conversions import AnyConversion from apischema.conversions.dataclass_models import get_model_origin, has_model_origin from apischema.methods import method_registerer +from apischema.ordering import Ordering from apischema.schemas import Schema from apischema.types import AnyType, Undefined, UndefinedType from apischema.typing import generic_mro, get_args, get_origin, get_type_hints @@ -39,8 +40,9 @@ class SerializedMethod: func: Callable conversion: Optional[AnyConversion] - schema: Optional[Schema] error_handler: Optional[Callable] + ordering: Optional[Ordering] + schema: Optional[Schema] def error_type(self) -> AnyType: assert self.error_handler is not None @@ -125,8 +127,9 @@ def serialized( alias: str = None, *, conversion: AnyConversion = None, - schema: Schema = None, error_handler: ErrorHandler = Undefined, + order: Optional[Ordering] = None, + schema: Schema = None, owner: Type = None, ) -> Callable[[MethodOrProp], MethodOrProp]: ... @@ -138,8 +141,9 @@ def serialized( *, alias: str = None, conversion: AnyConversion = None, - schema: Schema = None, error_handler: ErrorHandler = Undefined, + order: Optional[Ordering] = None, + schema: Schema = None, owner: Type = None, ): def register(func: Callable, owner: Type, alias2: str): @@ -167,7 +171,7 @@ def func(self): return error_handler(error, self, alias2) assert not isinstance(error_handler2, UndefinedType) - serialized = SerializedMethod(func, conversion, schema, error_handler2) + serialized = SerializedMethod(func, conversion, error_handler2, order, schema) _serialized_methods[owner][alias2] = serialized if isinstance(__arg, str): diff --git a/docs/de_serialization.md b/docs/de_serialization.md index 827519bd..df984fc2 100644 --- a/docs/de_serialization.md +++ b/docs/de_serialization.md @@ -226,12 +226,43 @@ These settings can also be set directly using `serialize` parameters, like in th {!exclude_defaults_none.py!} ``` -### TypedDict additional properties +### Field ordering -`TypedDict` can contain additional keys, which are not serialized by default. Setting `additional_properties` parameter to `True` (or `apischema.settings.additional_properties`) will toggle on their serialization (without aliasing). +Usually, JSON object properties are unordered, but sometimes, order does matter. By default, fields, are ordered according to their declaration; serialized methods are appended after the fields. + +However, it's possible to change the ordering using `apischema.order`. + +#### Class-level ordering + +`order` can be used to decorate a class with the field ordered as expected: + +```python +{!class_ordering.py!} +``` + +#### Field-level ordering + +Each field has an order "value" (0 by default), and ordering is done by sorting fields using this value; if several fields have the same order value, they are sorted by their declaration order. For instance, assigning `-1` to a field will put it before every other fields, and `999` will surely put it at the end. + +This order value is set using `order`, this time as a field metadata (or passed to `order` argument of [serialized methods/properties](#serialized-methodsproperties)). It has the following overloaded signature: -## Performances +- `order(value: int, /)`: set the order value of the field +- `order(*, after)`: ignore the order value and put the field after the given field/method/property +- `order(*, before)`: ignore the order value and put the field before the given field/method/property +!!! note + `after` and `before` can be raw strings, but also dataclass fields, methods or properties. + +Also, `order` can again be used as class decorator to override ordering metadata, by passing this time a mapping of field with their overridden order. + +```python +{!ordering.py!} +``` + + +### TypedDict additional properties + +`TypedDict` can contain additional keys, which are not serialized by default. Setting `additional_properties` parameter to `True` (or `apischema.settings.additional_properties`) will toggle on their serialization (without aliasing). ## FAQ diff --git a/docs/graphql/schema.md b/docs/graphql/schema.md index 96f06db7..d592312b 100644 --- a/docs/graphql/schema.md +++ b/docs/graphql/schema.md @@ -8,7 +8,7 @@ In fact, `graphql_schema` is just a wrapper around `graphql.GraphQLSchema` (same ## Operations metadata -*GraphQL* operations can be passed to `graphql_schema` either using simple functions or wrapping it into `apischema.graphql.Query`/`apischema.graphql.Mutation`/`apischema.graphql.Subscription`. These wrappers have the same parameters as `apischema.graphql.resolver`: `alias`, `conversions`, `error_handler` and `schema` (`Subscription` has an [additional parameter](#subscriptions)). +*GraphQL* operations can be passed to `graphql_schema` either using simple functions or wrapping it into `apischema.graphql.Query`/`apischema.graphql.Mutation`/`apischema.graphql.Subscription`. These wrappers have the same parameters as `apischema.graphql.resolver`: `alias`, `conversions`, `error_handler`, `order` and `schema` (`Subscription` has an [additional parameter](#subscriptions)). ```python {!operation.py!} diff --git a/examples/class_ordering.py b/examples/class_ordering.py new file mode 100644 index 00000000..d35734a9 --- /dev/null +++ b/examples/class_ordering.py @@ -0,0 +1,14 @@ +import json +from dataclasses import dataclass + +from apischema import order, serialize + + +@order(["baz", "bar"]) +@dataclass +class Foo: + bar: int + baz: int + + +assert json.dumps(serialize(Foo, Foo(0, 0))) == '{"baz": 0, "bar": 0}' diff --git a/examples/ordering.py b/examples/ordering.py new file mode 100644 index 00000000..6d61d085 --- /dev/null +++ b/examples/ordering.py @@ -0,0 +1,42 @@ +import json +from dataclasses import dataclass, field +from datetime import date + +from apischema import order, serialize, serialized + + +@order({"trigram": order(-1)}) +@dataclass +class User: + firstname: str + lastname: str + address: str = field(metadata=order(after="birthdate")) + birthdate: date = field() + + @serialized + @property + def trigram(self) -> str: + return (self.firstname[0] + self.lastname[0] + self.lastname[-1]).lower() + + @serialized(order=order(before=birthdate)) + @property + def age(self) -> int: + age = date.today().year - self.birthdate.year + if age > 0 and (date.today().month, date.today().day) < ( + self.birthdate.month, + self.birthdate.day, + ): + age -= 1 + return age + + +user = User("Harry", "Potter", "London", date(1980, 7, 31)) +dump = """{ + "trigram": "hpr", + "firstname": "Harry", + "lastname": "Potter", + "age": 41, + "birthdate": "1980-07-31", + "address": "London" +}""" +assert json.dumps(serialize(User, user), indent=4) == dump diff --git a/examples/serialized.py b/examples/serialized.py index 55a478b7..2d9ee245 100644 --- a/examples/serialized.py +++ b/examples/serialized.py @@ -39,6 +39,6 @@ def function(foo: Foo) -> int: "baz": {"type": "integer"}, "function": {"type": "integer"}, }, - "required": ["aliased", "bar", "baz", "function"], + "required": ["bar", "baz", "aliased", "function"], "additionalProperties": False, } diff --git a/tests/test_flattened_conversion.py b/tests/test_flattened_conversion.py index b5e5fa7d..31888e57 100644 --- a/tests/test_flattened_conversion.py +++ b/tests/test_flattened_conversion.py @@ -23,31 +23,6 @@ class Data: data_field: Field = field(metadata=flatten) -json_schema = { - "$schema": "http://json-schema.org/draft/2019-09/schema#", - "type": "object", - "allOf": [ - {"type": "object", "additionalProperties": False}, - { - "type": "object", - "properties": {"attr": {"type": "integer"}}, - "required": ["attr"], - "additionalProperties": False, - }, - ], - "unevaluatedProperties": False, -} -graphql_schema_str = """\ -type Query { - getData: Data -} - -type Data { - attr: Int! -} -""" - - def get_data() -> Data: return Data(Field(0)) diff --git a/tests/test_resolver_position.py b/tests/test_resolver_position.py deleted file mode 100644 index d1e46a6d..00000000 --- a/tests/test_resolver_position.py +++ /dev/null @@ -1,65 +0,0 @@ -from dataclasses import dataclass -from typing import Callable, ClassVar - -from graphql.utilities import print_schema - -from apischema.graphql import graphql_schema, resolver -from apischema.json_schema import serialization_schema - - -@dataclass -class A: - a: int - b: ClassVar[Callable] - _: ClassVar[Callable] - - @resolver(serialized=True) # type: ignore - def b(self) -> int: - ... - - @resolver("c", serialized=True) # type: ignore - def _(self) -> int: - ... - - d: int - - -@dataclass -class B(A): - e: int - - -def query() -> B: - ... - - -def test_resolver_position(): - assert serialization_schema(B) == { - "type": "object", - "properties": { - "a": {"type": "integer"}, - "b": {"type": "integer"}, - "c": {"type": "integer"}, - "d": {"type": "integer"}, - "e": {"type": "integer"}, - }, - "required": ["a", "b", "c", "d", "e"], - "additionalProperties": False, - "$schema": "http://json-schema.org/draft/2019-09/schema#", - } - assert ( - print_schema(graphql_schema(query=[query])) - == """\ -type Query { - query: B! -} - -type B { - a: Int! - b: Int! - c: Int! - d: Int! - e: Int! -} -""" - ) diff --git a/tests/test_subscriptions.py b/tests/test_subscriptions.py index 0a6d0f7a..34d5b03a 100644 --- a/tests/test_subscriptions.py +++ b/tests/test_subscriptions.py @@ -58,7 +58,13 @@ async def test_subscription(alias, conversion, error_handler, resolver): sub_op = events else: sub_op = Subscription( - events, alias, conversion, None, error_handler, resolver=resolver + events, + alias=alias, + conversion=conversion, + error_handler=error_handler, + order=None, + schema=None, + resolver=resolver, ) schema = graphql_schema(query=[hello], subscription=[sub_op], types=[Event]) sub_field = sub_name