diff --git a/apischema/graphql/resolvers.py b/apischema/graphql/resolvers.py index 1f556c2c..14e0c451 100644 --- a/apischema/graphql/resolvers.py +++ b/apischema/graphql/resolvers.py @@ -7,6 +7,7 @@ Any, Awaitable, Callable, + Collection, Dict, Iterator, Mapping, @@ -124,7 +125,7 @@ def return_type(self, return_type: AnyType) -> AnyType: ) -def get_resolvers(tp: AnyType) -> Mapping[str, Tuple[Resolver, Mapping[str, AnyType]]]: +def get_resolvers(tp: AnyType) -> Collection[Tuple[Resolver, Mapping[str, AnyType]]]: return _get_methods(tp, _resolvers) @@ -194,6 +195,7 @@ def register(func: Callable, owner: Type, alias2: str): error_handler2 = None resolver = Resolver( func, + alias2, conversion, error_handler2, order, diff --git a/apischema/graphql/schema.py b/apischema/graphql/schema.py index abdb9422..5d013dc0 100644 --- a/apischema/graphql/schema.py +++ b/apischema/graphql/schema.py @@ -45,6 +45,7 @@ resolver_parameters, resolver_resolve, ) +from apischema.json_schema.schema import get_field_schema, get_method_schema, get_schema from apischema.metadata.keys import SCHEMA_METADATA from apischema.objects import ObjectField from apischema.objects.visitor import ( @@ -54,7 +55,7 @@ ) from apischema.ordering import Ordering, sort_by_order from apischema.recursion import RecursiveConversionsVisitor -from apischema.schemas import Schema, get_schema, merge_schema +from apischema.schemas import Schema, merge_schema from apischema.serialization import SerializationMethod, serialize from apischema.serialization.serialized_methods import ErrorHandler from apischema.type_names import TypeName, TypeNameFactory, get_type_name @@ -111,6 +112,16 @@ def exec_thunk(thunk: TypeThunk, *, non_null=None) -> Any: return result +def get_parameter_schema( + func: Callable, parameter: Parameter, field: ObjectField +) -> Optional[Schema]: + from apischema import settings + + return merge_schema( + settings.base_schema.parameter(func, parameter, field.alias), field.schema + ) + + def merged_schema( schema: Optional[Schema], tp: Optional[AnyType] ) -> Tuple[Optional[Schema], Mapping[str, Any]]: @@ -151,7 +162,6 @@ def get_deprecated( @dataclass(frozen=True) class ResolverField: - alias: str resolver: Resolver types: Mapping[str, AnyType] parameters: Sequence[Parameter] @@ -480,7 +490,9 @@ class InputSchemaBuilder( ): types = graphql.type.definition.graphql_input_types - def _field(self, field: ObjectField) -> Lazy[graphql.GraphQLInputField]: + def _field( + self, tp: AnyType, field: ObjectField + ) -> Lazy[graphql.GraphQLInputField]: field_type = field.type field_default = graphql.Undefined if field.required else field.get_default() default: Any = graphql.Undefined @@ -501,7 +513,7 @@ def _field(self, field: ObjectField) -> Lazy[graphql.GraphQLInputField]: return lambda: graphql.GraphQLInputField( factory.type, # type: ignore default_value=default, - description=get_description(field.schema, field.type), + description=get_description(get_field_schema(tp, field), field.type), ) @cache_type @@ -514,7 +526,7 @@ def object( normal_field = NormalField( self.aliaser(field.alias), field.name, - self._field(field), + self._field(tp, field), field.ordering, ) visited_fields.append(normal_field) @@ -602,7 +614,7 @@ def resolve_wrapper(__obj, __info, **kwargs): return cast(Func, resolve_wrapper) - def _field(self, field: ObjectField) -> Lazy[graphql.GraphQLField]: + def _field(self, tp: AnyType, field: ObjectField) -> Lazy[graphql.GraphQLField]: field_name = field.name partial_serialize = self._field_serialization_method(field) @@ -611,15 +623,18 @@ def resolve(obj, _): return partial_serialize(getattr(obj, field_name)) factory = self.visit_with_conv(field.type, field.serialization) + field_schema = get_field_schema(tp, field) return lambda: graphql.GraphQLField( factory.type, None, resolve, - description=get_description(field.schema, field.type), - deprecation_reason=get_deprecated(field.schema, field.type), + description=get_description(field_schema, field.type), + deprecation_reason=get_deprecated(field_schema, field.type), ) - def _resolver(self, field: ResolverField) -> Lazy[graphql.GraphQLField]: + def _resolver( + self, tp: AnyType, field: ResolverField + ) -> Lazy[graphql.GraphQLField]: resolve = self._wrap_resolve( resolver_resolve( field.resolver, @@ -665,7 +680,10 @@ def _resolver(self, field: ResolverField) -> Lazy[graphql.GraphQLField]: arg_factory = self.input_builder.visit_with_conv( param_type, param_field.deserialization ) - description = get_description(param_field.schema, param_field.type) + description = get_description( + get_parameter_schema(field.resolver.func, param, param_field), + param_field.type, + ) def arg_thunk( arg_factory=arg_factory, default=default, description=description @@ -676,13 +694,14 @@ def arg_thunk( args[self.aliaser(param_field.alias)] = arg_thunk factory = self.visit_with_conv(field.types["return"], field.resolver.conversion) + field_schema = get_method_schema(tp, field.resolver) return lambda: graphql.GraphQLField( factory.type, # type: ignore {name: arg() for name, arg in args.items()} if args else None, resolve, field.subscribe, - get_description(field.resolver.schema), - get_deprecated(field.resolver.schema), + get_description(field_schema), + get_deprecated(field_schema), ) def _visit_flattened( @@ -716,7 +735,7 @@ def object( normal_field = NormalField( self.aliaser(field.name), field.name, - self._field(field), + self._field(tp, field), field.ordering, ) visited_fields.append(normal_field) @@ -727,20 +746,16 @@ def object( FlattenedField(field.name, field.ordering, flattened_factory) ) resolvers = list(resolvers) - for alias, (resolver, types) in get_resolvers(tp).items(): + for resolver, types in get_resolvers(tp): resolver_field = ResolverField( - alias, - resolver, - types, - resolver.parameters, - resolver.parameters_metadata, + resolver, types, resolver.parameters, resolver.parameters_metadata ) resolvers.append(resolver_field) for resolver_field in resolvers: normal_field = NormalField( - self.aliaser(resolver_field.alias), + self.aliaser(resolver_field.resolver.alias), resolver_field.resolver.func.__name__, - self._resolver(resolver_field), + self._resolver(tp, resolver_field), resolver_field.resolver.ordering, ) visited_fields.append(normal_field) @@ -838,9 +853,7 @@ class Subscription(Operation[AsyncIterable]): Op = TypeVar("Op", bound=Operation) -def operation_resolver( - operation: Union[Callable, Op], op_class: Type[Op] -) -> Tuple[str, Resolver]: +def operation_resolver(operation: Union[Callable, Op], op_class: Type[Op]) -> Resolver: if not isinstance(operation, op_class): operation = op_class(operation) # type: ignore error_handler: Optional[Callable] @@ -864,8 +877,9 @@ def wrapper(_, *args, **kwargs): wrapper.__annotations__ = op.__annotations__ (*parameters,) = resolver_parameters(operation.function, check_first=True) - return operation.alias or operation.function.__name__, Resolver( + return Resolver( wrapper, + operation.alias or operation.function.__name__, operation.conversion, error_handler, operation.order, @@ -912,9 +926,8 @@ def graphql_schema( (mutation, Mutation, mutation_fields), ]: for operation in operations: # type: ignore - alias, resolver = operation_resolver(operation, op_class) + resolver = operation_resolver(operation, op_class) resolver_field = ResolverField( - alias, resolver, resolver.types(), resolver.parameters, @@ -926,11 +939,11 @@ def graphql_schema( sub_op = Subscription(sub_op) # type: ignore sub_parameters: Sequence[Parameter] if sub_op.resolver is not None: - alias = sub_op.alias or sub_op.resolver.__name__ - _, subscriber2 = operation_resolver(sub_op, Subscription) + subscriber2 = operation_resolver(sub_op, Subscription) _, *sub_parameters = resolver_parameters(sub_op.resolver, check_first=False) resolver = Resolver( sub_op.resolver, + sub_op.alias or sub_op.resolver.__name__, sub_op.conversion, subscriber2.error_handler, sub_op.order, @@ -949,9 +962,10 @@ def graphql_schema( serialized=False, ) else: - alias, subscriber2 = operation_resolver(sub_op, Subscription) + subscriber2 = operation_resolver(sub_op, Subscription) resolver = Resolver( lambda _: _, + subscriber2.alias, sub_op.conversion, subscriber2.error_handler, sub_op.order, @@ -978,12 +992,7 @@ def graphql_schema( sub_types = {**sub_types, "return": resolver.return_type(event_type)} resolver_field = ResolverField( - alias, - resolver, - sub_types, - sub_parameters, - sub_op.parameters_metadata, - subscribe, + resolver, sub_types, sub_parameters, sub_op.parameters_metadata, subscribe ) subscription_fields.append(resolver_field) diff --git a/apischema/json_schema/schema.py b/apischema/json_schema/schema.py index 2cddd242..a20ae6bd 100644 --- a/apischema/json_schema/schema.py +++ b/apischema/json_schema/schema.py @@ -1,7 +1,6 @@ from contextlib import suppress from dataclasses import dataclass from enum import Enum -from functools import reduce from itertools import chain from typing import ( AbstractSet, @@ -51,9 +50,12 @@ SerializationObjectVisitor, ) from apischema.ordering import Ordering, sort_by_order -from apischema.schemas import Schema, get_schema +from apischema.schemas import Schema, get_schema as _get_schema, merge_schema from apischema.serialization import serialize -from apischema.serialization.serialized_methods import get_serialized_methods +from apischema.serialization.serialized_methods import ( + SerializedMethod, + get_serialized_methods, +) from apischema.type_names import TypeNameFactory, get_type_name from apischema.types import AnyType, OrderedDict, UndefinedType from apischema.typing import get_args, is_typed_dict @@ -66,6 +68,29 @@ ) +def get_schema(tp: AnyType) -> Optional[Schema]: + from apischema import settings + + return merge_schema(settings.base_schema.type(tp), _get_schema(tp)) + + +def get_field_schema(tp: AnyType, field: ObjectField) -> Optional[Schema]: + from apischema import settings + + assert not field.is_aggregate + return merge_schema( + settings.base_schema.field(tp, field.name, field.alias), field.schema + ) + + +def get_method_schema(tp: AnyType, method: SerializedMethod) -> Optional[Schema]: + from apischema import settings + + return merge_schema( + settings.base_schema.method(tp, method.func, method.alias), method.schema + ) + + def full_schema(base_schema: JsonSchema, schema: Optional[Schema]) -> JsonSchema: if schema is not None: base_schema = JsonSchema(base_schema) @@ -76,9 +101,9 @@ def full_schema(base_schema: JsonSchema, schema: Optional[Schema]) -> JsonSchema Method = TypeVar("Method", bound=Callable) -@dataclass +@dataclass(frozen=True) class Property: - alias: str + alias: AliasedStr name: str ordering: Optional[Ordering] required: bool @@ -115,18 +140,16 @@ def ref_schema(self, ref: Optional[str]) -> Optional[JsonSchema]: return JsonSchema({"$ref": self.ref_factory(ref)}) def annotated(self, tp: AnyType, annotations: Sequence[Any]) -> JsonSchema: - schemas: List[Optional[Schema]] = [] + schema = None for annotation in reversed(annotations): if isinstance(annotation, TypeNameFactory): ref = annotation.to_type_name(tp).json_schema ref_schema = self.ref_schema(ref) if ref_schema is not None: - return reduce(full_schema, reversed(schemas), ref_schema) + return full_schema(ref_schema, schema) if isinstance(annotation, Mapping): - schemas.append(annotation.get(SCHEMA_METADATA)) - return reduce( - full_schema, reversed(schemas), super().annotated(tp, annotations) - ) + schema = merge_schema(annotation.get(SCHEMA_METADATA), schema) + return full_schema(super().annotated(tp, annotations), schema) def any(self) -> JsonSchema: return JsonSchema() @@ -170,12 +193,15 @@ def mapping( else: return json_schema(type=JsonType.OBJECT, additionalProperties=value) - def visit_field(self, field: ObjectField, required: bool = True) -> JsonSchema: + def visit_field( + self, tp: AnyType, field: ObjectField, required: bool = True + ) -> JsonSchema: + assert not field.is_aggregate result = full_schema( self.visit_with_conv(field.type, self._field_conversion(field)), - field.schema, + get_field_schema(tp, field) if tp is not None else field.schema, ) - if not field.is_aggregate and not required and "default" not in result: + if not required and "default" not in result: result = JsonSchema(result) with suppress(Exception): result["default"] = serialize( @@ -191,7 +217,10 @@ def _object_schema(self, cls: type, field: ObjectField) -> JsonSchema: assert field.is_aggregate with context_setter(self): self._ignore_first_ref = True - object_schema = self.visit_field(field) + object_schema = full_schema( + self.visit_with_conv(field.type, self._field_conversion(field)), + field.schema, + ) if object_schema.get("type") not in {JsonType.OBJECT, "object"}: field_type = "Flattened" if field.flattened else "Properties" raise TypeError( @@ -233,7 +262,12 @@ def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> JsonSchema: for field in fields: if field.flattened: self._object_schema(cls, field) # check the field is an object - flattened_schemas.append(self.visit_field(field)) + flattened_schemas.append( + full_schema( + self.visit_with_conv(field.type, self._field_conversion(field)), + field.schema, + ) + ) elif field.pattern_properties is not None: if field.pattern_properties is ...: pattern = infer_pattern(field.type, self.default_conversion) @@ -318,17 +352,17 @@ def visit_conversion( dynamic: bool, next_conversion: Optional[AnyConversion] = None, ) -> JsonSchema: - schemas = [] + schema = None if not dynamic: for ref_tp in self.resolve_conversion(tp): ref_schema = self.ref_schema(get_type_name(ref_tp).json_schema) if ref_schema is not None: return ref_schema if get_args(tp): - schemas.append(get_schema(get_origin_or_type(tp))) - schemas.append(get_schema(tp)) + schema = merge_schema(schema, get_schema(get_origin_or_type(tp))) + schema = merge_schema(schema, get_schema(tp)) result = super().visit_conversion(tp, conversion, dynamic, next_conversion) - return reduce(full_schema, schemas, result) + return full_schema(result, schema) RefsExtractor: ClassVar[Type[RefsExtractor_]] @@ -345,11 +379,11 @@ def properties( ) -> Sequence[Property]: return [ Property( - field.alias, + AliasedStr(field.alias), field.name, field.ordering, field.required, - self.visit_field(field, field.required), + self.visit_field(tp, field, field.required), ) for field in fields if not field.is_aggregate @@ -378,11 +412,11 @@ def properties( return [ Property( - field.alias, + AliasedStr(field.alias), field.name, field.ordering, required, - self.visit_field(field, required), + self.visit_field(tp, field, required), ) for field in fields if not field.is_aggregate @@ -396,16 +430,16 @@ def properties( ] ] + [ Property( - AliasedStr(alias), + AliasedStr(serialized.alias), serialized.func.__name__, serialized.ordering, not is_union_of(types["return"], UndefinedType), full_schema( self.visit_with_conv(types["return"], serialized.conversion), - serialized.schema, + get_method_schema(tp, serialized), ), ) - for alias, (serialized, types) in get_serialized_methods(tp).items() + for serialized, types in get_serialized_methods(tp) ] diff --git a/apischema/schemas/__init__.py b/apischema/schemas/__init__.py index 7e02d80e..1914d32a 100644 --- a/apischema/schemas/__init__.py +++ b/apischema/schemas/__init__.py @@ -128,13 +128,11 @@ def default_schema(tp: AnyType) -> Optional[Schema]: def get_schema(tp: AnyType) -> Optional[Schema]: - from apischema import settings - tp = replace_builtins(tp) try: - return _schemas[tp] - except (KeyError, TypeError): - return settings.default_schema(tp) + return _schemas.get(tp) + except TypeError: + return None @merge_opts diff --git a/apischema/serialization/__init__.py b/apischema/serialization/__init__.py index 8b652250..424b0189 100644 --- a/apischema/serialization/__init__.py +++ b/apischema/serialization/__init__.py @@ -272,7 +272,7 @@ def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMet ] + [ ( serialized.func.__name__, - self.aliaser(name), + self.aliaser(serialized.alias), serialized.func, True, None, @@ -283,7 +283,7 @@ def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMet self.visit_with_conv(ret_type, serialized.conversion), serialized.ordering, ) - for name, (serialized, types) in get_serialized_methods(tp).items() + for serialized, types in get_serialized_methods(tp) for ret_type in [types["return"]] ] serialization_fields = sort_by_order( # type: ignore diff --git a/apischema/serialization/serialized_methods.py b/apischema/serialization/serialized_methods.py index fc058301..0ba02d4d 100644 --- a/apischema/serialization/serialized_methods.py +++ b/apischema/serialization/serialized_methods.py @@ -5,6 +5,7 @@ from typing import ( Any, Callable, + Collection, Dict, Mapping, MutableMapping, @@ -37,6 +38,7 @@ @dataclass(frozen=True) class SerializedMethod: func: Callable + alias: str conversion: Optional[AnyConversion] error_handler: Optional[Callable] ordering: Optional[Ordering] @@ -85,17 +87,17 @@ def types(self, owner: AnyType = None) -> Mapping[str, AnyType]: def _get_methods( tp: AnyType, all_methods: Mapping[Type, Mapping[str, S]] -) -> Mapping[str, Tuple[S, Mapping[str, AnyType]]]: +) -> Collection[Tuple[S, Mapping[str, AnyType]]]: result = {} for base in reversed(generic_mro(tp)): for name, method in all_methods[get_origin_or_type(base)].items(): result[name] = (method, method.types(base)) - return result + return result.values() def get_serialized_methods( tp: AnyType, -) -> Mapping[str, Tuple[SerializedMethod, Mapping[str, AnyType]]]: +) -> Collection[Tuple[SerializedMethod, Mapping[str, AnyType]]]: return _get_methods(tp, _serialized_methods) @@ -163,8 +165,9 @@ def func(self): return error_handler(error, self, alias2) assert not isinstance(error_handler2, UndefinedType) - serialized = SerializedMethod(func, conversion, error_handler2, order, schema) - _serialized_methods[owner][alias2] = serialized + _serialized_methods[owner][alias2] = SerializedMethod( + func, alias2, conversion, error_handler2, order, schema + ) if isinstance(__arg, str): alias = __arg diff --git a/apischema/settings.py b/apischema/settings.py index 59cf5c62..c4a26df2 100644 --- a/apischema/settings.py +++ b/apischema/settings.py @@ -1,3 +1,5 @@ +import warnings +from inspect import Parameter from typing import Callable, Optional, Sequence from apischema import cache @@ -11,7 +13,7 @@ from apischema.json_schema import JsonSchemaVersion from apischema.objects import ObjectField from apischema.objects.fields import default_object_fields as default_object_fields_ -from apischema.schemas import Schema, default_schema as default_schema_ +from apischema.schemas import Schema from apischema.serialization import PassThroughOptions from apischema.type_names import TypeName, default_type_name as default_type_name_ from apischema.types import AnyType @@ -33,6 +35,18 @@ def camel_case(self) -> bool: def camel_case(self, value: bool): settings.aliaser = to_camel_case if value else lambda s: s + def __setattr__(self, name, value): + if name == "default_schema" and not isinstance(value, ResetCache): + warnings.warn( + "settings.default_schema is deprecated," + " use settings.base_schema.type instead", + DeprecationWarning, + ) + assert self is settings + self.base_schema.type = value # type: ignore + else: + super().__setattr__(name, value) + class settings(metaclass=MetaSettings): additional_properties: bool = False @@ -40,10 +54,18 @@ class settings(metaclass=MetaSettings): default_object_fields: Callable[ [type], Optional[Sequence[ObjectField]] ] = default_object_fields_ - default_schema: Callable[[AnyType], Optional[Schema]] = default_schema_ + default_schema: Callable[[AnyType], Optional[Schema]] = lambda *_: None default_type_name: Callable[[AnyType], Optional[TypeName]] = default_type_name_ json_schema_version: JsonSchemaVersion = JsonSchemaVersion.DRAFT_2020_12 + class base_schema: + field: Callable[[AnyType, str, str], Optional[Schema]] = lambda *_: None + method: Callable[[AnyType, Callable, str], Optional[Schema]] = lambda *_: None + parameter: Callable[ + [Callable, Parameter, str], Optional[Schema] + ] = lambda *_: None + type: Callable[[AnyType], Optional[Schema]] = lambda *_: None + class deserialization(metaclass=ResetCache): coerce: bool = False coercer: Coercer = coerce_ diff --git a/apischema/type_names.py b/apischema/type_names.py index 40cb6dff..24643aa9 100644 --- a/apischema/type_names.py +++ b/apischema/type_names.py @@ -1,3 +1,4 @@ +import collections.abc import warnings from contextlib import suppress from dataclasses import dataclass @@ -5,7 +6,7 @@ from apischema.cache import CacheAwareDict from apischema.types import AnyType, PRIMITIVE_TYPES -from apischema.typing import get_args, get_origin, is_type_var +from apischema.typing import get_args, get_origin, is_named_tuple, is_type_var from apischema.utils import has_type_vars, merge_opts, replace_builtins @@ -70,6 +71,11 @@ def default_type_name(tp: AnyType) -> Optional[TypeName]: and not get_args(tp) and not has_type_vars(tp) and tp not in PRIMITIVE_TYPES + and ( + not isinstance(tp, type) + or not issubclass(tp, collections.abc.Collection) + or is_named_tuple(tp) + ) ): return TypeName(tp.__name__, tp.__name__) else: diff --git a/docs/graphql/data_model_and_resolvers.md b/docs/graphql/data_model_and_resolvers.md index a0cdd3f7..d110da65 100644 --- a/docs/graphql/data_model_and_resolvers.md +++ b/docs/graphql/data_model_and_resolvers.md @@ -95,6 +95,13 @@ Resolvers parameters can have metadata like dataclass fields. They can be passed !!! note Metadata can also be passed with `parameters_metadata` parameter; it takes a mapping of parameter names as key and mapped metadata as value. +### Parameters base schema + +Following the example of [type/field/method base schema](../json_schema.md#base-schema), resolver parameters also support a base schema definition + +```python +{!base_schema_parameter.py!} +``` ## ID type *GraphQL* `ID` has no precise specification and is defined according API needs; it can be a UUID or and ObjectId, etc. diff --git a/docs/json_schema.md b/docs/json_schema.md index 95273d79..109ac0a3 100644 --- a/docs/json_schema.md +++ b/docs/json_schema.md @@ -92,32 +92,24 @@ JSON schema constrains the data deserialized; these constraints are naturally us {!validation_error.py!} ``` -### Default `schema` +### Extra schema -When no schema are defined, a default schema can be computed using `settings.default_schema` like this: +`schema` has two other arguments: `extra` and `override`, which give a finer control of the JSON schema generated: `extra` and `override`. It can be used for example to build "strict" unions (using `oneOf` instead of `anyOf`) ```python -from typing import Optional -from apischema import schema, settings -from apischema.schemas import Schema - - -def default_schema(cls) -> Optional[Schema]: - return schema(...) if ... else None - -settings.default_schema = default_schema -``` - -Default implementation returns `None` for every types. +{!strict_union.py!} +``` -### Extra schema +### Base `schema` -`schema` has two other arguments: `extra` and `override`, which give a finer control of the JSON schema generated: `extra` and `override`. It can be used for example to build "strict" unions (using `oneOf` instead of `anyOf`) +`apischema.settings.base_schema` can be used to define "base schema" for the different kind of objects: types, object fields or (serialized) methods. ```python -{!strict_union.py!} +{!base_schema.py!} ``` +Base schema will be merged with `schema` defined at type/field/method level. + ## Required field with default value By default, a dataclass/namedtuple field will be tagged `required` if it doesn't have a default value. diff --git a/examples/base_schema.py b/examples/base_schema.py new file mode 100644 index 00000000..57ef7410 --- /dev/null +++ b/examples/base_schema.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from typing import Any, Callable, Optional, get_origin + +import docstring_parser + +from apischema import schema, serialized, settings +from apischema.json_schema import serialization_schema +from apischema.schemas import Schema +from apischema.type_names import get_type_name + + +@dataclass +class Foo: + """Foo class + + :var bar: bar attribute""" + + bar: str + + @serialized + @property + def baz(self) -> int: + """baz method""" + ... + + +def type_base_schema(tp: Any) -> Optional[Schema]: + if not hasattr(tp, "__doc__"): + return None + return schema( + title=get_type_name(tp).json_schema, + description=docstring_parser.parse(tp.__doc__).short_description, + ) + + +def field_base_schema(tp: Any, name: str, alias: str) -> Optional[Schema]: + title = alias.replace("_", " ").capitalize() + tp = get_origin(tp) or tp # tp can be generic + for meta in docstring_parser.parse(tp.__doc__).meta: + if meta.args == ["var", name]: + return schema(title=title, description=meta.description) + return schema(title=title) + + +def method_base_schema(tp: Any, method: Callable, alias: str) -> Optional[Schema]: + return schema( + title=alias.replace("_", " ").capitalize(), + description=docstring_parser.parse(method.__doc__).short_description, + ) + + +settings.base_schema.type = type_base_schema +settings.base_schema.field = field_base_schema +settings.base_schema.method = method_base_schema + +assert serialization_schema(Foo) == { + "$schema": "http://json-schema.org/draft/2020-12/schema#", + "additionalProperties": False, + "title": "Foo", + "description": "Foo class", + "properties": { + "bar": {"description": "bar attribute", "title": "Bar", "type": "string"}, + "baz": {"description": "baz method", "title": "Baz", "type": "integer"}, + }, + "required": ["bar", "baz"], + "type": "object", +} diff --git a/examples/base_schema_parameter.py b/examples/base_schema_parameter.py new file mode 100644 index 00000000..3df7a954 --- /dev/null +++ b/examples/base_schema_parameter.py @@ -0,0 +1,59 @@ +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import docstring_parser +from graphql.utilities import print_schema + +from apischema import schema, settings +from apischema.graphql import graphql_schema, resolver +from apischema.schemas import Schema + + +@dataclass +class Foo: + @resolver + def bar(self, arg: str) -> int: + """bar method + + :param arg: arg parameter + """ + ... + + +def method_base_schema(tp: Any, method: Callable, alias: str) -> Optional[Schema]: + return schema(description=docstring_parser.parse(method.__doc__).short_description) + + +def parameter_base_schema( + method: Callable, parameter: inspect.Parameter, alias: str +) -> Optional[Schema]: + for doc_param in docstring_parser.parse(method.__doc__).params: + if doc_param.arg_name == parameter.name: + return schema(description=doc_param.description) + return None + + +settings.base_schema.method = method_base_schema +settings.base_schema.parameter = parameter_base_schema + + +def foo() -> Foo: + ... + + +schema_ = graphql_schema(query=[foo]) +schema_str = '''\ +type Query { + foo: Foo! +} + +type Foo { + """bar method""" + bar( + """arg parameter""" + arg: String! + ): Int! +} +''' +assert print_schema(schema_) == schema_str diff --git a/scripts/test_wrapper.py b/scripts/test_wrapper.py index 0728daed..a631a955 100644 --- a/scripts/test_wrapper.py +++ b/scripts/test_wrapper.py @@ -9,9 +9,21 @@ from typing import * from unittest.mock import MagicMock -from apischema.typing import Annotated, Literal, TypedDict, get_args - +import pytest + +from apischema import settings +from apischema.typing import ( + Annotated, + Literal, + TypedDict, + get_args, + get_origin, + is_type, +) + +typing.get_origin, typing.get_args = get_origin, get_args typing.Annotated, typing.Literal, typing.TypedDict = Annotated, Literal, TypedDict +inspect.isclass = is_type if sys.version_info < (3, 9): class CollectionABC: @@ -51,8 +63,35 @@ def __subclasscheck__(self, subclass): if sys.version_info < (3, 7): asyncio.run = lambda coro: asyncio.get_event_loop().run_until_complete(coro) -inspect.isclass = lambda tp: isinstance(tp, type) and not get_args(tp) __timeit = timeit.timeit timeit.timeit = lambda stmt, number=None, **kwargs: __timeit(stmt, number=1, **kwargs) sys.modules["orjson"] = json + +settings_classes = ( + settings, + settings.base_schema, + settings.deserialization, + settings.serialization, +) +settings_dicts = {cls: dict(cls.__dict__) for cls in settings_classes} + +## test body + + +def set_settings(dicts: Mapping[type, Mapping[str, Any]]): + for cls, dict_ in dicts.items(): + for key, value in dict_.items(): + if not key.startswith("_"): + setattr(cls, key, value) + + +test_dicts = {cls: dict(cls.__dict__) for cls in settings_classes} +set_settings(settings_dicts) + + +@pytest.fixture(autouse=True) +def test_settings(monkeypatch): + set_settings(test_dicts) + yield + set_settings(settings_dicts) diff --git a/setup.py b/setup.py index f19b5ef8..ccf6d34b 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ "examples": [ "graphql-core>=3.1.2", "attrs", + "docstring_parser", "bson", "orjson", "pydantic", diff --git a/tests/requirements.txt b/tests/requirements.txt index 66c5d89d..8fb6ee51 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -2,6 +2,7 @@ dataclasses;python_version<'3.7' graphql-core attrs bson +docstring_parser pydantic pytest pytest-cov