Skip to content

Commit 8e1666a

Browse files
authored
Support writing nested lists to parquet (#1746)
* Support writing arbitrarily nested arrow arrays (#1744) * More tests * Port more tests * More tests * Review feedback * Reduce test churn * Port remaining tests * Review feedback * Fix clippy
1 parent 9722f06 commit 8e1666a

File tree

2 files changed

+897
-1318
lines changed

2 files changed

+897
-1318
lines changed

parquet/src/arrow/arrow_writer.rs

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use super::schema::{
3333
decimal_length_from_precision,
3434
};
3535

36+
use crate::arrow::levels::calculate_array_levels;
3637
use crate::column::writer::ColumnWriter;
3738
use crate::errors::{ParquetError, Result};
3839
use crate::file::properties::WriterProperties;
@@ -173,16 +174,15 @@ impl<W: Write> ArrowWriter<W> {
173174
}
174175
}
175176

176-
let mut levels: Vec<_> = arrays
177+
let mut levels = arrays
177178
.iter()
178179
.map(|array| {
179-
let batch_level = LevelInfo::new(0, array.len());
180-
let mut levels = batch_level.calculate_array_levels(array, field);
180+
let mut levels = calculate_array_levels(array, field)?;
181181
// Reverse levels as we pop() them when writing arrays
182182
levels.reverse();
183-
levels
183+
Ok(levels)
184184
})
185-
.collect();
185+
.collect::<Result<Vec<_>>>()?;
186186

187187
write_leaves(&mut row_group_writer, &arrays, &mut levels)?;
188188
}
@@ -341,26 +341,23 @@ fn write_leaf(
341341
column: &ArrayRef,
342342
levels: LevelInfo,
343343
) -> Result<i64> {
344-
let indices = levels.filter_array_indices();
345-
// Slice array according to computed offset and length
346-
let column = column.slice(levels.offset, levels.length);
344+
let indices = levels.non_null_indices();
347345
let written = match writer {
348346
ColumnWriter::Int32ColumnWriter(ref mut typed) => {
349347
let values = match column.data_type() {
350348
ArrowDataType::Date64 => {
351349
// If the column is a Date64, we cast it to a Date32, and then interpret that as Int32
352350
let array = if let ArrowDataType::Date64 = column.data_type() {
353-
let array =
354-
arrow::compute::cast(&column, &ArrowDataType::Date32)?;
351+
let array = arrow::compute::cast(column, &ArrowDataType::Date32)?;
355352
arrow::compute::cast(&array, &ArrowDataType::Int32)?
356353
} else {
357-
arrow::compute::cast(&column, &ArrowDataType::Int32)?
354+
arrow::compute::cast(column, &ArrowDataType::Int32)?
358355
};
359356
let array = array
360357
.as_any()
361358
.downcast_ref::<arrow_array::Int32Array>()
362359
.expect("Unable to get int32 array");
363-
get_numeric_array_slice::<Int32Type, _>(array, &indices)
360+
get_numeric_array_slice::<Int32Type, _>(array, indices)
364361
}
365362
ArrowDataType::UInt32 => {
366363
// follow C++ implementation and use overflow/reinterpret cast from u32 to i32 which will map
@@ -373,21 +370,21 @@ fn write_leaf(
373370
array,
374371
|x| x as i32,
375372
);
376-
get_numeric_array_slice::<Int32Type, _>(&array, &indices)
373+
get_numeric_array_slice::<Int32Type, _>(&array, indices)
377374
}
378375
_ => {
379-
let array = arrow::compute::cast(&column, &ArrowDataType::Int32)?;
376+
let array = arrow::compute::cast(column, &ArrowDataType::Int32)?;
380377
let array = array
381378
.as_any()
382379
.downcast_ref::<arrow_array::Int32Array>()
383380
.expect("Unable to get i32 array");
384-
get_numeric_array_slice::<Int32Type, _>(array, &indices)
381+
get_numeric_array_slice::<Int32Type, _>(array, indices)
385382
}
386383
};
387384
typed.write_batch(
388385
values.as_slice(),
389-
Some(levels.definition.as_slice()),
390-
levels.repetition.as_deref(),
386+
levels.def_levels(),
387+
levels.rep_levels(),
391388
)?
392389
}
393390
ColumnWriter::BoolColumnWriter(ref mut typed) => {
@@ -396,9 +393,9 @@ fn write_leaf(
396393
.downcast_ref::<arrow_array::BooleanArray>()
397394
.expect("Unable to get boolean array");
398395
typed.write_batch(
399-
get_bool_array_slice(array, &indices).as_slice(),
400-
Some(levels.definition.as_slice()),
401-
levels.repetition.as_deref(),
396+
get_bool_array_slice(array, indices).as_slice(),
397+
levels.def_levels(),
398+
levels.rep_levels(),
402399
)?
403400
}
404401
ColumnWriter::Int64ColumnWriter(ref mut typed) => {
@@ -408,7 +405,7 @@ fn write_leaf(
408405
.as_any()
409406
.downcast_ref::<arrow_array::Int64Array>()
410407
.expect("Unable to get i64 array");
411-
get_numeric_array_slice::<Int64Type, _>(array, &indices)
408+
get_numeric_array_slice::<Int64Type, _>(array, indices)
412409
}
413410
ArrowDataType::UInt64 => {
414411
// follow C++ implementation and use overflow/reinterpret cast from u64 to i64 which will map
@@ -421,21 +418,21 @@ fn write_leaf(
421418
array,
422419
|x| x as i64,
423420
);
424-
get_numeric_array_slice::<Int64Type, _>(&array, &indices)
421+
get_numeric_array_slice::<Int64Type, _>(&array, indices)
425422
}
426423
_ => {
427-
let array = arrow::compute::cast(&column, &ArrowDataType::Int64)?;
424+
let array = arrow::compute::cast(column, &ArrowDataType::Int64)?;
428425
let array = array
429426
.as_any()
430427
.downcast_ref::<arrow_array::Int64Array>()
431428
.expect("Unable to get i64 array");
432-
get_numeric_array_slice::<Int64Type, _>(array, &indices)
429+
get_numeric_array_slice::<Int64Type, _>(array, indices)
433430
}
434431
};
435432
typed.write_batch(
436433
values.as_slice(),
437-
Some(levels.definition.as_slice()),
438-
levels.repetition.as_deref(),
434+
levels.def_levels(),
435+
levels.rep_levels(),
439436
)?
440437
}
441438
ColumnWriter::Int96ColumnWriter(ref mut _typed) => {
@@ -447,9 +444,9 @@ fn write_leaf(
447444
.downcast_ref::<arrow_array::Float32Array>()
448445
.expect("Unable to get Float32 array");
449446
typed.write_batch(
450-
get_numeric_array_slice::<FloatType, _>(array, &indices).as_slice(),
451-
Some(levels.definition.as_slice()),
452-
levels.repetition.as_deref(),
447+
get_numeric_array_slice::<FloatType, _>(array, indices).as_slice(),
448+
levels.def_levels(),
449+
levels.rep_levels(),
453450
)?
454451
}
455452
ColumnWriter::DoubleColumnWriter(ref mut typed) => {
@@ -458,9 +455,9 @@ fn write_leaf(
458455
.downcast_ref::<arrow_array::Float64Array>()
459456
.expect("Unable to get Float64 array");
460457
typed.write_batch(
461-
get_numeric_array_slice::<DoubleType, _>(array, &indices).as_slice(),
462-
Some(levels.definition.as_slice()),
463-
levels.repetition.as_deref(),
458+
get_numeric_array_slice::<DoubleType, _>(array, indices).as_slice(),
459+
levels.def_levels(),
460+
levels.rep_levels(),
464461
)?
465462
}
466463
ColumnWriter::ByteArrayColumnWriter(ref mut typed) => match column.data_type() {
@@ -471,8 +468,8 @@ fn write_leaf(
471468
.expect("Unable to get BinaryArray array");
472469
typed.write_batch(
473470
get_binary_array(array).as_slice(),
474-
Some(levels.definition.as_slice()),
475-
levels.repetition.as_deref(),
471+
levels.def_levels(),
472+
levels.rep_levels(),
476473
)?
477474
}
478475
ArrowDataType::Utf8 => {
@@ -482,8 +479,8 @@ fn write_leaf(
482479
.expect("Unable to get LargeBinaryArray array");
483480
typed.write_batch(
484481
get_string_array(array).as_slice(),
485-
Some(levels.definition.as_slice()),
486-
levels.repetition.as_deref(),
482+
levels.def_levels(),
483+
levels.rep_levels(),
487484
)?
488485
}
489486
ArrowDataType::LargeBinary => {
@@ -493,8 +490,8 @@ fn write_leaf(
493490
.expect("Unable to get LargeBinaryArray array");
494491
typed.write_batch(
495492
get_large_binary_array(array).as_slice(),
496-
Some(levels.definition.as_slice()),
497-
levels.repetition.as_deref(),
493+
levels.def_levels(),
494+
levels.rep_levels(),
498495
)?
499496
}
500497
ArrowDataType::LargeUtf8 => {
@@ -504,8 +501,8 @@ fn write_leaf(
504501
.expect("Unable to get LargeUtf8 array");
505502
typed.write_batch(
506503
get_large_string_array(array).as_slice(),
507-
Some(levels.definition.as_slice()),
508-
levels.repetition.as_deref(),
504+
levels.def_levels(),
505+
levels.rep_levels(),
509506
)?
510507
}
511508
_ => unreachable!("Currently unreachable because data type not supported"),
@@ -518,14 +515,14 @@ fn write_leaf(
518515
.as_any()
519516
.downcast_ref::<arrow_array::IntervalYearMonthArray>()
520517
.unwrap();
521-
get_interval_ym_array_slice(array, &indices)
518+
get_interval_ym_array_slice(array, indices)
522519
}
523520
IntervalUnit::DayTime => {
524521
let array = column
525522
.as_any()
526523
.downcast_ref::<arrow_array::IntervalDayTimeArray>()
527524
.unwrap();
528-
get_interval_dt_array_slice(array, &indices)
525+
get_interval_dt_array_slice(array, indices)
529526
}
530527
_ => {
531528
return Err(ParquetError::NYI(
@@ -541,14 +538,14 @@ fn write_leaf(
541538
.as_any()
542539
.downcast_ref::<arrow_array::FixedSizeBinaryArray>()
543540
.unwrap();
544-
get_fsb_array_slice(array, &indices)
541+
get_fsb_array_slice(array, indices)
545542
}
546543
ArrowDataType::Decimal(_, _) => {
547544
let array = column
548545
.as_any()
549546
.downcast_ref::<arrow_array::DecimalArray>()
550547
.unwrap();
551-
get_decimal_array_slice(array, &indices)
548+
get_decimal_array_slice(array, indices)
552549
}
553550
_ => {
554551
return Err(ParquetError::NYI(
@@ -559,8 +556,8 @@ fn write_leaf(
559556
};
560557
typed.write_batch(
561558
bytes.as_slice(),
562-
Some(levels.definition.as_slice()),
563-
levels.repetition.as_deref(),
559+
levels.def_levels(),
560+
levels.rep_levels(),
564561
)?
565562
}
566563
};
@@ -593,6 +590,7 @@ macro_rules! def_get_binary_array_fn {
593590
};
594591
}
595592

593+
// TODO: These methods don't handle non null indices correctly (#1753)
596594
def_get_binary_array_fn!(get_binary_array, arrow_array::BinaryArray);
597595
def_get_binary_array_fn!(get_string_array, arrow_array::StringArray);
598596
def_get_binary_array_fn!(get_large_binary_array, arrow_array::LargeBinaryArray);

0 commit comments

Comments
 (0)