Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add base schema and deprecate default_schema. #215

Merged
merged 3 commits into from
Oct 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion apischema/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterator,
Mapping,
Expand Down Expand Up @@ -124,7 +125,7 @@ def return_type(self, return_type: AnyType) -> AnyType:
)


def get_resolvers(tp: AnyType) -> Mapping[str, Tuple[Resolver, Mapping[str, AnyType]]]:
def get_resolvers(tp: AnyType) -> Collection[Tuple[Resolver, Mapping[str, AnyType]]]:
return _get_methods(tp, _resolvers)


Expand Down Expand Up @@ -194,6 +195,7 @@ def register(func: Callable, owner: Type, alias2: str):
error_handler2 = None
resolver = Resolver(
func,
alias2,
conversion,
error_handler2,
order,
Expand Down
81 changes: 45 additions & 36 deletions apischema/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
resolver_parameters,
resolver_resolve,
)
from apischema.json_schema.schema import get_field_schema, get_method_schema, get_schema
from apischema.metadata.keys import SCHEMA_METADATA
from apischema.objects import ObjectField
from apischema.objects.visitor import (
Expand All @@ -54,7 +55,7 @@
)
from apischema.ordering import Ordering, sort_by_order
from apischema.recursion import RecursiveConversionsVisitor
from apischema.schemas import Schema, get_schema, merge_schema
from apischema.schemas import Schema, merge_schema
from apischema.serialization import SerializationMethod, serialize
from apischema.serialization.serialized_methods import ErrorHandler
from apischema.type_names import TypeName, TypeNameFactory, get_type_name
Expand Down Expand Up @@ -111,6 +112,16 @@ def exec_thunk(thunk: TypeThunk, *, non_null=None) -> Any:
return result


def get_parameter_schema(
func: Callable, parameter: Parameter, field: ObjectField
) -> Optional[Schema]:
from apischema import settings

return merge_schema(
settings.base_schema.parameter(func, parameter, field.alias), field.schema
)


def merged_schema(
schema: Optional[Schema], tp: Optional[AnyType]
) -> Tuple[Optional[Schema], Mapping[str, Any]]:
Expand Down Expand Up @@ -151,7 +162,6 @@ def get_deprecated(

@dataclass(frozen=True)
class ResolverField:
alias: str
resolver: Resolver
types: Mapping[str, AnyType]
parameters: Sequence[Parameter]
Expand Down Expand Up @@ -480,7 +490,9 @@ class InputSchemaBuilder(
):
types = graphql.type.definition.graphql_input_types

def _field(self, field: ObjectField) -> Lazy[graphql.GraphQLInputField]:
def _field(
self, tp: AnyType, field: ObjectField
) -> Lazy[graphql.GraphQLInputField]:
field_type = field.type
field_default = graphql.Undefined if field.required else field.get_default()
default: Any = graphql.Undefined
Expand All @@ -501,7 +513,7 @@ def _field(self, field: ObjectField) -> Lazy[graphql.GraphQLInputField]:
return lambda: graphql.GraphQLInputField(
factory.type, # type: ignore
default_value=default,
description=get_description(field.schema, field.type),
description=get_description(get_field_schema(tp, field), field.type),
)

@cache_type
Expand All @@ -514,7 +526,7 @@ def object(
normal_field = NormalField(
self.aliaser(field.alias),
field.name,
self._field(field),
self._field(tp, field),
field.ordering,
)
visited_fields.append(normal_field)
Expand Down Expand Up @@ -602,7 +614,7 @@ def resolve_wrapper(__obj, __info, **kwargs):

return cast(Func, resolve_wrapper)

def _field(self, field: ObjectField) -> Lazy[graphql.GraphQLField]:
def _field(self, tp: AnyType, field: ObjectField) -> Lazy[graphql.GraphQLField]:
field_name = field.name
partial_serialize = self._field_serialization_method(field)

Expand All @@ -611,15 +623,18 @@ def resolve(obj, _):
return partial_serialize(getattr(obj, field_name))

factory = self.visit_with_conv(field.type, field.serialization)
field_schema = get_field_schema(tp, field)
return lambda: graphql.GraphQLField(
factory.type,
None,
resolve,
description=get_description(field.schema, field.type),
deprecation_reason=get_deprecated(field.schema, field.type),
description=get_description(field_schema, field.type),
deprecation_reason=get_deprecated(field_schema, field.type),
)

def _resolver(self, field: ResolverField) -> Lazy[graphql.GraphQLField]:
def _resolver(
self, tp: AnyType, field: ResolverField
) -> Lazy[graphql.GraphQLField]:
resolve = self._wrap_resolve(
resolver_resolve(
field.resolver,
Expand Down Expand Up @@ -665,7 +680,10 @@ def _resolver(self, field: ResolverField) -> Lazy[graphql.GraphQLField]:
arg_factory = self.input_builder.visit_with_conv(
param_type, param_field.deserialization
)
description = get_description(param_field.schema, param_field.type)
description = get_description(
get_parameter_schema(field.resolver.func, param, param_field),
param_field.type,
)

def arg_thunk(
arg_factory=arg_factory, default=default, description=description
Expand All @@ -676,13 +694,14 @@ def arg_thunk(

args[self.aliaser(param_field.alias)] = arg_thunk
factory = self.visit_with_conv(field.types["return"], field.resolver.conversion)
field_schema = get_method_schema(tp, field.resolver)
return lambda: graphql.GraphQLField(
factory.type, # type: ignore
{name: arg() for name, arg in args.items()} if args else None,
resolve,
field.subscribe,
get_description(field.resolver.schema),
get_deprecated(field.resolver.schema),
get_description(field_schema),
get_deprecated(field_schema),
)

def _visit_flattened(
Expand Down Expand Up @@ -716,7 +735,7 @@ def object(
normal_field = NormalField(
self.aliaser(field.name),
field.name,
self._field(field),
self._field(tp, field),
field.ordering,
)
visited_fields.append(normal_field)
Expand All @@ -727,20 +746,16 @@ def object(
FlattenedField(field.name, field.ordering, flattened_factory)
)
resolvers = list(resolvers)
for alias, (resolver, types) in get_resolvers(tp).items():
for resolver, types in get_resolvers(tp):
resolver_field = ResolverField(
alias,
resolver,
types,
resolver.parameters,
resolver.parameters_metadata,
resolver, types, resolver.parameters, resolver.parameters_metadata
)
resolvers.append(resolver_field)
for resolver_field in resolvers:
normal_field = NormalField(
self.aliaser(resolver_field.alias),
self.aliaser(resolver_field.resolver.alias),
resolver_field.resolver.func.__name__,
self._resolver(resolver_field),
self._resolver(tp, resolver_field),
resolver_field.resolver.ordering,
)
visited_fields.append(normal_field)
Expand Down Expand Up @@ -838,9 +853,7 @@ class Subscription(Operation[AsyncIterable]):
Op = TypeVar("Op", bound=Operation)


def operation_resolver(
operation: Union[Callable, Op], op_class: Type[Op]
) -> Tuple[str, Resolver]:
def operation_resolver(operation: Union[Callable, Op], op_class: Type[Op]) -> Resolver:
if not isinstance(operation, op_class):
operation = op_class(operation) # type: ignore
error_handler: Optional[Callable]
Expand All @@ -864,8 +877,9 @@ def wrapper(_, *args, **kwargs):
wrapper.__annotations__ = op.__annotations__

(*parameters,) = resolver_parameters(operation.function, check_first=True)
return operation.alias or operation.function.__name__, Resolver(
return Resolver(
wrapper,
operation.alias or operation.function.__name__,
operation.conversion,
error_handler,
operation.order,
Expand Down Expand Up @@ -912,9 +926,8 @@ def graphql_schema(
(mutation, Mutation, mutation_fields),
]:
for operation in operations: # type: ignore
alias, resolver = operation_resolver(operation, op_class)
resolver = operation_resolver(operation, op_class)
resolver_field = ResolverField(
alias,
resolver,
resolver.types(),
resolver.parameters,
Expand All @@ -926,11 +939,11 @@ def graphql_schema(
sub_op = Subscription(sub_op) # type: ignore
sub_parameters: Sequence[Parameter]
if sub_op.resolver is not None:
alias = sub_op.alias or sub_op.resolver.__name__
_, subscriber2 = operation_resolver(sub_op, Subscription)
subscriber2 = operation_resolver(sub_op, Subscription)
_, *sub_parameters = resolver_parameters(sub_op.resolver, check_first=False)
resolver = Resolver(
sub_op.resolver,
sub_op.alias or sub_op.resolver.__name__,
sub_op.conversion,
subscriber2.error_handler,
sub_op.order,
Expand All @@ -949,9 +962,10 @@ def graphql_schema(
serialized=False,
)
else:
alias, subscriber2 = operation_resolver(sub_op, Subscription)
subscriber2 = operation_resolver(sub_op, Subscription)
resolver = Resolver(
lambda _: _,
subscriber2.alias,
sub_op.conversion,
subscriber2.error_handler,
sub_op.order,
Expand All @@ -978,12 +992,7 @@ def graphql_schema(
sub_types = {**sub_types, "return": resolver.return_type(event_type)}

resolver_field = ResolverField(
alias,
resolver,
sub_types,
sub_parameters,
sub_op.parameters_metadata,
subscribe,
resolver, sub_types, sub_parameters, sub_op.parameters_metadata, subscribe
)
subscription_fields.append(resolver_field)

Expand Down
Loading