Skip to content

Commit

Permalink
Perform hasDataLoaderRegistrations check at runtime
Browse files Browse the repository at this point in the history
Closes gh-1020
  • Loading branch information
rstoyanchev committed Jul 4, 2024
1 parent 9965c03 commit a6508f4
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,6 +30,17 @@
*/
public interface DataLoaderRegistrar {


/**
* Whether the registrar has any {@code DataLoader} registrations to make.
* @since 1.2.8
*/
default boolean hasRegistrations() {
DataLoaderRegistry registry = DataLoaderRegistry.newRegistry().build();
registerDataLoaders(registry, GraphQLContext.newContext().build());
return !registry.getDataLoaders().isEmpty();
}

/**
* Callback that provides access to the {@link DataLoaderRegistry} from the
* the {@link graphql.ExecutionInput}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ public <K, V> RegistrationSpec<K, V> forName(String name) {
return new DefaultRegistrationSpec<>(name);
}

@Override
public boolean hasRegistrations() {
return (!this.loaders.isEmpty() || !this.mappedLoaders.isEmpty());
}

@Override
public void registerDataLoaders(DataLoaderRegistry registry, GraphQLContext context) {
BatchLoaderContextProvider contextProvider = () -> context;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.graphql.ExecutionGraphQlResponse;
import org.springframework.graphql.ExecutionGraphQlService;
import org.springframework.graphql.support.DefaultExecutionGraphQlResponse;
import org.springframework.lang.Nullable;

/**
* {@link ExecutionGraphQlService} that uses a {@link GraphQlSource} to obtain a
Expand All @@ -51,7 +52,8 @@ public class DefaultExecutionGraphQlService implements ExecutionGraphQlService {

private final List<DataLoaderRegistrar> dataLoaderRegistrars = new ArrayList<>();

private boolean hasDataLoaderRegistrations;
@Nullable
private Boolean hasDataLoaderRegistrations;

private final boolean isDefaultExecutionIdProvider;

Expand All @@ -70,13 +72,6 @@ public DefaultExecutionGraphQlService(GraphQlSource graphQlSource) {
*/
public void addDataLoaderRegistrar(DataLoaderRegistrar registrar) {
this.dataLoaderRegistrars.add(registrar);
this.hasDataLoaderRegistrations = (this.hasDataLoaderRegistrations || hasRegistrations(registrar));
}

private static boolean hasRegistrations(DataLoaderRegistrar registrar) {
DataLoaderRegistry registry = DataLoaderRegistry.newRegistry().build();
registrar.registerDataLoaders(registry, GraphQLContext.newContext().build());
return !registry.getDataLoaders().isEmpty();
}


Expand All @@ -93,28 +88,41 @@ public final Mono<ExecutionGraphQlResponse> execute(ExecutionGraphQlRequest requ
GraphQLContext graphQLContext = executionInput.getGraphQLContext();
ContextSnapshot.captureFrom(contextView).updateContext(graphQLContext);

ExecutionInput updatedExecutionInput =
(this.hasDataLoaderRegistrations ? registerDataLoaders(executionInput) : executionInput);
ExecutionInput executionInputToUse = registerDataLoaders(executionInput);

return Mono.fromFuture(this.graphQlSource.graphQl().executeAsync(updatedExecutionInput))
.map((result) -> new DefaultExecutionGraphQlResponse(updatedExecutionInput, result));
return Mono.fromFuture(this.graphQlSource.graphQl().executeAsync(executionInputToUse))
.map((result) -> new DefaultExecutionGraphQlResponse(executionInputToUse, result));
});
}

private ExecutionInput registerDataLoaders(ExecutionInput executionInput) {
GraphQLContext graphQLContext = executionInput.getGraphQLContext();
DataLoaderRegistry existingRegistry = executionInput.getDataLoaderRegistry();
if (existingRegistry == DataLoaderDispatcherInstrumentationState.EMPTY_DATALOADER_REGISTRY) {
DataLoaderRegistry newRegistry = DataLoaderRegistry.newRegistry().build();
applyDataLoaderRegistrars(newRegistry, graphQLContext);
executionInput = executionInput.transform((builder) -> builder.dataLoaderRegistry(newRegistry));
if (this.hasDataLoaderRegistrations == null) {
this.hasDataLoaderRegistrations = initHasDataLoaderRegistrations();
}
else {
applyDataLoaderRegistrars(existingRegistry, graphQLContext);
if (this.hasDataLoaderRegistrations) {
GraphQLContext graphQLContext = executionInput.getGraphQLContext();
DataLoaderRegistry existingRegistry = executionInput.getDataLoaderRegistry();
if (existingRegistry == DataLoaderDispatcherInstrumentationState.EMPTY_DATALOADER_REGISTRY) {
DataLoaderRegistry newRegistry = DataLoaderRegistry.newRegistry().build();
applyDataLoaderRegistrars(newRegistry, graphQLContext);
executionInput = executionInput.transform((builder) -> builder.dataLoaderRegistry(newRegistry));
}
else {
applyDataLoaderRegistrars(existingRegistry, graphQLContext);
}
}
return executionInput;
}

private boolean initHasDataLoaderRegistrations() {
for (DataLoaderRegistrar registrar : this.dataLoaderRegistrars) {
if (registrar.hasRegistrations()) {
return true;
}
}
return false;
}

private void applyDataLoaderRegistrars(DataLoaderRegistry registry, GraphQLContext graphQLContext) {
this.dataLoaderRegistrars.forEach((registrar) -> registrar.registerDataLoaders(registry, graphQLContext));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,26 +41,27 @@ public class DefaultExecutionGraphQlServiceTests {

@Test
void customDataLoaderRegistry() {
DefaultBatchLoaderRegistry batchLoaderRegistry = new DefaultBatchLoaderRegistry();
batchLoaderRegistry.forTypePair(Book.class, Author.class)
.registerBatchLoader((books, batchLoaderEnvironment) -> Flux.empty());

GraphQlSource graphQlSource = GraphQlSetup.schemaContent("type Query { greeting: String }")
.queryFetcher("greeting", (env) -> "hi")
.toGraphQlSource();

BatchLoaderRegistry batchLoaderRegistry = new DefaultBatchLoaderRegistry();
DefaultExecutionGraphQlService graphQlService = new DefaultExecutionGraphQlService(graphQlSource);
graphQlService.addDataLoaderRegistrar(batchLoaderRegistry);

DataLoaderRegistry myRegistry = new DataLoaderRegistry();
// gh-1020: register loader after adding the registry to DefaultExecutionGraphQlService
batchLoaderRegistry.forTypePair(Book.class, Author.class)
.registerBatchLoader((books, batchLoaderEnvironment) -> Flux.empty());

DataLoaderRegistry dataLoaderRegistry = new DataLoaderRegistry();

ExecutionGraphQlRequest request = TestExecutionRequest.forDocument("{ greeting }");
request.configureExecutionInput((input, builder) -> builder.dataLoaderRegistry(myRegistry).build());
request.configureExecutionInput((input, builder) -> builder.dataLoaderRegistry(dataLoaderRegistry).build());

ExecutionGraphQlResponse response = graphQlService.execute(request).block();
Map<?, ?> data = response.getExecutionResult().getData();
assertThat(data).isEqualTo(Map.of("greeting", "hi"));
assertThat(myRegistry.getDataLoaders()).hasSize(1);
assertThat(dataLoaderRegistry.getDataLoaders()).hasSize(1);
}

}

0 comments on commit a6508f4

Please sign in to comment.