diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MapKeyCreatorGenerator.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MapKeyCreatorGenerator.kt index c3e2e9283..13b9d3aee 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MapKeyCreatorGenerator.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MapKeyCreatorGenerator.kt @@ -1,12 +1,25 @@ package com.squareup.anvil.compiler.codegen.dagger import com.google.auto.service.AutoService +import com.google.devtools.ksp.getDeclaredProperties +import com.google.devtools.ksp.processing.Resolver +import com.google.devtools.ksp.processing.SymbolProcessorEnvironment +import com.google.devtools.ksp.processing.SymbolProcessorProvider +import com.google.devtools.ksp.symbol.KSAnnotated +import com.google.devtools.ksp.symbol.KSClassDeclaration +import com.google.devtools.ksp.symbol.KSPropertyDeclaration +import com.squareup.anvil.compiler.api.AnvilApplicabilityChecker import com.squareup.anvil.compiler.api.AnvilContext import com.squareup.anvil.compiler.api.CodeGenerator import com.squareup.anvil.compiler.api.GeneratedFile import com.squareup.anvil.compiler.api.createGeneratedFile import com.squareup.anvil.compiler.codegen.PrivateCodeGenerator -import com.squareup.anvil.compiler.internal.buildFile +import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessor +import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessorProvider +import com.squareup.anvil.compiler.codegen.ksp.KspAnvilException +import com.squareup.anvil.compiler.codegen.ksp.argumentAt +import com.squareup.anvil.compiler.codegen.ksp.isAnnotationClass +import com.squareup.anvil.compiler.internal.createAnvilSpec import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionClassReference import com.squareup.anvil.compiler.internal.reference.ClassReference import com.squareup.anvil.compiler.internal.reference.MemberPropertyReference @@ -43,6 +56,9 @@ import com.squareup.kotlinpoet.TypeName import com.squareup.kotlinpoet.TypeSpec import com.squareup.kotlinpoet.asClassName import com.squareup.kotlinpoet.joinToCode +import com.squareup.kotlinpoet.ksp.toClassName +import com.squareup.kotlinpoet.ksp.toTypeName +import com.squareup.kotlinpoet.ksp.writeTo import dagger.MapKey import org.jetbrains.kotlin.descriptors.ModuleDescriptor import org.jetbrains.kotlin.psi.KtFile @@ -55,105 +71,222 @@ import kotlin.reflect.KClass * * Implemented from eyeballing https://github.com/google/dagger/blob/b5990a0641a7860b760aa9055b90a99d06186af6/javatests/dagger/internal/codegen/MapKeyProcessorTest.java */ -@AutoService(CodeGenerator::class) -internal class MapKeyCreatorGenerator : PrivateCodeGenerator() { - +object MapKeyCreatorCodeGen : AnvilApplicabilityChecker { override fun isApplicable(context: AnvilContext) = context.generateFactories - override fun generateCodePrivate( - codeGenDir: File, - module: ModuleDescriptor, - projectFiles: Collection, - ) { - projectFiles - .classAndInnerClassReferences(module) - .filter { classRef -> - val mapKey = classRef.annotations.find { it.fqName == mapKeyFqName } - if (mapKey != null) { - mapKey.argumentAt("unwrapValue", 0)?.value() == false - } else { - false + internal class KspGenerator( + override val env: SymbolProcessorEnvironment, + ) : AnvilSymbolProcessor() { + @AutoService(SymbolProcessorProvider::class) + class Provider : AnvilSymbolProcessorProvider(MapKeyCreatorCodeGen, ::KspGenerator) + + override fun processChecked(resolver: Resolver): List { + resolver.getSymbolsWithAnnotation(mapKeyFqName.asString()) + .filterIsInstance() + .filter { clazz -> + val mapKey = clazz.annotations.find { it.shortName.asString() == "MapKey" } + ?: return@filter false + val unwrapValue = mapKey.argumentAt("unwrapValue")?.value as? Boolean ?: true + return@filter !unwrapValue + } + .forEach { clazz -> + generateCreatorClass(clazz) + .writeTo( + env.codeGenerator, + aggregating = false, + originatingKSFiles = listOf(clazz.containingFile!!), + ) } + + return emptyList() + } + + private fun generateCreatorClass( + clazz: KSClassDeclaration, + ): FileSpec { + // // Given this + // @MapKey(unwrapValue = false) + // annotation class ActivityKey( + // val value: KClass, + // val scope: KClass<*>, + // ) + // + // // Generate this + // object ActivityKeyCreator { + // @JvmStatic + // fun createActivityKey( + // value: Class, + // scope: Class<*> + // ): ActivityKey { + // return ActivityKey(value.kotlin, scope.kotlin) + // } + // } + + val className = clazz.toClassName() + + if (!clazz.isAnnotationClass()) { + throw KspAnvilException( + message = "@MapKey is only applicable to annotation classes.", + node = clazz, + ) } - .forEach { clazz -> - generateCreatorClass(codeGenDir, clazz) + + val creatorsToGenerate = mutableSetOf() + + fun visitAnnotations(clazz: KSClassDeclaration) { + if (clazz.isAnnotationClass()) { + val added = creatorsToGenerate.add(clazz) + if (added) { + for (property in clazz.getDeclaredProperties()) { + val type = property.type.resolve().declaration as? KSClassDeclaration? + if (type?.isAnnotationClass() == true) { + visitAnnotations(type) + } + } + } + } } + + // Populate all used annotations + visitAnnotations(clazz) + + val creatorFunctions = creatorsToGenerate + .associateBy { annotationClass -> + annotationClass.toClassName() + } + .toSortedMap() + .map { (className, clazz) -> + val properties = clazz.getDeclaredProperties() + .map { AnnotationProperty(it) } + .associateBy { it.name } + generateCreatorFunction(className, properties) + } + + return generateCreatorFileSpec(className, creatorFunctions) + } } - private fun generateCreatorClass( - codeGenDir: File, - clazz: ClassReference, - ): GeneratedFile { - // // Given this - // @MapKey(unwrapValue = false) - // annotation class ActivityKey( - // val value: KClass, - // val scope: KClass<*>, - // ) - // - // // Generate this - // object ActivityKeyCreator { - // @JvmStatic - // fun createActivityKey( - // value: Class, - // scope: Class<*> - // ): ActivityKey { - // return ActivityKey(value.kotlin, scope.kotlin) - // } - // } - - val packageName = clazz.packageFqName.safePackageString() - - if (!clazz.isAnnotationClass()) { - throw AnvilCompilationExceptionClassReference( - message = "@MapKey is only applicable to annotation classes.", - classReference = clazz, - ) + @AutoService(CodeGenerator::class) + internal class EmbeddedGenerator : PrivateCodeGenerator() { + + override fun isApplicable(context: AnvilContext) = MapKeyCreatorCodeGen.isApplicable(context) + + override fun generateCodePrivate( + codeGenDir: File, + module: ModuleDescriptor, + projectFiles: Collection, + ) { + projectFiles + .classAndInnerClassReferences(module) + .filter { classRef -> + val mapKey = classRef.annotations.find { it.fqName == mapKeyFqName } + if (mapKey != null) { + mapKey.argumentAt("unwrapValue", 0)?.value() == false + } else { + false + } + } + .forEach { clazz -> + generateCreatorClass(codeGenDir, clazz) + } } - val className = clazz.asClassName() + private fun generateCreatorClass( + codeGenDir: File, + clazz: ClassReference, + ): GeneratedFile { + // // Given this + // @MapKey(unwrapValue = false) + // annotation class ActivityKey( + // val value: KClass, + // val scope: KClass<*>, + // ) + // + // // Generate this + // object ActivityKeyCreator { + // @JvmStatic + // fun createActivityKey( + // value: Class, + // scope: Class<*> + // ): ActivityKey { + // return ActivityKey(value.kotlin, scope.kotlin) + // } + // } + + val packageName = clazz.packageFqName.safePackageString() + + if (!clazz.isAnnotationClass()) { + throw AnvilCompilationExceptionClassReference( + message = "@MapKey is only applicable to annotation classes.", + classReference = clazz, + ) + } + + val className = clazz.asClassName() - val creatorsToGenerate = mutableSetOf() + val creatorsToGenerate = mutableSetOf() - fun visitAnnotations(clazz: ClassReference) { - if (clazz.isAnnotationClass()) { - val added = creatorsToGenerate.add(clazz) - if (added) { - for (property in clazz.properties) { - val type = property.type().asClassReferenceOrNull() - if (type?.isAnnotationClass() == true) { - visitAnnotations(type) + fun visitAnnotations(clazz: ClassReference) { + if (clazz.isAnnotationClass()) { + val added = creatorsToGenerate.add(clazz) + if (added) { + for (property in clazz.properties) { + val type = property.type().asClassReferenceOrNull() + if (type?.isAnnotationClass() == true) { + visitAnnotations(type) + } } } } } - } - // Populate all used annotations - visitAnnotations(clazz) + // Populate all used annotations + visitAnnotations(clazz) - val creatorFunctions = creatorsToGenerate - .associateBy { annotationClass -> - annotationClass.asTypeName().rawTypeOrNull() - ?: throw AnvilCompilationExceptionClassReference( - message = "@MapKey is only applicable to non-generic annotation classes.", - classReference = annotationClass, - ) - } - .toSortedMap() - .map { (className, clazz) -> generateCreatorFunction(className, clazz) } + val creatorFunctions = creatorsToGenerate + .associateBy { annotationClass -> + annotationClass.asTypeName().rawTypeOrNull() + ?: throw AnvilCompilationExceptionClassReference( + message = "@MapKey is only applicable to non-generic annotation classes.", + classReference = annotationClass, + ) + } + .toSortedMap() + .map { (className, clazz) -> + val properties = clazz.properties + .map { AnnotationProperty(it) } + .associateBy { it.name } + generateCreatorFunction(className, properties) + } + + val spec = generateCreatorFileSpec(className, creatorFunctions) + val content = spec.toString() + + return createGeneratedFile(codeGenDir, packageName, spec.name, content) + } - val simpleName = className.simpleNames.joinToString("_") + private fun generateCreatorFunction( + className: ClassName, + annotationClass: ClassReference, + ): FunSpec { + val properties = annotationClass.properties + .map { AnnotationProperty(it) } + .associateBy { it.name } + return generateCreatorFunction(className, properties) + } + } + + private fun generateCreatorFileSpec(sourceClass: ClassName, creatorFunctions: List): FileSpec { + val simpleName = sourceClass.simpleNames.joinToString("_") val generatedClassName = "${simpleName}Creator" - val content = FileSpec.buildFile(packageName, generatedClassName) { + val spec = FileSpec.createAnvilSpec(sourceClass.packageName, generatedClassName) { addType( TypeSpec.objectBuilder(generatedClassName) .addFunctions(creatorFunctions) .build(), ) } - - return createGeneratedFile(codeGenDir, packageName, generatedClassName, content) + return spec } /** @@ -161,11 +294,8 @@ internal class MapKeyCreatorGenerator : PrivateCodeGenerator() { */ private fun generateCreatorFunction( className: ClassName, - annotationClass: ClassReference, + properties: Map, ): FunSpec { - val properties = annotationClass.properties - .map { AnnotationProperty(it) } - .associateBy { it.name } return FunSpec.builder("create${className.simpleName}") .addAnnotation(JvmStatic::class) .apply { @@ -189,17 +319,13 @@ private class AnnotationProperty( val callExpression: CodeBlock, ) { companion object { - operator fun invoke( - property: MemberPropertyReference, - ): AnnotationProperty { - val name = property.name - val typeName = property.type().asTypeName() - val javaType = typeName.resolveJavaType() + private fun create(name: String, type: TypeName): AnnotationProperty { + val javaType = type.resolveJavaType() val codeBlock = when { javaType.rawTypeOrNull() == CLASS_CLASS_NAME -> CodeBlock.of("%L.kotlin", name) - typeName is ParameterizedTypeName && - typeName.rawType == ARRAY && - typeName.typeArguments[0].rawTypeOrNull() == KCLASS_CLASS_NAME -> { + type is ParameterizedTypeName && + type.rawType == ARRAY && + type.typeArguments[0].rawTypeOrNull() == KCLASS_CLASS_NAME -> { // Dense but this avoids an intermediate list allocation compared to .map { ... }.toTypedArray() CodeBlock.of("%1T(%2L.size)·{·%2L[it].kotlin·}", ARRAY, name) } @@ -212,6 +338,14 @@ private class AnnotationProperty( codeBlock, ) } + + operator fun invoke( + property: KSPropertyDeclaration, + ): AnnotationProperty = create(property.simpleName.asString(), property.type.toTypeName()) + + operator fun invoke( + property: MemberPropertyReference, + ): AnnotationProperty = create(property.name, property.type().asTypeName()) } } diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ksp/KspUtil.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ksp/KspUtil.kt index 42a7b81e1..589a4ab08 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ksp/KspUtil.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ksp/KspUtil.kt @@ -1,7 +1,9 @@ package com.squareup.anvil.compiler.codegen.ksp +import com.google.devtools.ksp.symbol.ClassKind.ANNOTATION_CLASS import com.google.devtools.ksp.symbol.KSAnnotated import com.google.devtools.ksp.symbol.KSAnnotation +import com.google.devtools.ksp.symbol.KSClassDeclaration import kotlin.reflect.KClass /** @@ -33,3 +35,5 @@ internal fun KSAnnotated.getKSAnnotationsByQualifiedName( internal fun KSAnnotated.isAnnotationPresent(qualifiedName: String): Boolean = getKSAnnotationsByQualifiedName(qualifiedName).firstOrNull() != null + +internal fun KSClassDeclaration.isAnnotationClass(): Boolean = classKind == ANNOTATION_CLASS diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MapKeyCreatorGeneratorTest.kt b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MapKeyCreatorGeneratorTest.kt index 8e3d556df..d3912faf4 100644 --- a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MapKeyCreatorGeneratorTest.kt +++ b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MapKeyCreatorGeneratorTest.kt @@ -1,7 +1,11 @@ package com.squareup.anvil.compiler.dagger +import com.google.common.collect.Lists.cartesianProduct import com.google.common.truth.Truth.assertThat import com.squareup.anvil.compiler.WARNINGS_AS_ERRORS +import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode +import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode.Embedded +import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode.Ksp import com.squareup.anvil.compiler.internal.testing.compileAnvil import com.squareup.anvil.compiler.internal.testing.isStatic import com.squareup.anvil.compiler.isFullTestRun @@ -17,13 +21,24 @@ import org.junit.runners.Parameterized.Parameters @RunWith(Parameterized::class) class MapKeyCreatorGeneratorTest( private val useDagger: Boolean, + private val mode: AnvilCompilationMode, ) { companion object { - @Parameters(name = "Use Dagger: {0}") + @Parameters(name = "Use Dagger: {0}, mode: {1}") @JvmStatic fun useDagger(): Collection { - return listOf(isFullTestRun(), false).distinct() + return cartesianProduct( + listOf(isFullTestRun(), false), + listOf(Embedded(), Ksp()), + ).mapNotNull { (useDagger, mode) -> + if (useDagger == true && mode is Ksp) { + // TODO Dagger is not supported with KSP in Anvil's tests yet + null + } else { + arrayOf(useDagger, mode) + } + }.distinct() } } @@ -251,6 +266,7 @@ class MapKeyCreatorGeneratorTest( enableDaggerAnnotationProcessor = useDagger, generateDaggerFactories = !useDagger, allWarningsAsErrors = WARNINGS_AS_ERRORS, + mode = mode, block = block, ) }