diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultSchemaResourceGraphQlSourceBuilder.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultSchemaResourceGraphQlSourceBuilder.java index ce77b884..9146d93e 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultSchemaResourceGraphQlSourceBuilder.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultSchemaResourceGraphQlSourceBuilder.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -30,6 +31,7 @@ import graphql.GraphQL; import graphql.language.InterfaceTypeDefinition; import graphql.language.UnionTypeDefinition; +import graphql.schema.DataFetcher; import graphql.schema.GraphQLSchema; import graphql.schema.TypeResolver; import graphql.schema.idl.CombinedWiringFactory; @@ -45,6 +47,8 @@ import org.springframework.core.io.Resource; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** @@ -149,6 +153,7 @@ protected GraphQLSchema initGraphQlSchema() { } RuntimeWiring runtimeWiring = initRuntimeWiring(); + updateForCustomRootOperationTypeNames(registry, runtimeWiring); TypeResolver typeResolver = initTypeResolver(); registry.types().values().forEach((def) -> { @@ -210,6 +215,23 @@ private RuntimeWiring initRuntimeWiring() { return builder.build(); } + @SuppressWarnings("rawtypes") + private static void updateForCustomRootOperationTypeNames( + TypeDefinitionRegistry registry, RuntimeWiring runtimeWiring) { + + if (registry.schemaDefinition().isEmpty()) { + return; + } + + registry.schemaDefinition().get().getOperationTypeDefinitions().forEach((definition) -> { + String name = StringUtils.capitalize(definition.getName()); + Map dataFetcherMap = runtimeWiring.getDataFetchers().remove(name); + if (!CollectionUtils.isEmpty(dataFetcherMap)) { + runtimeWiring.getDataFetchers().put(definition.getTypeName().getName(), dataFetcherMap); + } + }); + } + private TypeResolver initTypeResolver() { return (this.typeResolver != null) ? this.typeResolver : new ClassNameTypeResolver(); } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultSchemaResourceGraphQlSourceBuilderTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultSchemaResourceGraphQlSourceBuilderTests.java index a1204b27..9ffe73bf 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultSchemaResourceGraphQlSourceBuilderTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultSchemaResourceGraphQlSourceBuilderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import graphql.Scalars; import graphql.schema.DataFetcher; import graphql.schema.FieldCoordinates; +import graphql.schema.GraphQLCodeRegistry; import graphql.schema.GraphQLFieldDefinition; import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLSchema; @@ -78,13 +79,14 @@ public TraversalControl visitGraphQLObjectType( @Test void typeVisitorToTransformSchema() { - String schemaContent = "" + - "type Query {" + - " person: Person" + - "} " + - "type Person {" + - " firstName: String" + - "}"; + String schemaContent = """ + type Query { + person: Person + } + type Person { + firstName: String + } + """; GraphQLTypeVisitor visitor = new GraphQLTypeVisitorStub() { @@ -113,6 +115,34 @@ public TraversalControl visitGraphQLObjectType( assertThat(schema.getObjectType("Person").getFieldDefinition("lastName")).isNotNull(); } + @Test // gh-708 + void rootOperationTypesWithCustomNames() { + String schemaContent = """ + schema { + query: MyQuery + mutation: MyMutation + } + type MyQuery { + hello: String! + } + type MyMutation { + saveGreeting(greeting: String!): String! + } + """; + + GraphQLSchema schema = GraphQlSetup.schemaContent(schemaContent) + .runtimeWiring(wiringBuilder -> { + wiringBuilder.type("Query", builder -> builder.dataFetcher("hello", env -> "")); + wiringBuilder.type("Mutation", builder -> builder.dataFetcher("saveGreeting", env -> "")); + }) + .toGraphQlSource() + .schema(); + + GraphQLCodeRegistry registry = schema.getCodeRegistry(); + assertThat(registry.hasDataFetcher(FieldCoordinates.coordinates("MyQuery", "hello"))).isTrue(); + assertThat(registry.hasDataFetcher(FieldCoordinates.coordinates("MyMutation", "saveGreeting"))).isTrue(); + } + @Test void wiringFactoryList() {