Skip to content

Commit

Permalink
Support KSP in MapKeyCreatorGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
ZacSweers committed Nov 18, 2023
1 parent 18ae172 commit 6e94ca1
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 90 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
package com.squareup.anvil.compiler.codegen.dagger

import com.google.auto.service.AutoService
import com.google.devtools.ksp.getDeclaredProperties
import com.google.devtools.ksp.processing.Resolver
import com.google.devtools.ksp.processing.SymbolProcessorEnvironment
import com.google.devtools.ksp.processing.SymbolProcessorProvider
import com.google.devtools.ksp.symbol.KSAnnotated
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSPropertyDeclaration
import com.squareup.anvil.compiler.api.AnvilApplicabilityChecker
import com.squareup.anvil.compiler.api.AnvilContext
import com.squareup.anvil.compiler.api.CodeGenerator
import com.squareup.anvil.compiler.api.GeneratedFile
import com.squareup.anvil.compiler.api.createGeneratedFile
import com.squareup.anvil.compiler.codegen.PrivateCodeGenerator
import com.squareup.anvil.compiler.internal.buildFile
import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessor
import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessorProvider
import com.squareup.anvil.compiler.codegen.ksp.KspAnvilException
import com.squareup.anvil.compiler.codegen.ksp.argumentAt
import com.squareup.anvil.compiler.codegen.ksp.isAnnotationClass
import com.squareup.anvil.compiler.internal.createAnvilSpec
import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionClassReference
import com.squareup.anvil.compiler.internal.reference.ClassReference
import com.squareup.anvil.compiler.internal.reference.MemberPropertyReference
Expand Down Expand Up @@ -43,6 +56,9 @@ import com.squareup.kotlinpoet.TypeName
import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.asClassName
import com.squareup.kotlinpoet.joinToCode
import com.squareup.kotlinpoet.ksp.toClassName
import com.squareup.kotlinpoet.ksp.toTypeName
import com.squareup.kotlinpoet.ksp.writeTo
import dagger.MapKey
import org.jetbrains.kotlin.descriptors.ModuleDescriptor
import org.jetbrains.kotlin.psi.KtFile
Expand All @@ -55,117 +71,231 @@ import kotlin.reflect.KClass
*
* Implemented from eyeballing https://github.com/google/dagger/blob/b5990a0641a7860b760aa9055b90a99d06186af6/javatests/dagger/internal/codegen/MapKeyProcessorTest.java
*/
@AutoService(CodeGenerator::class)
internal class MapKeyCreatorGenerator : PrivateCodeGenerator() {

object MapKeyCreatorCodeGen : AnvilApplicabilityChecker {
override fun isApplicable(context: AnvilContext) = context.generateFactories

override fun generateCodePrivate(
codeGenDir: File,
module: ModuleDescriptor,
projectFiles: Collection<KtFile>,
) {
projectFiles
.classAndInnerClassReferences(module)
.filter { classRef ->
val mapKey = classRef.annotations.find { it.fqName == mapKeyFqName }
if (mapKey != null) {
mapKey.argumentAt("unwrapValue", 0)?.value<Boolean>() == false
} else {
false
internal class KspGenerator(
override val env: SymbolProcessorEnvironment,
) : AnvilSymbolProcessor() {
@AutoService(SymbolProcessorProvider::class)
class Provider : AnvilSymbolProcessorProvider(MapKeyCreatorCodeGen, ::KspGenerator)

override fun processChecked(resolver: Resolver): List<KSAnnotated> {
resolver.getSymbolsWithAnnotation(mapKeyFqName.asString())
.filterIsInstance<KSClassDeclaration>()
.filter { clazz ->
val mapKey = clazz.annotations.find { it.shortName.asString() == "MapKey" }
?: return@filter false
val unwrapValue = mapKey.argumentAt("unwrapValue")?.value as? Boolean ?: true
return@filter !unwrapValue
}
.forEach { clazz ->
generateCreatorClass(clazz)
.writeTo(
env.codeGenerator,
aggregating = false,
originatingKSFiles = listOf(clazz.containingFile!!),
)
}

return emptyList()
}

private fun generateCreatorClass(
clazz: KSClassDeclaration,
): FileSpec {
// // Given this
// @MapKey(unwrapValue = false)
// annotation class ActivityKey(
// val value: KClass<out Activity>,
// val scope: KClass<*>,
// )
//
// // Generate this
// object ActivityKeyCreator {
// @JvmStatic
// fun createActivityKey(
// value: Class<out Activity>,
// scope: Class<*>
// ): ActivityKey {
// return ActivityKey(value.kotlin, scope.kotlin)
// }
// }

val className = clazz.toClassName()

if (!clazz.isAnnotationClass()) {
throw KspAnvilException(
message = "@MapKey is only applicable to annotation classes.",
node = clazz,
)
}
.forEach { clazz ->
generateCreatorClass(codeGenDir, clazz)

val creatorsToGenerate = mutableSetOf<KSClassDeclaration>()

fun visitAnnotations(clazz: KSClassDeclaration) {
if (clazz.isAnnotationClass()) {
val added = creatorsToGenerate.add(clazz)
if (added) {
for (property in clazz.getDeclaredProperties()) {
val type = property.type.resolve().declaration as? KSClassDeclaration?
if (type?.isAnnotationClass() == true) {
visitAnnotations(type)
}
}
}
}
}

// Populate all used annotations
visitAnnotations(clazz)

val creatorFunctions = creatorsToGenerate
.associateBy { annotationClass ->
annotationClass.toClassName()
}
.toSortedMap()
.map { (className, clazz) ->
val properties = clazz.getDeclaredProperties()
.map { AnnotationProperty(it) }
.associateBy { it.name }
generateCreatorFunction(className, properties)
}

return generateCreatorFileSpec(className, creatorFunctions)
}
}

private fun generateCreatorClass(
codeGenDir: File,
clazz: ClassReference,
): GeneratedFile {
// // Given this
// @MapKey(unwrapValue = false)
// annotation class ActivityKey(
// val value: KClass<out Activity>,
// val scope: KClass<*>,
// )
//
// // Generate this
// object ActivityKeyCreator {
// @JvmStatic
// fun createActivityKey(
// value: Class<out Activity>,
// scope: Class<*>
// ): ActivityKey {
// return ActivityKey(value.kotlin, scope.kotlin)
// }
// }

val packageName = clazz.packageFqName.safePackageString()

if (!clazz.isAnnotationClass()) {
throw AnvilCompilationExceptionClassReference(
message = "@MapKey is only applicable to annotation classes.",
classReference = clazz,
)
@AutoService(CodeGenerator::class)
internal class EmbeddedGenerator : PrivateCodeGenerator() {

override fun isApplicable(context: AnvilContext) = MapKeyCreatorCodeGen.isApplicable(context)

override fun generateCodePrivate(
codeGenDir: File,
module: ModuleDescriptor,
projectFiles: Collection<KtFile>,
) {
projectFiles
.classAndInnerClassReferences(module)
.filter { classRef ->
val mapKey = classRef.annotations.find { it.fqName == mapKeyFqName }
if (mapKey != null) {
mapKey.argumentAt("unwrapValue", 0)?.value<Boolean>() == false
} else {
false
}
}
.forEach { clazz ->
generateCreatorClass(codeGenDir, clazz)
}
}

val className = clazz.asClassName()
private fun generateCreatorClass(
codeGenDir: File,
clazz: ClassReference,
): GeneratedFile {
// // Given this
// @MapKey(unwrapValue = false)
// annotation class ActivityKey(
// val value: KClass<out Activity>,
// val scope: KClass<*>,
// )
//
// // Generate this
// object ActivityKeyCreator {
// @JvmStatic
// fun createActivityKey(
// value: Class<out Activity>,
// scope: Class<*>
// ): ActivityKey {
// return ActivityKey(value.kotlin, scope.kotlin)
// }
// }

val packageName = clazz.packageFqName.safePackageString()

if (!clazz.isAnnotationClass()) {
throw AnvilCompilationExceptionClassReference(
message = "@MapKey is only applicable to annotation classes.",
classReference = clazz,
)
}

val className = clazz.asClassName()

val creatorsToGenerate = mutableSetOf<ClassReference>()
val creatorsToGenerate = mutableSetOf<ClassReference>()

fun visitAnnotations(clazz: ClassReference) {
if (clazz.isAnnotationClass()) {
val added = creatorsToGenerate.add(clazz)
if (added) {
for (property in clazz.properties) {
val type = property.type().asClassReferenceOrNull()
if (type?.isAnnotationClass() == true) {
visitAnnotations(type)
fun visitAnnotations(clazz: ClassReference) {
if (clazz.isAnnotationClass()) {
val added = creatorsToGenerate.add(clazz)
if (added) {
for (property in clazz.properties) {
val type = property.type().asClassReferenceOrNull()
if (type?.isAnnotationClass() == true) {
visitAnnotations(type)
}
}
}
}
}
}

// Populate all used annotations
visitAnnotations(clazz)
// Populate all used annotations
visitAnnotations(clazz)

val creatorFunctions = creatorsToGenerate
.associateBy { annotationClass ->
annotationClass.asTypeName().rawTypeOrNull()
?: throw AnvilCompilationExceptionClassReference(
message = "@MapKey is only applicable to non-generic annotation classes.",
classReference = annotationClass,
)
}
.toSortedMap()
.map { (className, clazz) -> generateCreatorFunction(className, clazz) }
val creatorFunctions = creatorsToGenerate
.associateBy { annotationClass ->
annotationClass.asTypeName().rawTypeOrNull()
?: throw AnvilCompilationExceptionClassReference(
message = "@MapKey is only applicable to non-generic annotation classes.",
classReference = annotationClass,
)
}
.toSortedMap()
.map { (className, clazz) ->
val properties = clazz.properties
.map { AnnotationProperty(it) }
.associateBy { it.name }
generateCreatorFunction(className, properties)
}

val spec = generateCreatorFileSpec(className, creatorFunctions)
val content = spec.toString()

return createGeneratedFile(codeGenDir, packageName, spec.name, content)
}

val simpleName = className.simpleNames.joinToString("_")
private fun generateCreatorFunction(
className: ClassName,
annotationClass: ClassReference,
): FunSpec {
val properties = annotationClass.properties
.map { AnnotationProperty(it) }
.associateBy { it.name }
return generateCreatorFunction(className, properties)
}
}

private fun generateCreatorFileSpec(sourceClass: ClassName, creatorFunctions: List<FunSpec>): FileSpec {
val simpleName = sourceClass.simpleNames.joinToString("_")
val generatedClassName = "${simpleName}Creator"
val content = FileSpec.buildFile(packageName, generatedClassName) {
val spec = FileSpec.createAnvilSpec(sourceClass.packageName, generatedClassName) {
addType(
TypeSpec.objectBuilder(generatedClassName)
.addFunctions(creatorFunctions)
.build(),
)
}

return createGeneratedFile(codeGenDir, packageName, generatedClassName, content)
return spec
}

/**
* Generates a single static creator function for a given annotation [annotationClass].
*/
private fun generateCreatorFunction(
className: ClassName,
annotationClass: ClassReference,
properties: Map<String, AnnotationProperty>,
): FunSpec {
val properties = annotationClass.properties
.map { AnnotationProperty(it) }
.associateBy { it.name }
return FunSpec.builder("create${className.simpleName}")
.addAnnotation(JvmStatic::class)
.apply {
Expand All @@ -189,17 +319,13 @@ private class AnnotationProperty(
val callExpression: CodeBlock,
) {
companion object {
operator fun invoke(
property: MemberPropertyReference,
): AnnotationProperty {
val name = property.name
val typeName = property.type().asTypeName()
val javaType = typeName.resolveJavaType()
private fun create(name: String, type: TypeName): AnnotationProperty {
val javaType = type.resolveJavaType()
val codeBlock = when {
javaType.rawTypeOrNull() == CLASS_CLASS_NAME -> CodeBlock.of("%L.kotlin", name)
typeName is ParameterizedTypeName &&
typeName.rawType == ARRAY &&
typeName.typeArguments[0].rawTypeOrNull() == KCLASS_CLASS_NAME -> {
type is ParameterizedTypeName &&
type.rawType == ARRAY &&
type.typeArguments[0].rawTypeOrNull() == KCLASS_CLASS_NAME -> {
// Dense but this avoids an intermediate list allocation compared to .map { ... }.toTypedArray()
CodeBlock.of("%1T(%2L.size)·{·%2L[it].kotlin·}", ARRAY, name)
}
Expand All @@ -212,6 +338,14 @@ private class AnnotationProperty(
codeBlock,
)
}

operator fun invoke(
property: KSPropertyDeclaration,
): AnnotationProperty = create(property.simpleName.asString(), property.type.toTypeName())

operator fun invoke(
property: MemberPropertyReference,
): AnnotationProperty = create(property.name, property.type().asTypeName())
}
}

Expand Down
Loading

0 comments on commit 6e94ca1

Please sign in to comment.