From 3148b2794b5f60ad443bfeb7253288a9a818f797 Mon Sep 17 00:00:00 2001 From: Martin Bonnin Date: Fri, 8 Nov 2024 02:26:47 +0100 Subject: [PATCH] Fix generating unions (#43) --- .../execution/processor/definitions.kt | 266 +++++++++++------- tests/directives/graphql/schema.graphqls | 16 +- tests/federation/graphql/schema.graphqls | 12 +- tests/integration/build.gradle.kts | 1 + tests/integration/graphql/schema.graphqls | 20 ++ tests/integration/src/main/kotlin/graphql.kt | 18 +- tests/integration/src/test/kotlin/MainTest.kt | 35 +++ 7 files changed, 252 insertions(+), 116 deletions(-) create mode 100644 tests/integration/src/test/kotlin/MainTest.kt diff --git a/apollo-execution-processor/src/main/kotlin/com/apollographql/execution/processor/definitions.kt b/apollo-execution-processor/src/main/kotlin/com/apollographql/execution/processor/definitions.kt index 212d7095..941e9e02 100644 --- a/apollo-execution-processor/src/main/kotlin/com/apollographql/execution/processor/definitions.kt +++ b/apollo-execution-processor/src/main/kotlin/com/apollographql/execution/processor/definitions.kt @@ -72,7 +72,14 @@ private class TypeDefinitionContext( val declarationsToVisit = mutableListOf() + val usedTypeNames = mutableSetOf() + val unions = mutableMapOf>() + /** + * Walk the Kotlin type graph. It goes: + * - recursively (depth first) for supertypes so we can filter out union markers + * - breadth first for subtypes/fields so we don't loop on circular field/interfaces references + */ fun walk( query: KSClassDeclaration, mutation: KSClassDeclaration?, @@ -86,84 +93,9 @@ private class TypeDefinitionContext( declarationsToVisit.add(DeclarationToVisit(subscription, VisitContext.OUTPUT, "subscription")) } - val usedNames = mutableSetOf() while (declarationsToVisit.isNotEmpty()) { val declarationToVisit = declarationsToVisit.removeFirst() - val declaration = declarationToVisit.declaration - val context = declarationToVisit.context - - val qualifiedName = declaration.asClassName().asString() - if (typeDefinitions.containsKey(qualifiedName)) { - // Already visited - continue - } - - if (builtinTypes.contains(qualifiedName)) { - typeDefinitions.put(qualifiedName, builtinScalarDefinition(qualifiedName)) - continue - } - - val name = declaration.graphqlName() - if (usedNames.contains(name)) { - logger.error("Duplicate type '$name'. Either rename the declaration or use @GraphQLName.", declaration) - typeDefinitions.put(qualifiedName, null) - continue - } - usedNames.add(name) - - if (declaration.typeParameters.isNotEmpty()) { - logger.error("Generic classes are not supported") - typeDefinitions.put(qualifiedName, null) - continue - } - - if (unsupportedTypes.contains(qualifiedName)) { - logger.error( - "'$qualifiedName' is not a supported built-in type. Either use one of the built-in types (Boolean, String, Int, Double) or use a custom scalar.", - declaration - ) - typeDefinitions.put(qualifiedName, null) - continue - } - - if (declaration.isExternal()) { - logger.error( - "'$qualifiedName' doesn't have a containing file and probably comes from a dependency.", - declaration - ) - typeDefinitions.put(qualifiedName, null) - continue - } - - /** - * Track the files - */ - ksFiles.add(declaration.containingFile) - - if (declaration is KSTypeAlias) { - typeDefinitions.put(qualifiedName, declaration.toSirScalarDefinition(qualifiedName)) - continue - } - if (declaration !is KSClassDeclaration) { - logger.error("Unsupported type", declaration) - continue - } - if (declaration.classKind == ClassKind.ENUM_CLASS) { - typeDefinitions.put(qualifiedName, declaration.toSirEnumDefinition()) - continue - } - if (declaration.findAnnotation("GraphQLScalar") != null) { - typeDefinitions.put(qualifiedName, declaration.toSirScalarDefinition(qualifiedName)) - continue - } - if (context == VisitContext.INPUT) { - typeDefinitions.put(qualifiedName, declaration.toSirInputObject()) - continue - } - if (context == VisitContext.OUTPUT) { - typeDefinitions.put(qualifiedName, declaration.toSirComposite(declarationToVisit.isoperationType)) - continue - } + getOrResolve(declarationToVisit) } val finalizedDirectiveDefinitions = directiveDefinitions.mapNotNull { @@ -180,11 +112,91 @@ private class TypeDefinitionContext( } return TraversalResults( - definitions = finalizedDirectiveDefinitions + typeDefinitions.values.filterNotNull().toList(), + /** + * Not 100% sure what order to use for the types. + * Fields in source order make sense but for classes that may be defined in different files, it's a lot less clear + */ + definitions = typeDefinitions.patchUnions(unions).sortedBy { it.type() + it.name } + finalizedDirectiveDefinitions.sortedBy { it.name }, analyzedFiles = ksFiles.filterNotNull() ) } + private fun getOrResolve(declarationToVisit: DeclarationToVisit): SirTypeDefinition? { + val qualifiedName = declarationToVisit. declaration.asClassName().asString() + if (typeDefinitions.containsKey(qualifiedName)) { + // Already visited (maybe error) + return typeDefinitions.get(qualifiedName) + } + + val typeDefinition = resolveType(qualifiedName, declarationToVisit) + typeDefinitions.put(qualifiedName, typeDefinition) + return typeDefinition + } + + /** + * If returning null, this function also logs an error to fail the processor. + * + * @return the definition or null if there was an error + */ + private fun resolveType(qualifiedName: String, declarationToVisit: DeclarationToVisit): SirTypeDefinition? { + val declaration = declarationToVisit.declaration + val context = declarationToVisit.context + + if (builtinTypes.contains(qualifiedName)) { + return builtinScalarDefinition(qualifiedName) + } + if (unsupportedTypes.contains(qualifiedName)) { + logger.error( + "'$qualifiedName' is not a supported built-in type. Either use one of the built-in types (Boolean, String, Int, Double) or use a custom scalar.", + declaration + ) + return null + } + if (declaration.containingFile == null) { + logger.error( + "'$qualifiedName' doesn't have a containing file and probably comes from a dependency.", + declaration + ) + return null + } + + /** + * Track the files + */ + ksFiles.add(declaration.containingFile) + + val name = declaration.graphqlName() + if (usedTypeNames.contains(name)) { + logger.error("Duplicate type '$name'. Either rename the declaration or use @GraphQLName.", declaration) + return null + } + usedTypeNames.add(name) + + if (declaration.typeParameters.isNotEmpty()) { + logger.error("Generic classes are not supported") + return null + } + + if (declaration is KSTypeAlias) { + return declaration.toSirScalarDefinition(qualifiedName) + } + if (declaration !is KSClassDeclaration) { + logger.error("Unsupported type", declaration) + return null + } + if (declaration.classKind == ClassKind.ENUM_CLASS) { + return declaration.toSirEnumDefinition() + } + if (declaration.findAnnotation("GraphQLScalar") != null) { + return declaration.toSirScalarDefinition(qualifiedName) + } + + return when(context) { + VisitContext.OUTPUT -> declaration.toSirComposite(declarationToVisit.operationType) + VisitContext.INPUT -> declaration.toSirInputObject() + } + } + /** * Same code for both type aliases and classes */ @@ -353,6 +365,7 @@ private class TypeDefinitionContext( GQLEnumValue(null, simpleName.asString()) } } + else -> { logger.error("Cannot convert $this to a GQLValue", argument) GQLNullValue(null) // not correct but compilation should fail anyway @@ -410,7 +423,7 @@ private class TypeDefinitionContext( name = name, description = description, qualifiedName = qualifiedName, - interfaces = interfaces(), + interfaces = interfaces(name), targetClassName = asClassName(), instantiation = instantiation(), operationType = operationType, @@ -426,18 +439,31 @@ private class TypeDefinitionContext( return null } - val subclasses = getSealedSubclasses().map { - // Look into subclasses + getSealedSubclasses().forEach { + /** + * We go depth first on the superclasses but need to escape the callstack and + * remember to also go the other direction to not miss anything from the graph. + * + * If we were to go depth first only, we would miss all the concrete animal types + * below: + * + * ```graphql + * type Query { + * animal: Animal + * } + * + * union Animal = Cat | Dog | Lion ... + * ``` + */ declarationsToVisit.add(DeclarationToVisit(it, VisitContext.OUTPUT, null)) - it.graphqlName() - }.toList() + } if (allFields.isEmpty()) { SirUnionDefinition( name = name, description = description, qualifiedName = qualifiedName, - memberTypes = subclasses, + memberTypes = emptyList(), // we'll patch that later directives = directives(GQLDirectiveLocation.UNION), ) } else { @@ -448,7 +474,7 @@ private class TypeDefinitionContext( name = name, description = description, qualifiedName = qualifiedName, - interfaces = interfaces(), + interfaces = interfaces(null), fields = allFields, directives = directives(GQLDirectiveLocation.INTERFACE), ) @@ -474,7 +500,7 @@ private class TypeDefinitionContext( ) } - private fun KSClassDeclaration.interfaces(): List { + private fun KSClassDeclaration.interfaces(objectName: String?): List { return getAllSuperTypes().mapNotNull { val declaration = it.declaration if (it.arguments.isNotEmpty()) { @@ -482,16 +508,25 @@ private class TypeDefinitionContext( null } else if (declaration is KSClassDeclaration) { if (declaration.asClassName().asString() == "kotlin.Any") { - null - } else if (declaration.containingFile == null) { - logger.error( - "Class '${simpleName.asString()}' has a super class without a containing file that probably comes from a dependency.", - this - ) + // kotlin.Any is a super type of everything, just ignore it null } else { - declarationsToVisit.add(DeclarationToVisit(declaration, VisitContext.OUTPUT, null)) - declaration.graphqlName() + val supertype = getOrResolve(DeclarationToVisit(declaration, VisitContext.OUTPUT, null)) + if (supertype is SirInterfaceDefinition) { + supertype.name + } else if (supertype is SirUnionDefinition) { + if (objectName == null) { + logger.error("Interfaces are not allowed to extend union markers. Only classes can") + } else { + unions.compute(supertype.name) { _, oldValue -> + oldValue.orEmpty() + objectName + } + } + null + } else { + // error + null + } } } else { logger.error("Unrecognized super class", this) @@ -658,6 +693,11 @@ private class TypeDefinitionContext( } if (!argumentType.isMarkedNullable) { + /* + * Note: it's still possible to have a missing variable at runtime in a non-null position. + * Those cases trigger request error before reaching the resolver and the argument cannot + * be of Optional type. + */ logger.error("Input value is not nullable and cannot be optional", debugContext.node) return SirErrorType } @@ -666,13 +706,7 @@ private class TypeDefinitionContext( } else { if (!hasDefaultValue && isMarkedNullable) { logger.error( - """ - Input value is nullable and doesn't have a default value: it must also be optional. - - If the type is nullable with a default value and no value is provided by the user, the default value is passed to the resolver, the resolver code does not need to handle the `Absent` case. - If the type is non-nullable and there is no default value, variable values may still be absent at runtime. These cases are caught during coercion before it reaches the resolver code. - - """.trimIndent(), + "Input value is nullable and doesn't have a default value: it must also be optional.", debugContext.node ) return SirErrorType @@ -780,6 +814,36 @@ private class TypeDefinitionContext( } } +/** + * Sorting helper function. Not 100% sure of the order here + */ +private fun SirTypeDefinition.type(): String { + return when (this) { + is SirScalarDefinition -> "0" + is SirEnumDefinition -> "1" + is SirObjectDefinition -> "2" + is SirInterfaceDefinition -> "3" + is SirUnionDefinition -> "4" + is SirInputObjectDefinition -> "5" + } +} + +private fun Map.patchUnions(unions: Map>): List { + return values.filterNotNull().map { + if (it is SirUnionDefinition) { + SirUnionDefinition( + it.name, + it.description, + it.qualifiedName, + unions.get(it.name)!!.toList(), + it.directives + ) + } else { + it + } + } +} + private fun KSDeclaration.isApolloOptional(): Boolean { return asClassName().asString() == "com.apollographql.apollo.api.Optional" @@ -807,7 +871,7 @@ private val builtinTypes = listOf("Double", "String", "Boolean", "Int").map { private class DeclarationToVisit( val declaration: KSDeclaration, val context: VisitContext, - val isoperationType: String? = null + val operationType: String? = null ) private enum class VisitContext { diff --git a/tests/directives/graphql/schema.graphqls b/tests/directives/graphql/schema.graphqls index 1673bb6c..8ca5c595 100644 --- a/tests/directives/graphql/schema.graphqls +++ b/tests/directives/graphql/schema.graphqls @@ -2,7 +2,13 @@ schema { query: Query } -directive @requiresOptIn (feature: OptInFeature!) on FIELD_DEFINITION +enum OptInLevel { + Ignore + + Warning + + Error +} type Query { experimentalField: String! @requiresOptIn(feature: { @@ -18,13 +24,7 @@ input OptInFeature { level: OptInLevel! } -enum OptInLevel { - Ignore - - Warning - - Error -} +directive @requiresOptIn (feature: OptInFeature!) on FIELD_DEFINITION type __Schema { description: String diff --git a/tests/federation/graphql/schema.graphqls b/tests/federation/graphql/schema.graphqls index 3bf6a18f..c5b4f6a5 100644 --- a/tests/federation/graphql/schema.graphqls +++ b/tests/federation/graphql/schema.graphqls @@ -2,6 +2,12 @@ schema { query: Query } +type Product @key(fields: "id") { + id: String! + + name: String! +} + type Query { _entities(representations: [_Any!]!): [_Entity!]! @@ -10,12 +16,6 @@ type Query { products: [Product!]! } -type Product @key(fields: "id") { - id: String! - - name: String! -} - union _Entity = Product type _Service { diff --git a/tests/integration/build.gradle.kts b/tests/integration/build.gradle.kts index 808b12e1..215f9023 100644 --- a/tests/integration/build.gradle.kts +++ b/tests/integration/build.gradle.kts @@ -12,4 +12,5 @@ apolloExecution { dependencies { implementation(libs.apollo.execution.runtime) + testImplementation(libs.kotlin.test) } \ No newline at end of file diff --git a/tests/integration/graphql/schema.graphqls b/tests/integration/graphql/schema.graphqls index 1a84e8a1..b20c4997 100644 --- a/tests/integration/graphql/schema.graphqls +++ b/tests/integration/graphql/schema.graphqls @@ -2,10 +2,30 @@ schema { query: Query } +type Cat { + meow: String! +} + +type Dog { + barf: String! +} + type Query { field: String! + + animal: Cat! + + a(arg: Int!): Int! + + c(arg: Int): Int! + + e(arg: Int! = 10): Int! + + f(arg: Int = 10): Int! } +union Animal = Cat|Dog + type __Schema { description: String diff --git a/tests/integration/src/main/kotlin/graphql.kt b/tests/integration/src/main/kotlin/graphql.kt index 47f0718e..29fcb219 100644 --- a/tests/integration/src/main/kotlin/graphql.kt +++ b/tests/integration/src/main/kotlin/graphql.kt @@ -1,6 +1,22 @@ +import com.apollographql.apollo.api.Optional +import com.apollographql.execution.annotation.GraphQLDefault import com.apollographql.execution.annotation.GraphQLQuery @GraphQLQuery class Query { fun field(): String = "hello" -} \ No newline at end of file + val animal = Cat("meeeooooooowwwww") + fun a(arg: Int) = 0 + //fun b(arg: Int?) = 0 // Input value is nullable and doesn't have a default value: it must also be optional. + fun c(arg: Optional) = 0 + //fun d(arg: Optional) = 0 // Input value is not nullable and cannot be optional + fun e(@GraphQLDefault("10") arg: Int) = 0 + fun f(@GraphQLDefault("10") arg: Int?) = 0 + //fun g(@GraphQLDefault("10") arg: Optional) = 0 // Input value has a default value and cannot be optional + //fun h(@GraphQLDefault("10") arg: Optional) = 0 // Input value has a default value and cannot be optional +} + +sealed interface Animal + +class Cat(val meow:String): Animal +class Dog(val barf:String): Animal \ No newline at end of file diff --git a/tests/integration/src/test/kotlin/MainTest.kt b/tests/integration/src/test/kotlin/MainTest.kt new file mode 100644 index 00000000..9f72dc46 --- /dev/null +++ b/tests/integration/src/test/kotlin/MainTest.kt @@ -0,0 +1,35 @@ +import com.apollographql.execution.toGraphQLRequest +import com.example.ServiceExecutableSchemaBuilder +import kotlinx.coroutines.runBlocking +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs + +class MainTest { + @Test + fun test() { + ServiceExecutableSchemaBuilder() + .build() + .let { + runBlocking { + it.execute( + """ + { + field + animal { + ... on Cat { meow } + } + } + """.trimIndent().toGraphQLRequest()) + } + }.data + .apply { + assertIs>(this) + assertEquals(get("field"), "hello") + get("animal").apply { + assertIs>(this) + assertEquals(get("meow"), "meeeooooooowwwww") + } + } + } +}