Skip to content

Commit

Permalink
Updated RecordDecoderGenerator to run from KClass
Browse files Browse the repository at this point in the history
  • Loading branch information
sksamuel committed Apr 28, 2024
1 parent 3a9d8d5 commit 99c4bb0
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,50 +1,82 @@
package com.sksamuel.centurion.avro.generation

import kotlin.reflect.KClass
import kotlin.reflect.KProperty1
import kotlin.reflect.KType
import kotlin.reflect.full.declaredMemberProperties

/**
* A code generator that outputs Kotlin that will deserialize a generic record for a given
* type into an instance of that data class.
*/
class RecordDecoderGenerator {
fun generate(ds: DataClass): String {

fun generate(kclass: KClass<*>): String {
return buildString {
appendLine("package ${ds.packageName}")
appendLine("package ${kclass.java.packageName}")
appendLine()
appendLine("import com.sksamuel.centurion.avro.decoders.*")
appendLine("import org.apache.avro.Schema")
appendLine("import org.apache.avro.generic.GenericData")
appendLine("import org.apache.avro.generic.GenericRecord")
appendLine()
appendLine("/**")
appendLine(" * This is a generated [Decoder] that deserializes Avro [GenericRecord]s to [${ds.className}]s")
appendLine(" * This is a generated [Decoder] that deserializes Avro [GenericRecord]s to [${kclass.java.simpleName}]s")
appendLine(" */")
appendLine("object ${ds.className}Decoder : Decoder<${ds.className}> {")
appendLine(" override fun decode(schema: Schema, value: Any?): ${ds.className} {")
appendLine(" require(value is GenericRecord)")
appendLine(" return ${ds.className}(")
ds.members.forEach { member ->
appendLine(" ${member.name} = ${decoderFor(member)},")
appendLine("object ${kclass.java.simpleName}Decoder : Decoder<${kclass.java.simpleName}> {")
appendLine()
appendLine(" override fun decode(schema: Schema): (Any?) -> ${kclass.java.simpleName} {")
appendLine()
kclass.declaredMemberProperties.forEach { property ->
appendLine(" val ${property.name}Schema = schema.getField(\"${property.name}\").schema()")
appendLine(" val ${property.name}Pos = schema.getField(\"${property.name}\").pos()")
appendLine(" val ${property.name}Decode = ${decode(property)}")
}
appendLine(" )")
appendLine()
appendLine(" return { record ->")
appendLine(" require(record is GenericRecord)")
appendLine(" ${kclass.java.simpleName}(")
kclass.declaredMemberProperties.forEach { property ->
appendLine(" ${property.name} = ${property.name}Decode(record[${property.name}Pos]),")
}
appendLine(" )")
appendLine(" }")
appendLine(" }")
appendLine("}")
}
}

private fun decode(property: KProperty1<out Any, *>): String {
val baseDecoder = decoderFor(property.returnType)
val wrapped = if (property.returnType.isMarkedNullable) "NullDecoder($baseDecoder)" else baseDecoder
return "$wrapped.decode(${property.name}Schema)"
}

private fun decoderFor(type: KType): String {
return when (val classifier = type.classifier) {
Boolean::class -> "BooleanDecoder"
Double::class -> "DoubleDecoder"
Float::class -> "FloatDecoder"
Int::class -> "IntDecoder"
Long::class -> "LongDecoder"
String::class -> "StringDecoder"
Set::class -> {
val elementDecoder = decoderFor(type.arguments.first().type!!)
"SetDecoder($elementDecoder)"
}

List::class -> {
val elementDecoder = decoderFor(type.arguments.first().type!!)
"ListDecoder($elementDecoder)"
}

Map::class -> {
val valueDecoder = decoderFor(type.arguments[1].type!!)
"MapDecoder($valueDecoder)"
}

private fun decoderFor(member: Member): String {
val getSchema = "schema.getField(\"${member.name}\").schema()"
val getValue = "value.get(\"${member.name}\")"
return when (member.type) {
Type.BooleanType -> "BooleanDecoder.decode($getSchema, $getValue)"
Type.DoubleType -> "DoubleDecoder.decode($getSchema, $getValue)"
Type.FloatType -> "FloatDecoder.decode($getSchema, $getValue)"
Type.IntType -> "IntDecoder.decode($getSchema, $getValue)"
Type.LongType -> "LongDecoder.decode($getSchema, $getValue)"
is Type.Nullable -> TODO()
is Type.RecordType -> TODO()
Type.StringType -> "StringDecoder.decode($getSchema, $getValue)"
is Type.ArrayType -> "ListDecoder.decode($getSchema, $getValue)"
is Type.MapType -> "MapDecoder.decode($getSchema, $getValue)"
is KClass<*> -> if (classifier.java.isEnum) "EnumDecoder<${classifier.java.name}>()" else error("Unsupported type: $type")
else -> error("Unsupported type: $type")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,67 @@ import io.kotest.matchers.shouldBe
class RecordDecoderGeneratorTest : FunSpec({

test("simple decoder") {
RecordDecoderGenerator().generate(
DataClass(
"a.b",
"Foo",
listOf(
Member("a", Type.BooleanType),
Member("b", Type.StringType),
)
)
).trim() shouldBe """
package a.b
RecordDecoderGenerator().generate(MyFoo::class).trim() shouldBe """
package com.sksamuel.centurion.avro.generation
import com.sksamuel.centurion.avro.decoders.*
import org.apache.avro.Schema
import org.apache.avro.generic.GenericData
import org.apache.avro.generic.GenericRecord
/**
* This is a generated [Decoder] that deserializes Avro [GenericRecord]s to [Foo]s
* This is a generated [Decoder] that deserializes Avro [GenericRecord]s to [MyFoo]s
*/
object FooDecoder : Decoder<Foo> {
override fun decode(schema: Schema, value: Any?): Foo {
require(value is GenericRecord)
return Foo(
a = BooleanDecoder.decode(schema.getField("a").schema(), value.get("a")),
b = StringDecoder.decode(schema.getField("b").schema(), value.get("b")),
)
object MyFooDecoder : Decoder<MyFoo> {
override fun decode(schema: Schema): (Any?) -> MyFoo {
val bSchema = schema.getField("b").schema()
val bPos = schema.getField("b").pos()
val bDecode = BooleanDecoder.decode(bSchema)
val cSchema = schema.getField("c").schema()
val cPos = schema.getField("c").pos()
val cDecode = LongDecoder.decode(cSchema)
val dSchema = schema.getField("d").schema()
val dPos = schema.getField("d").pos()
val dDecode = DoubleDecoder.decode(dSchema)
val fSchema = schema.getField("f").schema()
val fPos = schema.getField("f").pos()
val fDecode = FloatDecoder.decode(fSchema)
val iSchema = schema.getField("i").schema()
val iPos = schema.getField("i").pos()
val iDecode = IntDecoder.decode(iSchema)
val listsSchema = schema.getField("lists").schema()
val listsPos = schema.getField("lists").pos()
val listsDecode = ListDecoder(IntDecoder).decode(listsSchema)
val mapsSchema = schema.getField("maps").schema()
val mapsPos = schema.getField("maps").pos()
val mapsDecode = MapDecoder(DoubleDecoder).decode(mapsSchema)
val sSchema = schema.getField("s").schema()
val sPos = schema.getField("s").pos()
val sDecode = NullDecoder(StringDecoder).decode(sSchema)
val setsSchema = schema.getField("sets").schema()
val setsPos = schema.getField("sets").pos()
val setsDecode = SetDecoder(StringDecoder).decode(setsSchema)
val wineSchema = schema.getField("wine").schema()
val winePos = schema.getField("wine").pos()
val wineDecode = NullDecoder(EnumDecoder<com.sksamuel.centurion.avro.encoders.Wine>()).decode(wineSchema)
return { record ->
require(record is GenericRecord)
MyFoo(
b = bDecode(record[bPos]),
c = cDecode(record[cPos]),
d = dDecode(record[dPos]),
f = fDecode(record[fPos]),
i = iDecode(record[iPos]),
lists = listsDecode(record[listsPos]),
maps = mapsDecode(record[mapsPos]),
s = sDecode(record[sPos]),
sets = setsDecode(record[setsPos]),
wine = wineDecode(record[winePos]),
)
}
}
}
""".trim()
Expand Down

0 comments on commit 99c4bb0

Please sign in to comment.