Skip to content

Commit

Permalink
Support encoding sets and lists
Browse files Browse the repository at this point in the history
  • Loading branch information
sksamuel committed Apr 14, 2024
1 parent c86dd97 commit 32631d6
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ fun interface Encoder<T> {

companion object {

/**
* Returns an [Encoder] that encodes using the supplied function.
*/
operator fun <T> invoke(f: (T) -> Any) = Encoder<T> { _, value -> f(value) }

/**
* Returns an [Encoder] that encodes by simply returning the input value.
*/
Expand All @@ -43,6 +38,8 @@ fun interface Encoder<T> {
Int::class -> IntEncoder
Long::class -> LongEncoder
BigDecimal::class -> BigDecimalStringEncoder
Set::class -> GenericArraySetEncoder(encoderFor(type.arguments.first().type!!))
List::class -> GenericArrayListEncoder(encoderFor(type.arguments.first().type!!))
is KClass<*> -> if (classifier.java.isEnum) EnumEncoder<Enum<*>>() else error("Unsupported type $type")
else -> error("Unsupported type $type")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class GenericArrayArrayEncoder<T>(private val encoder: Encoder<T>) : Encoder<Arr
override fun encode(schema: Schema, value: Array<T>): Any {
require(schema.type == Schema.Type.ARRAY)
val elements = value.map { encoder.encode(schema.elementType, it) }
return GenericData.Array<T>(elements.size, schema.elementType)
return GenericData.Array<T>(elements.size, schema).also { it.addAll(elements as Collection<T>) }
}
}

Expand All @@ -21,6 +21,17 @@ class GenericArrayListEncoder<T>(private val encoder: Encoder<T>) : Encoder<List
override fun encode(schema: Schema, value: List<T>): Any {
require(schema.type == Schema.Type.ARRAY)
val elements = value.map { encoder.encode(schema.elementType, it) }
return GenericData.Array<T>(elements.size, schema.elementType)
return GenericData.Array<T>(elements.size, schema).also { it.addAll(elements as Collection<T>) }
}
}

/**
* An [Encoder] for Sets of [T] that encodes into an Avro [GenericArray].
*/
class GenericArraySetEncoder<T>(private val encoder: Encoder<T>) : Encoder<Set<T>> {
override fun encode(schema: Schema, value: Set<T>): Any {
require(schema.type == Schema.Type.ARRAY)
val elements = value.map { encoder.encode(schema.elementType, it) }
return GenericData.Array<T>(elements.size, schema).also { it.addAll(elements as Collection<T>) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class NullEncoder<T>(private val encoder: Encoder<T>) : Encoder<T?> {
// nullables must be encoded with a union of 2 elements, where null is the first type
require(schema.type == Schema.Type.UNION) { "Nulls can only be encoded with a UNION schema" }
require(schema.types.size == 2) { "Nulls can only be encoded with a 2 element union schema" }
require(schema.types[0].type == Schema.Type.NULL) { "Nullable unions must have NULL as the first element type" }
return if (value == null) null else encoder.encode(schema.types[1], value)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.apache.avro.SchemaBuilder
import org.apache.avro.generic.GenericData
import org.apache.avro.util.Utf8

class RecordEncoderTest : FunSpec({
class ReflectionRecordEncoderTest : FunSpec({

test("basic test") {
data class Foo(val a: String, val b: Boolean)
Expand Down Expand Up @@ -37,6 +37,36 @@ class RecordEncoderTest : FunSpec({

actual shouldBe expected
}

test("sets") {
data class Foo(val set1: Set<Int>, val set2: Set<Long?>)

val schema = SchemaBuilder.record("Foo").fields()
.name("set1").type().array().items().intType().noDefault()
.name("set2").type().array().items().type(SchemaBuilder.nullable().longType()).noDefault()
.endRecord()

val expected = GenericData.Record(schema)
expected.put("set1", listOf(1, 2))
expected.put("set2", listOf(1L, null, 2L))

ReflectionRecordEncoder().encode(schema, Foo(setOf(1, 2), setOf(1L, null, 2L))) shouldBe expected
}

test("list") {
data class Foo(val list1: List<Int>, val list2: List<Long?>)

val schema = SchemaBuilder.record("Foo").fields()
.name("list1").type().array().items().intType().noDefault()
.name("list2").type().array().items().type(SchemaBuilder.nullable().longType()).noDefault()
.endRecord()

val expected = GenericData.Record(schema)
expected.put("list1", listOf(1, 2))
expected.put("list2", listOf(1L, null, 2L))

ReflectionRecordEncoder().encode(schema, Foo(listOf(1, 2), listOf(1L, null, 2L))) shouldBe expected
}
})

enum class Wine { Shiraz, Malbec }

0 comments on commit 32631d6

Please sign in to comment.