Skip to content

Commit

Permalink
Add base schema and deprecate default_schema. (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo authored Oct 10, 2021
1 parent 78dad2c commit 6d3516a
Show file tree
Hide file tree
Showing 15 changed files with 339 additions and 99 deletions.
4 changes: 3 additions & 1 deletion apischema/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterator,
Mapping,
Expand Down Expand Up @@ -124,7 +125,7 @@ def return_type(self, return_type: AnyType) -> AnyType:
)


def get_resolvers(tp: AnyType) -> Mapping[str, Tuple[Resolver, Mapping[str, AnyType]]]:
def get_resolvers(tp: AnyType) -> Collection[Tuple[Resolver, Mapping[str, AnyType]]]:
return _get_methods(tp, _resolvers)


Expand Down Expand Up @@ -194,6 +195,7 @@ def register(func: Callable, owner: Type, alias2: str):
error_handler2 = None
resolver = Resolver(
func,
alias2,
conversion,
error_handler2,
order,
Expand Down
81 changes: 45 additions & 36 deletions apischema/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
resolver_parameters,
resolver_resolve,
)
from apischema.json_schema.schema import get_field_schema, get_method_schema, get_schema
from apischema.metadata.keys import SCHEMA_METADATA
from apischema.objects import ObjectField
from apischema.objects.visitor import (
Expand All @@ -54,7 +55,7 @@
)
from apischema.ordering import Ordering, sort_by_order
from apischema.recursion import RecursiveConversionsVisitor
from apischema.schemas import Schema, get_schema, merge_schema
from apischema.schemas import Schema, merge_schema
from apischema.serialization import SerializationMethod, serialize
from apischema.serialization.serialized_methods import ErrorHandler
from apischema.type_names import TypeName, TypeNameFactory, get_type_name
Expand Down Expand Up @@ -111,6 +112,16 @@ def exec_thunk(thunk: TypeThunk, *, non_null=None) -> Any:
return result


def get_parameter_schema(
func: Callable, parameter: Parameter, field: ObjectField
) -> Optional[Schema]:
from apischema import settings

return merge_schema(
settings.base_schema.parameter(func, parameter, field.alias), field.schema
)


def merged_schema(
schema: Optional[Schema], tp: Optional[AnyType]
) -> Tuple[Optional[Schema], Mapping[str, Any]]:
Expand Down Expand Up @@ -151,7 +162,6 @@ def get_deprecated(

@dataclass(frozen=True)
class ResolverField:
alias: str
resolver: Resolver
types: Mapping[str, AnyType]
parameters: Sequence[Parameter]
Expand Down Expand Up @@ -480,7 +490,9 @@ class InputSchemaBuilder(
):
types = graphql.type.definition.graphql_input_types

def _field(self, field: ObjectField) -> Lazy[graphql.GraphQLInputField]:
def _field(
self, tp: AnyType, field: ObjectField
) -> Lazy[graphql.GraphQLInputField]:
field_type = field.type
field_default = graphql.Undefined if field.required else field.get_default()
default: Any = graphql.Undefined
Expand All @@ -501,7 +513,7 @@ def _field(self, field: ObjectField) -> Lazy[graphql.GraphQLInputField]:
return lambda: graphql.GraphQLInputField(
factory.type, # type: ignore
default_value=default,
description=get_description(field.schema, field.type),
description=get_description(get_field_schema(tp, field), field.type),
)

@cache_type
Expand All @@ -514,7 +526,7 @@ def object(
normal_field = NormalField(
self.aliaser(field.alias),
field.name,
self._field(field),
self._field(tp, field),
field.ordering,
)
visited_fields.append(normal_field)
Expand Down Expand Up @@ -602,7 +614,7 @@ def resolve_wrapper(__obj, __info, **kwargs):

return cast(Func, resolve_wrapper)

def _field(self, field: ObjectField) -> Lazy[graphql.GraphQLField]:
def _field(self, tp: AnyType, field: ObjectField) -> Lazy[graphql.GraphQLField]:
field_name = field.name
partial_serialize = self._field_serialization_method(field)

Expand All @@ -611,15 +623,18 @@ def resolve(obj, _):
return partial_serialize(getattr(obj, field_name))

factory = self.visit_with_conv(field.type, field.serialization)
field_schema = get_field_schema(tp, field)
return lambda: graphql.GraphQLField(
factory.type,
None,
resolve,
description=get_description(field.schema, field.type),
deprecation_reason=get_deprecated(field.schema, field.type),
description=get_description(field_schema, field.type),
deprecation_reason=get_deprecated(field_schema, field.type),
)

def _resolver(self, field: ResolverField) -> Lazy[graphql.GraphQLField]:
def _resolver(
self, tp: AnyType, field: ResolverField
) -> Lazy[graphql.GraphQLField]:
resolve = self._wrap_resolve(
resolver_resolve(
field.resolver,
Expand Down Expand Up @@ -665,7 +680,10 @@ def _resolver(self, field: ResolverField) -> Lazy[graphql.GraphQLField]:
arg_factory = self.input_builder.visit_with_conv(
param_type, param_field.deserialization
)
description = get_description(param_field.schema, param_field.type)
description = get_description(
get_parameter_schema(field.resolver.func, param, param_field),
param_field.type,
)

def arg_thunk(
arg_factory=arg_factory, default=default, description=description
Expand All @@ -676,13 +694,14 @@ def arg_thunk(

args[self.aliaser(param_field.alias)] = arg_thunk
factory = self.visit_with_conv(field.types["return"], field.resolver.conversion)
field_schema = get_method_schema(tp, field.resolver)
return lambda: graphql.GraphQLField(
factory.type, # type: ignore
{name: arg() for name, arg in args.items()} if args else None,
resolve,
field.subscribe,
get_description(field.resolver.schema),
get_deprecated(field.resolver.schema),
get_description(field_schema),
get_deprecated(field_schema),
)

def _visit_flattened(
Expand Down Expand Up @@ -716,7 +735,7 @@ def object(
normal_field = NormalField(
self.aliaser(field.name),
field.name,
self._field(field),
self._field(tp, field),
field.ordering,
)
visited_fields.append(normal_field)
Expand All @@ -727,20 +746,16 @@ def object(
FlattenedField(field.name, field.ordering, flattened_factory)
)
resolvers = list(resolvers)
for alias, (resolver, types) in get_resolvers(tp).items():
for resolver, types in get_resolvers(tp):
resolver_field = ResolverField(
alias,
resolver,
types,
resolver.parameters,
resolver.parameters_metadata,
resolver, types, resolver.parameters, resolver.parameters_metadata
)
resolvers.append(resolver_field)
for resolver_field in resolvers:
normal_field = NormalField(
self.aliaser(resolver_field.alias),
self.aliaser(resolver_field.resolver.alias),
resolver_field.resolver.func.__name__,
self._resolver(resolver_field),
self._resolver(tp, resolver_field),
resolver_field.resolver.ordering,
)
visited_fields.append(normal_field)
Expand Down Expand Up @@ -838,9 +853,7 @@ class Subscription(Operation[AsyncIterable]):
Op = TypeVar("Op", bound=Operation)


def operation_resolver(
operation: Union[Callable, Op], op_class: Type[Op]
) -> Tuple[str, Resolver]:
def operation_resolver(operation: Union[Callable, Op], op_class: Type[Op]) -> Resolver:
if not isinstance(operation, op_class):
operation = op_class(operation) # type: ignore
error_handler: Optional[Callable]
Expand All @@ -864,8 +877,9 @@ def wrapper(_, *args, **kwargs):
wrapper.__annotations__ = op.__annotations__

(*parameters,) = resolver_parameters(operation.function, check_first=True)
return operation.alias or operation.function.__name__, Resolver(
return Resolver(
wrapper,
operation.alias or operation.function.__name__,
operation.conversion,
error_handler,
operation.order,
Expand Down Expand Up @@ -912,9 +926,8 @@ def graphql_schema(
(mutation, Mutation, mutation_fields),
]:
for operation in operations: # type: ignore
alias, resolver = operation_resolver(operation, op_class)
resolver = operation_resolver(operation, op_class)
resolver_field = ResolverField(
alias,
resolver,
resolver.types(),
resolver.parameters,
Expand All @@ -926,11 +939,11 @@ def graphql_schema(
sub_op = Subscription(sub_op) # type: ignore
sub_parameters: Sequence[Parameter]
if sub_op.resolver is not None:
alias = sub_op.alias or sub_op.resolver.__name__
_, subscriber2 = operation_resolver(sub_op, Subscription)
subscriber2 = operation_resolver(sub_op, Subscription)
_, *sub_parameters = resolver_parameters(sub_op.resolver, check_first=False)
resolver = Resolver(
sub_op.resolver,
sub_op.alias or sub_op.resolver.__name__,
sub_op.conversion,
subscriber2.error_handler,
sub_op.order,
Expand All @@ -949,9 +962,10 @@ def graphql_schema(
serialized=False,
)
else:
alias, subscriber2 = operation_resolver(sub_op, Subscription)
subscriber2 = operation_resolver(sub_op, Subscription)
resolver = Resolver(
lambda _: _,
subscriber2.alias,
sub_op.conversion,
subscriber2.error_handler,
sub_op.order,
Expand All @@ -978,12 +992,7 @@ def graphql_schema(
sub_types = {**sub_types, "return": resolver.return_type(event_type)}

resolver_field = ResolverField(
alias,
resolver,
sub_types,
sub_parameters,
sub_op.parameters_metadata,
subscribe,
resolver, sub_types, sub_parameters, sub_op.parameters_metadata, subscribe
)
subscription_fields.append(resolver_field)

Expand Down
Loading

0 comments on commit 6d3516a

Please sign in to comment.