Skip to content

Commit 1ad1e8e

Browse files
Xuanwosvencowart
authored andcommitted
feat(parquet): Add next_row_group API for ParquetRecordBatchStream (apache#6907)
* feat(parquet): Add next_row_group API for ParquetRecordBatchStream Signed-off-by: Xuanwo <[email protected]> * chore: Returning error instead of using unreachable Signed-off-by: Xuanwo <[email protected]> --------- Signed-off-by: Xuanwo <[email protected]>
1 parent ef65a10 commit 1ad1e8e

File tree

1 file changed

+132
-0
lines changed
  • parquet/src/arrow/async_reader

1 file changed

+132
-0
lines changed

parquet/src/arrow/async_reader/mod.rs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,9 @@ impl<T> std::fmt::Debug for StreamState<T> {
613613

614614
/// An asynchronous [`Stream`](https://docs.rs/futures/latest/futures/stream/trait.Stream.html) of [`RecordBatch`]
615615
/// for a parquet file that can be constructed using [`ParquetRecordBatchStreamBuilder`].
616+
///
617+
/// `ParquetRecordBatchStream` also provides [`ParquetRecordBatchStream::next_row_group`] for fetching row groups,
618+
/// allowing users to decode record batches separately from I/O.
616619
pub struct ParquetRecordBatchStream<T> {
617620
metadata: Arc<ParquetMetaData>,
618621

@@ -654,6 +657,70 @@ impl<T> ParquetRecordBatchStream<T> {
654657
}
655658
}
656659

660+
impl<T> ParquetRecordBatchStream<T>
661+
where
662+
T: AsyncFileReader + Unpin + Send + 'static,
663+
{
664+
/// Fetches the next row group from the stream.
665+
///
666+
/// Users can continue to call this function to get row groups and decode them concurrently.
667+
///
668+
/// ## Notes
669+
///
670+
/// ParquetRecordBatchStream should be used either as a `Stream` or with `next_row_group`; they should not be used simultaneously.
671+
///
672+
/// ## Returns
673+
///
674+
/// - `Ok(None)` if the stream has ended.
675+
/// - `Err(error)` if the stream has errored. All subsequent calls will return `Ok(None)`.
676+
/// - `Ok(Some(reader))` which holds all the data for the row group.
677+
pub async fn next_row_group(&mut self) -> Result<Option<ParquetRecordBatchReader>> {
678+
loop {
679+
match &mut self.state {
680+
StreamState::Decoding(_) | StreamState::Reading(_) => {
681+
return Err(ParquetError::General(
682+
"Cannot combine the use of next_row_group with the Stream API".to_string(),
683+
))
684+
}
685+
StreamState::Init => {
686+
let row_group_idx = match self.row_groups.pop_front() {
687+
Some(idx) => idx,
688+
None => return Ok(None),
689+
};
690+
691+
let row_count = self.metadata.row_group(row_group_idx).num_rows() as usize;
692+
693+
let selection = self.selection.as_mut().map(|s| s.split_off(row_count));
694+
695+
let reader_factory = self.reader.take().expect("lost reader");
696+
697+
let (reader_factory, maybe_reader) = reader_factory
698+
.read_row_group(
699+
row_group_idx,
700+
selection,
701+
self.projection.clone(),
702+
self.batch_size,
703+
)
704+
.await
705+
.map_err(|err| {
706+
self.state = StreamState::Error;
707+
err
708+
})?;
709+
self.reader = Some(reader_factory);
710+
711+
if let Some(reader) = maybe_reader {
712+
return Ok(Some(reader));
713+
} else {
714+
// All rows skipped, read next row group
715+
continue;
716+
}
717+
}
718+
StreamState::Error => return Ok(None), // Ends the stream as error happens.
719+
}
720+
}
721+
}
722+
}
723+
657724
impl<T> Stream for ParquetRecordBatchStream<T>
658725
where
659726
T: AsyncFileReader + Unpin + Send + 'static,
@@ -1020,6 +1087,71 @@ mod tests {
10201087
);
10211088
}
10221089

1090+
#[tokio::test]
1091+
async fn test_async_reader_with_next_row_group() {
1092+
let testdata = arrow::util::test_util::parquet_test_data();
1093+
let path = format!("{testdata}/alltypes_plain.parquet");
1094+
let data = Bytes::from(std::fs::read(path).unwrap());
1095+
1096+
let metadata = ParquetMetaDataReader::new()
1097+
.parse_and_finish(&data)
1098+
.unwrap();
1099+
let metadata = Arc::new(metadata);
1100+
1101+
assert_eq!(metadata.num_row_groups(), 1);
1102+
1103+
let async_reader = TestReader {
1104+
data: data.clone(),
1105+
metadata: metadata.clone(),
1106+
requests: Default::default(),
1107+
};
1108+
1109+
let requests = async_reader.requests.clone();
1110+
let builder = ParquetRecordBatchStreamBuilder::new(async_reader)
1111+
.await
1112+
.unwrap();
1113+
1114+
let mask = ProjectionMask::leaves(builder.parquet_schema(), vec![1, 2]);
1115+
let mut stream = builder
1116+
.with_projection(mask.clone())
1117+
.with_batch_size(1024)
1118+
.build()
1119+
.unwrap();
1120+
1121+
let mut readers = vec![];
1122+
while let Some(reader) = stream.next_row_group().await.unwrap() {
1123+
readers.push(reader);
1124+
}
1125+
1126+
let async_batches: Vec<_> = readers
1127+
.into_iter()
1128+
.flat_map(|r| r.map(|v| v.unwrap()).collect::<Vec<_>>())
1129+
.collect();
1130+
1131+
let sync_batches = ParquetRecordBatchReaderBuilder::try_new(data)
1132+
.unwrap()
1133+
.with_projection(mask)
1134+
.with_batch_size(104)
1135+
.build()
1136+
.unwrap()
1137+
.collect::<ArrowResult<Vec<_>>>()
1138+
.unwrap();
1139+
1140+
assert_eq!(async_batches, sync_batches);
1141+
1142+
let requests = requests.lock().unwrap();
1143+
let (offset_1, length_1) = metadata.row_group(0).column(1).byte_range();
1144+
let (offset_2, length_2) = metadata.row_group(0).column(2).byte_range();
1145+
1146+
assert_eq!(
1147+
&requests[..],
1148+
&[
1149+
offset_1 as usize..(offset_1 + length_1) as usize,
1150+
offset_2 as usize..(offset_2 + length_2) as usize
1151+
]
1152+
);
1153+
}
1154+
10231155
#[tokio::test]
10241156
async fn test_async_reader_with_index() {
10251157
let testdata = arrow::util::test_util::parquet_test_data();

0 commit comments

Comments
 (0)