diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java index 4a206bb0..267e3db8 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java @@ -24,6 +24,7 @@ import graphql.GraphQL; import graphql.GraphQLContext; import graphql.execution.ExecutionIdProvider; +import graphql.execution.instrumentation.dataloader.DataLoaderDispatcherInstrumentationState; import io.micrometer.context.ContextSnapshot; import org.dataloader.DataLoaderRegistry; import reactor.core.publisher.Mono; @@ -88,12 +89,21 @@ public final Mono execute(ExecutionGraphQlRequest requ private ExecutionInput registerDataLoaders(ExecutionInput executionInput) { if (!this.dataLoaderRegistrars.isEmpty()) { GraphQLContext graphQLContext = executionInput.getGraphQLContext(); - DataLoaderRegistry previousRegistry = executionInput.getDataLoaderRegistry(); - DataLoaderRegistry newRegistry = DataLoaderRegistry.newRegistry().registerAll(previousRegistry).build(); - this.dataLoaderRegistrars.forEach(registrar -> registrar.registerDataLoaders(newRegistry, graphQLContext)); - executionInput = executionInput.transform(builder -> builder.dataLoaderRegistry(newRegistry)); + DataLoaderRegistry existingRegistry = executionInput.getDataLoaderRegistry(); + if (existingRegistry == DataLoaderDispatcherInstrumentationState.EMPTY_DATALOADER_REGISTRY) { + DataLoaderRegistry newRegistry = DataLoaderRegistry.newRegistry().build(); + applyDataLoaderRegistrars(newRegistry, graphQLContext); + executionInput = executionInput.transform(builder -> builder.dataLoaderRegistry(newRegistry)); + } + else { + applyDataLoaderRegistrars(existingRegistry, graphQLContext); + } } return executionInput; } + private void applyDataLoaderRegistrars(DataLoaderRegistry registry, GraphQLContext graphQLContext) { + this.dataLoaderRegistrars.forEach(registrar -> registrar.registerDataLoaders(registry, graphQLContext)); + } + } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java new file mode 100644 index 00000000..34a2932a --- /dev/null +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.execution; + +import java.util.Map; + +import org.dataloader.DataLoaderRegistry; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; + +import org.springframework.graphql.Author; +import org.springframework.graphql.Book; +import org.springframework.graphql.ExecutionGraphQlRequest; +import org.springframework.graphql.ExecutionGraphQlResponse; +import org.springframework.graphql.GraphQlSetup; +import org.springframework.graphql.TestExecutionRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link DefaultExecutionGraphQlService}. + * + * @author Rossen Stoyanchev + * @since 1.2.4 + */ +public class DefaultExecutionGraphQlServiceTests { + + @Test + void customDataLoaderRegistry() { + DefaultBatchLoaderRegistry batchLoaderRegistry = new DefaultBatchLoaderRegistry(); + batchLoaderRegistry.forTypePair(Book.class, Author.class) + .registerBatchLoader((books, batchLoaderEnvironment) -> Flux.empty()); + + GraphQlSource graphQlSource = GraphQlSetup.schemaContent("type Query { greeting: String }") + .queryFetcher("greeting", (env) -> "hi") + .toGraphQlSource(); + + DefaultExecutionGraphQlService graphQlService = new DefaultExecutionGraphQlService(graphQlSource); + graphQlService.addDataLoaderRegistrar(batchLoaderRegistry); + + DataLoaderRegistry myRegistry = new DataLoaderRegistry(); + + ExecutionGraphQlRequest request = TestExecutionRequest.forDocument("{ greeting }"); + request.configureExecutionInput((input, builder) -> builder.dataLoaderRegistry(myRegistry).build()); + + ExecutionGraphQlResponse response = graphQlService.execute(request).block(); + Map data = response.getExecutionResult().getData(); + assertThat(data).isEqualTo(Map.of("greeting", "hi")); + assertThat(myRegistry.getDataLoaders()).hasSize(1); + } + +}