Skip to content

Commit 0defedd

Browse files
committed
arrow parquet tests from #577
1 parent 964360c commit 0defedd

File tree

8 files changed

+256
-50
lines changed

8 files changed

+256
-50
lines changed

dataframe-arrow/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies {
2020
implementation(libs.arrow.vector)
2121
implementation(libs.arrow.format)
2222
implementation(libs.arrow.memory)
23+
implementation(libs.arrow.dataset)
2324
implementation(libs.commonsCompress)
2425
implementation(libs.kotlin.reflect)
2526
implementation(libs.kotlin.datetimeJvm)

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReading.kt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.jetbrains.kotlinx.dataframe.io
22

3+
import org.apache.arrow.dataset.file.FileFormat
34
import org.apache.arrow.memory.RootAllocator
45
import org.apache.arrow.vector.ipc.ArrowReader
56
import org.apache.commons.compress.utils.SeekableInMemoryByteChannel
@@ -185,3 +186,17 @@ public fun DataFrame.Companion.readArrow(
185186
*/
186187
public fun ArrowReader.toDataFrame(nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame =
187188
DataFrame.Companion.readArrowImpl(this, nullability)
189+
190+
/**
191+
* Read [Parquet](https://parquet.apache.org/) data from existing [url] by using [Arrow Dataset](https://arrow.apache.org/docs/java/dataset.html)
192+
*/
193+
public fun DataFrame.Companion.readParquet(
194+
url: URL,
195+
nullability: NullabilityOptions = NullabilityOptions.Infer,
196+
): AnyFrame = readArrowDataset(url.toString(), fileFormat = FileFormat.PARQUET, nullability = nullability)
197+
198+
public fun DataFrame.Companion.readArrowDataset(
199+
vararg fileUri: String,
200+
fileFormat: FileFormat,
201+
nullability: NullabilityOptions = NullabilityOptions.Infer,
202+
): AnyFrame = readArrowDatasetImpl(fileUri, fileFormat, nullability)

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ import kotlinx.datetime.LocalTime
66
import kotlinx.datetime.toKotlinLocalDate
77
import kotlinx.datetime.toKotlinLocalDateTime
88
import kotlinx.datetime.toKotlinLocalTime
9+
import org.apache.arrow.dataset.file.FileFormat
10+
import org.apache.arrow.dataset.file.FileSystemDatasetFactory
11+
import org.apache.arrow.dataset.jni.DirectReservationListener
12+
import org.apache.arrow.dataset.jni.NativeMemoryPool
13+
import org.apache.arrow.dataset.scanner.ScanOptions
914
import org.apache.arrow.memory.RootAllocator
1015
import org.apache.arrow.vector.BigIntVector
1116
import org.apache.arrow.vector.BitVector
@@ -55,6 +60,7 @@ import org.apache.arrow.vector.util.DateUtility.getLocalDateTimeFromEpochMilli
5560
import org.apache.arrow.vector.util.DateUtility.getLocalDateTimeFromEpochNano
5661
import org.jetbrains.kotlinx.dataframe.AnyBaseCol
5762
import org.jetbrains.kotlinx.dataframe.AnyFrame
63+
import org.jetbrains.kotlinx.dataframe.AnyRow
5864
import org.jetbrains.kotlinx.dataframe.DataColumn
5965
import org.jetbrains.kotlinx.dataframe.DataFrame
6066
import org.jetbrains.kotlinx.dataframe.api.Infer
@@ -65,6 +71,7 @@ import org.jetbrains.kotlinx.dataframe.api.cast
6571
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
6672
import org.jetbrains.kotlinx.dataframe.api.emptyDataFrame
6773
import org.jetbrains.kotlinx.dataframe.api.getColumn
74+
import org.jetbrains.kotlinx.dataframe.api.single
6875
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
6976
import org.jetbrains.kotlinx.dataframe.impl.asList
7077
import java.math.BigDecimal
@@ -244,9 +251,13 @@ private fun TimeStampSecTZVector.values(range: IntRange): List<LocalDateTime?> =
244251
}
245252
}
246253

247-
private fun StructVector.values(range: IntRange): List<Map<String, Any?>?> =
254+
// TODO, use recursive type-mapping instead of inference
255+
private fun StructVector.values(range: IntRange): List<AnyRow?> =
248256
range.map {
249257
getObject(it)
258+
?.mapValues { listOf(it.value) }
259+
?.toDataFrame()
260+
?.single()
250261
}
251262

252263
private fun ListVector.values(range: IntRange): List<List<Any?>?> =
@@ -432,7 +443,7 @@ private fun readField(root: VectorSchemaRoot, field: Field, nullability: Nullabi
432443
else -> throw NotImplementedError("reading from ${vector.javaClass.canonicalName} is not implemented")
433444
}
434445

435-
return DataColumn.createValueColumn(name = field.name, values = list, type = type, infer = infer)
446+
return DataColumn.createByType(name = field.name, values = list, type = type, infer = infer)
436447
} catch (_: NullabilityException) {
437448
throw IllegalArgumentException("Column `${field.name}` should be not nullable but has nulls")
438449
}
@@ -469,23 +480,73 @@ internal fun DataFrame.Companion.readArrowImpl(
469480
is ArrowFileReader -> {
470481
reader.recordBlocks.forEach { block ->
471482
reader.loadRecordBatch(block)
472-
val root = reader.vectorSchemaRoot
473-
val schema = root.schema
474-
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
475-
add(df)
483+
reader.vectorSchemaRoot.use { root ->
484+
val schema = root.schema
485+
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
486+
add(df)
487+
}
476488
}
477489
}
478490

479491
else -> {
480-
val root = reader.vectorSchemaRoot
481-
val schema = root.schema
482-
while (reader.loadNextBatch()) {
483-
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
484-
add(df)
492+
reader.vectorSchemaRoot.use { root ->
493+
val schema = root.schema
494+
while (reader.loadNextBatch()) {
495+
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
496+
add(df)
497+
}
485498
}
486499
}
487500
}
488501
}
489502
return flattened.concatKeepingSchema()
490503
}
491504
}
505+
506+
internal fun DataFrame.Companion.readArrowDatasetImpl2(
507+
fileUris: Array<out String>,
508+
fileFormat: FileFormat,
509+
nullability: NullabilityOptions = NullabilityOptions.Infer,
510+
): AnyFrame =
511+
RootAllocator().use { allocator ->
512+
FileSystemDatasetFactory(
513+
// allocator =
514+
allocator,
515+
// memoryPool =
516+
NativeMemoryPool.createListenable(DirectReservationListener.instance()),
517+
// format =
518+
fileFormat,
519+
// uris =
520+
fileUris,
521+
).use { datasetFactory ->
522+
datasetFactory.finish().use { dataset ->
523+
dataset.newScan(ScanOptions(32_768)).use { scanner ->
524+
scanner.scanBatches().use { reader ->
525+
readArrowImpl(reader = reader, nullability = nullability)
526+
}
527+
}
528+
}
529+
}
530+
}
531+
532+
internal fun DataFrame.Companion.readArrowDatasetImpl(
533+
fileUris: Array<out String>,
534+
fileFormat: FileFormat,
535+
nullability: NullabilityOptions = NullabilityOptions.Infer,
536+
): AnyFrame =
537+
using {
538+
val allocator = +RootAllocator()
539+
val datasetFactory = +FileSystemDatasetFactory(
540+
allocator,
541+
NativeMemoryPool.createListenable(DirectReservationListener.instance()),
542+
fileFormat,
543+
fileUris,
544+
)
545+
val dataset = +datasetFactory.finish()
546+
val scanner = +dataset.newScan(ScanOptions(32_768))
547+
val reader = +scanner.scanBatches()
548+
549+
readArrowImpl(reader, nullability)
550+
} catch { e: Exception ->
551+
e.printStackTrace()
552+
} finally { }
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package org.jetbrains.kotlinx.dataframe.io
2+
3+
import java.util.concurrent.ConcurrentLinkedQueue
4+
import kotlin.contracts.ExperimentalContracts
5+
import kotlin.contracts.InvocationKind
6+
import kotlin.contracts.contract
7+
8+
@OptIn(ExperimentalContracts::class)
9+
internal inline fun <R> using(crossinline block: UsingResources<R>.() -> R): Catcher<R> {
10+
contract {
11+
callsInPlace(block, InvocationKind.EXACTLY_ONCE)
12+
}
13+
val manager = UsingResourcesImpl<R>()
14+
try {
15+
val result = manager.use(block)
16+
manager.result = result
17+
} catch (t: Throwable) {
18+
manager.throwable = t
19+
}
20+
return manager.getCatcher()
21+
}
22+
23+
internal interface UsingResources<R> {
24+
25+
fun <T : AutoCloseable> T.use(): T
26+
27+
operator fun <T : AutoCloseable> T.unaryPlus(): T = use()
28+
29+
operator fun <T : AutoCloseable> T.invoke(): T = use()
30+
}
31+
32+
internal class UsingResourcesImpl<R> :
33+
AutoCloseable,
34+
UsingResources<R> {
35+
var throwable: Throwable? = null
36+
val resourceQueue = ConcurrentLinkedQueue<AutoCloseable>()
37+
38+
var result: R? = null
39+
40+
override fun <T : AutoCloseable> T.use(): T {
41+
resourceQueue.offer(this)
42+
return this
43+
}
44+
45+
override fun close() {
46+
for (closeable in resourceQueue) {
47+
try {
48+
closeable.close()
49+
} catch (t: Throwable) {
50+
if (this.throwable == null) {
51+
this.throwable = t
52+
} else {
53+
this.throwable!!.addSuppressed(t)
54+
}
55+
}
56+
}
57+
}
58+
59+
fun getCatcher(): Catcher<R> = Catcher(this)
60+
}
61+
62+
internal class Catcher<R>(val manager: UsingResourcesImpl<R>) {
63+
var throwable: Throwable? = null
64+
var thrown: Throwable? = null
65+
66+
init {
67+
throwable = manager.throwable
68+
}
69+
70+
inline infix fun <reified T : Throwable> catch(block: (T) -> Unit): Catcher<R> {
71+
if (throwable is T) {
72+
try {
73+
block(throwable as T)
74+
} catch (thrown: Throwable) {
75+
this.thrown = thrown
76+
} finally {
77+
// It's been caught, so set it to null
78+
throwable = null
79+
}
80+
}
81+
return this
82+
}
83+
84+
inline infix fun finally(block: () -> Unit): R {
85+
try {
86+
block()
87+
} catch (thrown: Throwable) {
88+
if (throwable == null) {
89+
// we've caught the exception, or none was thrown
90+
if (this.thrown == null) {
91+
// No exception was thrown in the catch blocks
92+
throw thrown
93+
} else {
94+
// An exception was thrown in the catch block
95+
this.thrown!!.let {
96+
it.addSuppressed(thrown)
97+
throw it
98+
}
99+
}
100+
} else {
101+
// We never caught the exception
102+
// So therefore this.thrown is also null
103+
throwable!!.let {
104+
it.addSuppressed(thrown)
105+
throw it
106+
}
107+
}
108+
}
109+
110+
// At this point the finally block did not thrown an exception
111+
// We need to see if there are still any exceptions left to throw
112+
throwable?.let { throwable ->
113+
thrown?.let { throwable.addSuppressed(it) }
114+
throw throwable
115+
}
116+
thrown?.let { throw it }
117+
118+
return manager.result as R
119+
}
120+
}

dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,11 @@ package org.jetbrains.kotlinx.dataframe.io
33
import io.kotest.assertions.throwables.shouldThrow
44
import io.kotest.matchers.collections.shouldContain
55
import io.kotest.matchers.shouldBe
6-
import io.zonky.test.db.postgres.junit.EmbeddedPostgresRules
7-
import io.zonky.test.db.postgres.junit.SingleInstancePostgresRule
86
import kotlinx.datetime.LocalDate
97
import kotlinx.datetime.LocalDateTime
108
import kotlinx.datetime.UtcOffset
119
import kotlinx.datetime.toInstant
1210
import kotlinx.datetime.toJavaInstant
13-
import org.apache.arrow.adapter.jdbc.JdbcFieldInfo
14-
import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder
15-
import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils
16-
import org.apache.arrow.adbc.core.AdbcDriver
17-
import org.apache.arrow.adbc.driver.jdbc.JdbcConnection
18-
import org.apache.arrow.adbc.driver.jdbc.JdbcDriver
19-
import org.apache.arrow.adbc.driver.jdbc.JdbcQuirks
2011
import org.apache.arrow.memory.RootAllocator
2112
import org.apache.arrow.vector.TimeStampMicroVector
2213
import org.apache.arrow.vector.TimeStampMilliVector
@@ -25,10 +16,8 @@ import org.apache.arrow.vector.TimeStampSecVector
2516
import org.apache.arrow.vector.VectorSchemaRoot
2617
import org.apache.arrow.vector.ipc.ArrowFileReader
2718
import org.apache.arrow.vector.ipc.ArrowFileWriter
28-
import org.apache.arrow.vector.ipc.ArrowReader
2919
import org.apache.arrow.vector.ipc.ArrowStreamReader
3020
import org.apache.arrow.vector.ipc.ArrowStreamWriter
31-
import org.apache.arrow.vector.types.DateUnit
3221
import org.apache.arrow.vector.types.FloatingPointPrecision
3322
import org.apache.arrow.vector.types.TimeUnit
3423
import org.apache.arrow.vector.types.pojo.ArrowType
@@ -37,15 +26,11 @@ import org.apache.arrow.vector.types.pojo.FieldType
3726
import org.apache.arrow.vector.types.pojo.Schema
3827
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
3928
import org.apache.arrow.vector.util.Text
40-
import org.duckdb.DuckDBConnection
41-
import org.duckdb.DuckDBResultSet
42-
import org.intellij.lang.annotations.Language
4329
import org.jetbrains.kotlinx.dataframe.AnyFrame
4430
import org.jetbrains.kotlinx.dataframe.DataColumn
4531
import org.jetbrains.kotlinx.dataframe.DataFrame
4632
import org.jetbrains.kotlinx.dataframe.api.NullabilityOptions
4733
import org.jetbrains.kotlinx.dataframe.api.add
48-
import org.jetbrains.kotlinx.dataframe.api.asIterable
4934
import org.jetbrains.kotlinx.dataframe.api.columnOf
5035
import org.jetbrains.kotlinx.dataframe.api.convertToBoolean
5136
import org.jetbrains.kotlinx.dataframe.api.copy
@@ -56,26 +41,14 @@ import org.jetbrains.kotlinx.dataframe.api.print
5641
import org.jetbrains.kotlinx.dataframe.api.remove
5742
import org.jetbrains.kotlinx.dataframe.api.toColumn
5843
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
59-
import org.junit.Assert
60-
import org.junit.Rule
6144
import org.junit.Test
62-
import org.postgresql.ds.PGSimpleDataSource
6345
import java.io.ByteArrayInputStream
6446
import java.io.ByteArrayOutputStream
6547
import java.io.File
66-
import java.math.BigDecimal
48+
import java.net.URI
6749
import java.net.URL
6850
import java.nio.channels.Channels
69-
import java.sql.Connection
70-
import java.sql.Date
71-
import java.sql.DriverManager
72-
import java.sql.Time
73-
import java.sql.Timestamp
74-
import java.sql.Types
7551
import java.util.Locale
76-
import java.util.UUID
77-
import kotlin.reflect.full.memberProperties
78-
import kotlin.reflect.jvm.isAccessible
7952
import kotlin.reflect.typeOf
8053

8154
class ArrowKtTest {
@@ -656,4 +629,24 @@ class ArrowKtTest {
656629
val arrowStreamReader = ArrowStreamReader(ipcInputStream, RootAllocator())
657630
arrowStreamReader.toDataFrame() shouldBe expected
658631
}
632+
633+
@Test
634+
fun testReadParquet() {
635+
val path = testResource("test.arrow.parquet").path
636+
val dataFrame = DataFrame.readParquet(URI("file:$path").toURL())
637+
dataFrame.rowsCount() shouldBe 300
638+
assertEstimations(
639+
exampleFrame = dataFrame,
640+
expectedNullable = false,
641+
hasNulls = false,
642+
fromParquet = true,
643+
)
644+
}
645+
646+
@Test
647+
fun testReadParquet2() {
648+
val path = testResource("snappy.parquet").path
649+
val dataFrame = DataFrame.readParquet(URI("file:$path").toURL())
650+
dataFrame.print(columnTypes = true, borders = true)
651+
}
659652
}

0 commit comments

Comments
 (0)