Skip to content

Commit

Permalink
Support KSP in BindsMethodValidator (#831)
Browse files Browse the repository at this point in the history
* Support KSP in BindsMethodValidator

* Fix tests requiring dagger compilation in BindsMethodValidator

* use KSP built-in assignability check

* explain forceEmbeddedMode parameter in BindsMethodValidatorTest

---------

Co-authored-by: Zac Sweers <[email protected]>
  • Loading branch information
IlyaGulya and ZacSweers authored Apr 3, 2024
1 parent 95ade2c commit e5919f2
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 133 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,39 @@
package com.squareup.anvil.compiler.codegen.dagger

import com.google.auto.service.AutoService
import com.google.devtools.ksp.processing.Resolver
import com.google.devtools.ksp.processing.SymbolProcessorEnvironment
import com.google.devtools.ksp.processing.SymbolProcessorProvider
import com.google.devtools.ksp.symbol.KSAnnotated
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSType
import com.google.devtools.ksp.symbol.KSTypeReference
import com.squareup.anvil.compiler.api.AnvilApplicabilityChecker
import com.squareup.anvil.compiler.api.AnvilContext
import com.squareup.anvil.compiler.api.CodeGenerator
import com.squareup.anvil.compiler.codegen.CheckOnlyCodeGenerator
import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessor
import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessorProvider
import com.squareup.anvil.compiler.codegen.ksp.KspAnvilException
import com.squareup.anvil.compiler.codegen.ksp.getAnnotatedFunctions
import com.squareup.anvil.compiler.codegen.ksp.getAnnotatedSymbols
import com.squareup.anvil.compiler.codegen.ksp.isExtensionDeclaration
import com.squareup.anvil.compiler.codegen.ksp.resolveKSClassDeclaration
import com.squareup.anvil.compiler.codegen.ksp.returnTypeOrNull
import com.squareup.anvil.compiler.codegen.ksp.superTypesExcludingAny
import com.squareup.anvil.compiler.codegen.ksp.withCompanion
import com.squareup.anvil.compiler.daggerBindsFqName
import com.squareup.anvil.compiler.daggerModuleFqName
import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionFunctionReference
import com.squareup.anvil.compiler.internal.reference.ClassReference
import com.squareup.anvil.compiler.internal.reference.MemberFunctionReference
import com.squareup.anvil.compiler.internal.reference.TypeReference
import com.squareup.anvil.compiler.internal.reference.allSuperTypeClassReferences
import com.squareup.anvil.compiler.internal.reference.asTypeName
import com.squareup.anvil.compiler.internal.reference.classAndInnerClassReferences
import com.squareup.anvil.compiler.internal.reference.toTypeReference
import com.squareup.kotlinpoet.ksp.toTypeName
import dagger.Binds
import dagger.Module
import org.jetbrains.kotlin.descriptors.ModuleDescriptor
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.psi.psiUtil.isExtensionDeclaration
Expand All @@ -24,106 +46,198 @@ import java.io.File
* that Dagger would consider a compile time failure would instead manifest as a runtime failure
* when Anvil generates Dagger factories.
*/
@AutoService(CodeGenerator::class)
internal class BindsMethodValidator : CheckOnlyCodeGenerator() {
internal object BindsMethodValidator : AnvilApplicabilityChecker {

override fun isApplicable(context: AnvilContext) = context.generateFactories

override fun checkCode(
codeGenDir: File,
module: ModuleDescriptor,
projectFiles: Collection<KtFile>,
) {
projectFiles
.classAndInnerClassReferences(module)
.filter { it.isAnnotatedWith(daggerModuleFqName) }
.forEach { clazz ->
(clazz.companionObjects() + clazz)
.asSequence()
.flatMap { it.functions }
.filter { it.isAnnotatedWith(daggerBindsFqName) }
.also { functions ->
assertNoDuplicateFunctions(clazz, functions)
}
.forEach { function ->
validateBindsFunction(function)
}
internal object Errors {
internal const val BINDS_CANT_BE_AN_EXTENSION =
"@Binds methods can not be an extension function"
internal const val BINDS_MUST_BE_ABSTRACT = "@Binds methods must be abstract"
internal const val BINDS_MUST_HAVE_SINGLE_PARAMETER =
"@Binds methods must have exactly one parameter, " +
"whose type is assignable to the return type"
internal const val BINDS_MUST_RETURN_A_VALUE = "@Binds methods must return a value (not void)"
internal fun bindsParameterMustBeAssignable(
bindingParameterSuperTypeNames: List<String>,
returnTypeName: String,
parameterType: String?,
): String {
val superTypesMessage = if (bindingParameterSuperTypeNames.isEmpty()) {
"has no supertypes."
} else {
"only has the following supertypes: $bindingParameterSuperTypeNames"
}
}

private fun validateBindsFunction(function: MemberFunctionReference.Psi) {
if (!function.isAbstract()) {
throw AnvilCompilationExceptionFunctionReference(
message = "@Binds methods must be abstract",
functionReference = function,
)
return "@Binds methods' parameter type must be assignable to the return type. " +
"Expected binding of type $returnTypeName but impl parameter of type " +
"$parameterType $superTypesMessage"
}
}

if (function.function.isExtensionDeclaration()) {
throw AnvilCompilationExceptionFunctionReference(
message = "@Binds methods can not be an extension function",
functionReference = function,
)
internal class KspValidator(
override val env: SymbolProcessorEnvironment,
) : AnvilSymbolProcessor() {
@AutoService(SymbolProcessorProvider::class)
class Provider : AnvilSymbolProcessorProvider(BindsMethodValidator, ::KspValidator)

override fun processChecked(resolver: Resolver): List<KSAnnotated> {
val actualValidator = ActualValidator(resolver)
resolver
.getAnnotatedSymbols<Module>()
.filterIsInstance<KSClassDeclaration>()
.forEach { clazz ->
clazz
.withCompanion()
.flatMap { it.getAnnotatedFunctions<Binds>() }
.also {
assertNoDuplicateFunctions(clazz, it)
}
.forEach { function ->
actualValidator.validateBindsFunction(function)
}
}
return emptyList()
}

val hasSingleBindingParameter =
function.parameters.size == 1 && !function.function.isExtensionDeclaration()
if (!hasSingleBindingParameter) {
throw AnvilCompilationExceptionFunctionReference(
message = "@Binds methods must have exactly one parameter, " +
"whose type is assignable to the return type",
functionReference = function,
)
internal class ActualValidator(
private val resolver: Resolver,
) {
internal fun validateBindsFunction(function: KSFunctionDeclaration) {
if (!function.isAbstract) {
throw KspAnvilException(
message = Errors.BINDS_MUST_BE_ABSTRACT,
node = function,
)
}

if (function.isExtensionDeclaration()) {
throw KspAnvilException(
message = Errors.BINDS_CANT_BE_AN_EXTENSION,
node = function,
)
}

val bindingParameter = function.singleParameterOrNull()

bindingParameter ?: throw KspAnvilException(
message = Errors.BINDS_MUST_HAVE_SINGLE_PARAMETER,
node = function,
)

val returnType = function.returnTypeOrNull()

returnType ?: throw KspAnvilException(
message = Errors.BINDS_MUST_RETURN_A_VALUE,
node = function,
)

if (!returnType.isAssignableFrom(bindingParameter.resolve())) {
val superTypeNames =
bindingParameter
.resolveSuperTypesExcludingAny()
.map { it.toTypeName().toString() }
.toList()

throw KspAnvilException(
message = Errors.bindsParameterMustBeAssignable(
bindingParameterSuperTypeNames = superTypeNames,
returnTypeName = returnType.toTypeName().toString(),
parameterType = bindingParameter.toTypeName().toString(),
),
node = function,
)
}
}

private fun KSTypeReference.resolveSuperTypesExcludingAny(): Sequence<KSType> {
val clazz = resolve().resolveKSClassDeclaration() ?: return emptySequence()
return clazz
.superTypesExcludingAny(resolver)
.map { it.resolve() }
}

private fun KSFunctionDeclaration.singleParameterOrNull(): KSTypeReference? =
parameters.singleOrNull()?.type
}
}

function.returnTypeOrNull() ?: throw AnvilCompilationExceptionFunctionReference(
message = "@Binds methods must return a value (not void)",
functionReference = function,
)
@AutoService(CodeGenerator::class)
internal class Embedded : CheckOnlyCodeGenerator() {
override fun isApplicable(context: AnvilContext): Boolean =
BindsMethodValidator.isApplicable(context)

if (!function.parameterMatchesReturnType() && !function.receiverMatchesReturnType()) {
val returnType = function.returnType().asClassReference().shortName
val paramSuperTypes = (function.parameterSuperTypes() ?: function.receiverSuperTypes())!!
.map { it.shortName }
.toList()
override fun checkCode(
codeGenDir: File,
module: ModuleDescriptor,
projectFiles: Collection<KtFile>,
) {
projectFiles
.classAndInnerClassReferences(module)
.filter { it.isAnnotatedWith(daggerModuleFqName) }
.forEach { clazz ->
(clazz.companionObjects() + clazz)
.asSequence()
.flatMap { it.functions }
.filter { it.isAnnotatedWith(daggerBindsFqName) }
.also { functions ->
assertNoDuplicateFunctions(clazz, functions)
}
.forEach { function ->
validateBindsFunction(function)
}
}
}

val superTypesMessage = if (paramSuperTypes.size == 1) {
"has no supertypes."
} else {
"only has the following supertypes: ${paramSuperTypes.drop(1)}"
private fun validateBindsFunction(function: MemberFunctionReference.Psi) {
if (!function.isAbstract()) {
throw AnvilCompilationExceptionFunctionReference(
message = Errors.BINDS_MUST_BE_ABSTRACT,
functionReference = function,
)
}
throw AnvilCompilationExceptionFunctionReference(
message = "@Binds methods' parameter type must be assignable to the return type. " +
"Expected binding of type $returnType but impl parameter of type " +
"${paramSuperTypes.first()} $superTypesMessage",

if (function.function.isExtensionDeclaration()) {
throw AnvilCompilationExceptionFunctionReference(
message = Errors.BINDS_CANT_BE_AN_EXTENSION,
functionReference = function,
)
}

val bindingParameter = function.singleParameterTypeOrNull()

bindingParameter ?: throw AnvilCompilationExceptionFunctionReference(
message = Errors.BINDS_MUST_HAVE_SINGLE_PARAMETER,
functionReference = function,
)
}
}

private fun MemberFunctionReference.Psi.parameterMatchesReturnType(): Boolean {
return parameterSuperTypes()
?.contains(returnType().asClassReference())
?: false
}
val returnType = function.returnTypeOrNull()?.asClassReference()

private fun MemberFunctionReference.Psi.parameterSuperTypes(): Sequence<ClassReference>? {
return parameters.singleOrNull()
?.type()
?.asClassReference()
?.allSuperTypeClassReferences(includeSelf = true)
}
returnType ?: throw AnvilCompilationExceptionFunctionReference(
message = Errors.BINDS_MUST_RETURN_A_VALUE,
functionReference = function,
)

private fun MemberFunctionReference.Psi.receiverMatchesReturnType(): Boolean {
return receiverSuperTypes()
?.contains(returnType().asClassReference())
?: false
}
val superTypes = bindingParameter
.asClassReference()
.allSuperTypeClassReferences(includeSelf = true)

private fun MemberFunctionReference.Psi.receiverSuperTypes(): Sequence<ClassReference>? {
return function.receiverTypeReference
?.toTypeReference(declaringClass, module)
?.asClassReference()
?.allSuperTypeClassReferences(includeSelf = true)
if (returnType !in superTypes) {
val superTypeNames = superTypes.map { it.asTypeName().toString() }

throw AnvilCompilationExceptionFunctionReference(
message = Errors.bindsParameterMustBeAssignable(
bindingParameterSuperTypeNames = superTypeNames.drop(1).toList(),
returnTypeName = returnType.asTypeName().toString(),
parameterType = superTypeNames.first(),
),
functionReference = function,
)
}
}

private fun MemberFunctionReference.Psi.singleParameterTypeOrNull(): TypeReference? {
return parameters.singleOrNull()?.type()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package com.squareup.anvil.compiler.codegen.dagger

import com.google.auto.service.AutoService
import com.google.devtools.ksp.closestClassDeclaration
import com.google.devtools.ksp.getDeclaredFunctions
import com.google.devtools.ksp.getDeclaredProperties
import com.google.devtools.ksp.getVisibility
import com.google.devtools.ksp.processing.Resolver
Expand All @@ -25,9 +24,11 @@ import com.squareup.anvil.compiler.codegen.PrivateCodeGenerator
import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessor
import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessorProvider
import com.squareup.anvil.compiler.codegen.ksp.KspAnvilException
import com.squareup.anvil.compiler.codegen.ksp.getAnnotatedFunctions
import com.squareup.anvil.compiler.codegen.ksp.getKSAnnotationsByType
import com.squareup.anvil.compiler.codegen.ksp.isAnnotationPresent
import com.squareup.anvil.compiler.codegen.ksp.isInterface
import com.squareup.anvil.compiler.codegen.ksp.withCompanion
import com.squareup.anvil.compiler.codegen.ksp.withJvmSuppressWildcardsIfNeeded
import com.squareup.anvil.compiler.daggerModuleFqName
import com.squareup.anvil.compiler.daggerProvidesFqName
Expand Down Expand Up @@ -82,23 +83,21 @@ internal object ProvidesMethodFactoryCodeGen : AnvilApplicabilityChecker {
resolver.getSymbolsWithAnnotation(daggerModuleFqName.asString())
.filterIsInstance<KSClassDeclaration>()
.forEach { clazz ->
val classAndCompanion = sequenceOf(clazz)
.plus(
clazz.declarations.filterIsInstance<KSClassDeclaration>()
.filter { it.isCompanionObject },
)

val functions = classAndCompanion.flatMap { it.getDeclaredFunctions() }
.filter { it.isAnnotationPresent<Provides>() }
.onEach { function ->
checkFunctionIsNotAbstract(clazz, function)
}
.also { functions ->
assertNoDuplicateFunctions(clazz, functions)
}
.map { function ->
CallableReference.from(function)
}
val classAndCompanion = clazz.withCompanion()
val functions =
classAndCompanion
.flatMap {
it.getAnnotatedFunctions<Provides>()
}
.onEach { function ->
checkFunctionIsNotAbstract(clazz, function)
}
.also { functions ->
assertNoDuplicateFunctions(clazz, functions)
}
.map { function ->
CallableReference.from(function)
}

val properties = classAndCompanion.flatMap { it.getDeclaredProperties() }
.filter { it.isAnnotationPresent<Provides>() || it.getter?.isAnnotationPresent<Provides>() == true }
Expand Down
Loading

0 comments on commit e5919f2

Please sign in to comment.