diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/DataLoaderRegistrar.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/DataLoaderRegistrar.java index 52557fff..6178e9dd 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/DataLoaderRegistrar.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/DataLoaderRegistrar.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,17 @@ */ public interface DataLoaderRegistrar { + + /** + * Whether the registrar has any {@code DataLoader} registrations to make. + * @since 1.2.8 + */ + default boolean hasRegistrations() { + DataLoaderRegistry registry = DataLoaderRegistry.newRegistry().build(); + registerDataLoaders(registry, GraphQLContext.newContext().build()); + return !registry.getDataLoaders().isEmpty(); + } + /** * Callback that provides access to the {@link DataLoaderRegistry} from the * the {@link graphql.ExecutionInput}. diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultBatchLoaderRegistry.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultBatchLoaderRegistry.java index c1fe8324..c8127b3a 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultBatchLoaderRegistry.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultBatchLoaderRegistry.java @@ -89,6 +89,11 @@ public RegistrationSpec forName(String name) { return new DefaultRegistrationSpec<>(name); } + @Override + public boolean hasRegistrations() { + return (!this.loaders.isEmpty() || !this.mappedLoaders.isEmpty()); + } + @Override public void registerDataLoaders(DataLoaderRegistry registry, GraphQLContext context) { BatchLoaderContextProvider contextProvider = () -> context; diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java index ebd5e4b8..d28762c2 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java @@ -33,6 +33,7 @@ import org.springframework.graphql.ExecutionGraphQlResponse; import org.springframework.graphql.ExecutionGraphQlService; import org.springframework.graphql.support.DefaultExecutionGraphQlResponse; +import org.springframework.lang.Nullable; /** * {@link ExecutionGraphQlService} that uses a {@link GraphQlSource} to obtain a @@ -51,7 +52,8 @@ public class DefaultExecutionGraphQlService implements ExecutionGraphQlService { private final List dataLoaderRegistrars = new ArrayList<>(); - private boolean hasDataLoaderRegistrations; + @Nullable + private Boolean hasDataLoaderRegistrations; private final boolean isDefaultExecutionIdProvider; @@ -70,13 +72,6 @@ public DefaultExecutionGraphQlService(GraphQlSource graphQlSource) { */ public void addDataLoaderRegistrar(DataLoaderRegistrar registrar) { this.dataLoaderRegistrars.add(registrar); - this.hasDataLoaderRegistrations = (this.hasDataLoaderRegistrations || hasRegistrations(registrar)); - } - - private static boolean hasRegistrations(DataLoaderRegistrar registrar) { - DataLoaderRegistry registry = DataLoaderRegistry.newRegistry().build(); - registrar.registerDataLoaders(registry, GraphQLContext.newContext().build()); - return !registry.getDataLoaders().isEmpty(); } @@ -93,28 +88,41 @@ public final Mono execute(ExecutionGraphQlRequest requ GraphQLContext graphQLContext = executionInput.getGraphQLContext(); ContextSnapshot.captureFrom(contextView).updateContext(graphQLContext); - ExecutionInput updatedExecutionInput = - (this.hasDataLoaderRegistrations ? registerDataLoaders(executionInput) : executionInput); + ExecutionInput executionInputToUse = registerDataLoaders(executionInput); - return Mono.fromFuture(this.graphQlSource.graphQl().executeAsync(updatedExecutionInput)) - .map((result) -> new DefaultExecutionGraphQlResponse(updatedExecutionInput, result)); + return Mono.fromFuture(this.graphQlSource.graphQl().executeAsync(executionInputToUse)) + .map((result) -> new DefaultExecutionGraphQlResponse(executionInputToUse, result)); }); } private ExecutionInput registerDataLoaders(ExecutionInput executionInput) { - GraphQLContext graphQLContext = executionInput.getGraphQLContext(); - DataLoaderRegistry existingRegistry = executionInput.getDataLoaderRegistry(); - if (existingRegistry == DataLoaderDispatcherInstrumentationState.EMPTY_DATALOADER_REGISTRY) { - DataLoaderRegistry newRegistry = DataLoaderRegistry.newRegistry().build(); - applyDataLoaderRegistrars(newRegistry, graphQLContext); - executionInput = executionInput.transform((builder) -> builder.dataLoaderRegistry(newRegistry)); + if (this.hasDataLoaderRegistrations == null) { + this.hasDataLoaderRegistrations = initHasDataLoaderRegistrations(); } - else { - applyDataLoaderRegistrars(existingRegistry, graphQLContext); + if (this.hasDataLoaderRegistrations) { + GraphQLContext graphQLContext = executionInput.getGraphQLContext(); + DataLoaderRegistry existingRegistry = executionInput.getDataLoaderRegistry(); + if (existingRegistry == DataLoaderDispatcherInstrumentationState.EMPTY_DATALOADER_REGISTRY) { + DataLoaderRegistry newRegistry = DataLoaderRegistry.newRegistry().build(); + applyDataLoaderRegistrars(newRegistry, graphQLContext); + executionInput = executionInput.transform((builder) -> builder.dataLoaderRegistry(newRegistry)); + } + else { + applyDataLoaderRegistrars(existingRegistry, graphQLContext); + } } return executionInput; } + private boolean initHasDataLoaderRegistrations() { + for (DataLoaderRegistrar registrar : this.dataLoaderRegistrars) { + if (registrar.hasRegistrations()) { + return true; + } + } + return false; + } + private void applyDataLoaderRegistrars(DataLoaderRegistry registry, GraphQLContext graphQLContext) { this.dataLoaderRegistrars.forEach((registrar) -> registrar.registerDataLoaders(registry, graphQLContext)); } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java index 34a2932a..ef9293d9 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,26 +41,27 @@ public class DefaultExecutionGraphQlServiceTests { @Test void customDataLoaderRegistry() { - DefaultBatchLoaderRegistry batchLoaderRegistry = new DefaultBatchLoaderRegistry(); - batchLoaderRegistry.forTypePair(Book.class, Author.class) - .registerBatchLoader((books, batchLoaderEnvironment) -> Flux.empty()); - GraphQlSource graphQlSource = GraphQlSetup.schemaContent("type Query { greeting: String }") .queryFetcher("greeting", (env) -> "hi") .toGraphQlSource(); + BatchLoaderRegistry batchLoaderRegistry = new DefaultBatchLoaderRegistry(); DefaultExecutionGraphQlService graphQlService = new DefaultExecutionGraphQlService(graphQlSource); graphQlService.addDataLoaderRegistrar(batchLoaderRegistry); - DataLoaderRegistry myRegistry = new DataLoaderRegistry(); + // gh-1020: register loader after adding the registry to DefaultExecutionGraphQlService + batchLoaderRegistry.forTypePair(Book.class, Author.class) + .registerBatchLoader((books, batchLoaderEnvironment) -> Flux.empty()); + + DataLoaderRegistry dataLoaderRegistry = new DataLoaderRegistry(); ExecutionGraphQlRequest request = TestExecutionRequest.forDocument("{ greeting }"); - request.configureExecutionInput((input, builder) -> builder.dataLoaderRegistry(myRegistry).build()); + request.configureExecutionInput((input, builder) -> builder.dataLoaderRegistry(dataLoaderRegistry).build()); ExecutionGraphQlResponse response = graphQlService.execute(request).block(); Map data = response.getExecutionResult().getData(); assertThat(data).isEqualTo(Map.of("greeting", "hi")); - assertThat(myRegistry.getDataLoaders()).hasSize(1); + assertThat(dataLoaderRegistry.getDataLoaders()).hasSize(1); } }