Skip to content

Commit

Permalink
feat: Add IPC source node for new streaming engine (#19454)
Browse files Browse the repository at this point in the history
Co-authored-by: Orson Peters <[email protected]>
  • Loading branch information
coastalwhite and orlp authored Nov 13, 2024
1 parent 18786ac commit 8cb7839
Show file tree
Hide file tree
Showing 18 changed files with 843 additions and 49 deletions.
18 changes: 13 additions & 5 deletions crates/polars-arrow/src/io/ipc/read/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,14 @@ pub fn read_dictionary<R: Read + Seek>(
Ok(())
}

pub fn prepare_projection(
schema: &ArrowSchema,
mut projection: Vec<usize>,
) -> (Vec<usize>, PlHashMap<usize, usize>, ArrowSchema) {
#[derive(Clone)]
pub struct ProjectionInfo {
pub columns: Vec<usize>,
pub map: PlHashMap<usize, usize>,
pub schema: ArrowSchema,
}

pub fn prepare_projection(schema: &ArrowSchema, mut projection: Vec<usize>) -> ProjectionInfo {
let schema = projection
.iter()
.map(|x| {
Expand Down Expand Up @@ -355,7 +359,11 @@ pub fn prepare_projection(
}
}

(projection, map, schema)
ProjectionInfo {
columns: projection,
map,
schema,
}
}

pub fn apply_projection(
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/io/ipc/read/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ fn get_message_from_block_offset<'a, R: Read + Seek>(
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))
}

fn get_message_from_block<'a, R: Read + Seek>(
pub(super) fn get_message_from_block<'a, R: Read + Seek>(
reader: &mut R,
block: &arrow_format::ipc::Block,
message_scratch: &'a mut Vec<u8>,
Expand Down
1 change: 1 addition & 0 deletions crates/polars-arrow/src/io/ipc/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod schema;
mod stream;

pub(crate) use common::first_dict_field;
pub use common::{prepare_projection, ProjectionInfo};
pub use error::OutOfSpecKind;
pub use file::{
deserialize_footer, get_row_count, read_batch, read_file_dictionaries, read_file_metadata,
Expand Down
90 changes: 80 additions & 10 deletions crates/polars-arrow/src/io/ipc/read/reader.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::io::{Read, Seek};

use polars_error::PolarsResult;
use polars_utils::aliases::PlHashMap;

use super::common::*;
use super::file::{get_message_from_block, get_record_batch};
use super::{read_batch, read_file_dictionaries, Dictionaries, FileMetadata};
use crate::array::Array;
use crate::datatypes::ArrowSchema;
Expand All @@ -16,7 +16,7 @@ pub struct FileReader<R: Read + Seek> {
// the dictionaries are going to be read
dictionaries: Option<Dictionaries>,
current_block: usize,
projection: Option<(Vec<usize>, PlHashMap<usize, usize>, ArrowSchema)>,
projection: Option<ProjectionInfo>,
remaining: usize,
data_scratch: Vec<u8>,
message_scratch: Vec<u8>,
Expand All @@ -32,10 +32,29 @@ impl<R: Read + Seek> FileReader<R> {
projection: Option<Vec<usize>>,
limit: Option<usize>,
) -> Self {
let projection = projection.map(|projection| {
let (p, h, schema) = prepare_projection(&metadata.schema, projection);
(p, h, schema)
});
let projection =
projection.map(|projection| prepare_projection(&metadata.schema, projection));
Self {
reader,
metadata,
dictionaries: Default::default(),
projection,
remaining: limit.unwrap_or(usize::MAX),
current_block: 0,
data_scratch: Default::default(),
message_scratch: Default::default(),
}
}

/// Creates a new [`FileReader`]. Use `projection` to only take certain columns.
/// # Panic
/// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid)
pub fn new_with_projection_info(
reader: R,
metadata: FileMetadata,
projection: Option<ProjectionInfo>,
limit: Option<usize>,
) -> Self {
Self {
reader,
metadata,
Expand All @@ -52,7 +71,7 @@ impl<R: Read + Seek> FileReader<R> {
pub fn schema(&self) -> &ArrowSchema {
self.projection
.as_ref()
.map(|x| &x.2)
.map(|x| &x.schema)
.unwrap_or(&self.metadata.schema)
}

Expand All @@ -66,9 +85,23 @@ impl<R: Read + Seek> FileReader<R> {
self.reader
}

pub fn set_current_block(&mut self, idx: usize) {
self.current_block = idx;
}

pub fn get_current_block(&self) -> usize {
self.current_block
}

/// Get the inner memory scratches so they can be reused in a new writer.
/// This can be utilized to save memory allocations for performance reasons.
pub fn take_projection_info(&mut self) -> Option<ProjectionInfo> {
std::mem::take(&mut self.projection)
}

/// Get the inner memory scratches so they can be reused in a new writer.
/// This can be utilized to save memory allocations for performance reasons.
pub fn get_scratches(&mut self) -> (Vec<u8>, Vec<u8>) {
pub fn take_scratches(&mut self) -> (Vec<u8>, Vec<u8>) {
(
std::mem::take(&mut self.data_scratch),
std::mem::take(&mut self.message_scratch),
Expand All @@ -91,6 +124,43 @@ impl<R: Read + Seek> FileReader<R> {
};
Ok(())
}

/// Skip over blocks until we have seen at most `offset` rows, returning how many rows we are
/// still too see.
///
/// This will never go over the `offset`. Meaning that if the `offset < current_block.len()`,
/// the block will not be skipped.
pub fn skip_blocks_till_limit(&mut self, offset: u64) -> PolarsResult<u64> {
let mut remaining_offset = offset;

for (i, block) in self.metadata.blocks.iter().enumerate() {
let message =
get_message_from_block(&mut self.reader, block, &mut self.message_scratch)?;
let record_batch = get_record_batch(message)?;

let length = record_batch.length()?;
let length = length as u64;

if length > remaining_offset {
self.current_block = i;
return Ok(remaining_offset);
}

remaining_offset -= length;
}

self.current_block = self.metadata.blocks.len();
Ok(remaining_offset)
}

pub fn next_record_batch(
&mut self,
) -> Option<PolarsResult<arrow_format::ipc::RecordBatchRef<'_>>> {
let block = self.metadata.blocks.get(self.current_block)?;
self.current_block += 1;
let message = get_message_from_block(&mut self.reader, block, &mut self.message_scratch);
Some(message.and_then(|m| get_record_batch(m)))
}
}

impl<R: Read + Seek> Iterator for FileReader<R> {
Expand All @@ -114,15 +184,15 @@ impl<R: Read + Seek> Iterator for FileReader<R> {
&mut self.reader,
self.dictionaries.as_ref().unwrap(),
&self.metadata,
self.projection.as_ref().map(|x| x.0.as_ref()),
self.projection.as_ref().map(|x| x.columns.as_ref()),
Some(self.remaining),
block,
&mut self.message_scratch,
&mut self.data_scratch,
);
self.remaining -= chunk.as_ref().map(|x| x.len()).unwrap_or_default();

let chunk = if let Some((_, map, _)) = &self.projection {
let chunk = if let Some(ProjectionInfo { map, .. }) = &self.projection {
// re-order according to projection
chunk.map(|chunk| apply_projection(chunk, map))
} else {
Expand Down
17 changes: 7 additions & 10 deletions crates/polars-arrow/src/io/ipc/read/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::io::Read;

use arrow_format::ipc::planus::ReadAsRoot;
use polars_error::{polars_bail, polars_err, PolarsError, PolarsResult};
use polars_utils::aliases::PlHashMap;

use super::super::CONTINUATION_MARKER;
use super::common::*;
Expand Down Expand Up @@ -93,7 +92,7 @@ fn read_next<R: Read>(
dictionaries: &mut Dictionaries,
message_buffer: &mut Vec<u8>,
data_buffer: &mut Vec<u8>,
projection: &Option<(Vec<usize>, PlHashMap<usize, usize>, ArrowSchema)>,
projection: &Option<ProjectionInfo>,
scratch: &mut Vec<u8>,
) -> PolarsResult<Option<StreamState>> {
// determine metadata length
Expand Down Expand Up @@ -169,7 +168,7 @@ fn read_next<R: Read>(
batch,
&metadata.schema,
&metadata.ipc_schema,
projection.as_ref().map(|x| x.0.as_ref()),
projection.as_ref().map(|x| x.columns.as_ref()),
None,
dictionaries,
metadata.version,
Expand All @@ -179,7 +178,7 @@ fn read_next<R: Read>(
scratch,
);

if let Some((_, map, _)) = projection {
if let Some(ProjectionInfo { map, .. }) = projection {
// re-order according to projection
chunk
.map(|chunk| apply_projection(chunk, map))
Expand Down Expand Up @@ -238,7 +237,7 @@ pub struct StreamReader<R: Read> {
finished: bool,
data_buffer: Vec<u8>,
message_buffer: Vec<u8>,
projection: Option<(Vec<usize>, PlHashMap<usize, usize>, ArrowSchema)>,
projection: Option<ProjectionInfo>,
scratch: Vec<u8>,
}

Expand All @@ -249,10 +248,8 @@ impl<R: Read> StreamReader<R> {
/// encounter a schema.
/// To check if the reader is done, use `is_finished(self)`
pub fn new(reader: R, metadata: StreamMetadata, projection: Option<Vec<usize>>) -> Self {
let projection = projection.map(|projection| {
let (p, h, schema) = prepare_projection(&metadata.schema, projection);
(p, h, schema)
});
let projection =
projection.map(|projection| prepare_projection(&metadata.schema, projection));

Self {
reader,
Expand All @@ -275,7 +272,7 @@ impl<R: Read> StreamReader<R> {
pub fn schema(&self) -> &ArrowSchema {
self.projection
.as_ref()
.map(|x| &x.2)
.map(|x| &x.schema)
.unwrap_or(&self.metadata.schema)
}

Expand Down
22 changes: 16 additions & 6 deletions crates/polars-arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::array::{Array, ArrayRef};
/// the same length, [`RecordBatchT::len`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RecordBatchT<A: AsRef<dyn Array>> {
length: usize,
height: usize,
arrays: Vec<A>,
}

Expand All @@ -29,14 +29,14 @@ impl<A: AsRef<dyn Array>> RecordBatchT<A> {
///
/// # Error
///
/// I.f.f. the length does not match the length of any of the arrays
pub fn try_new(length: usize, arrays: Vec<A>) -> PolarsResult<Self> {
/// I.f.f. the height does not match the length of any of the arrays
pub fn try_new(height: usize, arrays: Vec<A>) -> PolarsResult<Self> {
polars_ensure!(
arrays.iter().all(|arr| arr.as_ref().len() == length),
arrays.iter().all(|arr| arr.as_ref().len() == height),
ComputeError: "RecordBatch requires all its arrays to have an equal number of rows",
);

Ok(Self { length, arrays })
Ok(Self { height, arrays })
}

/// returns the [`Array`]s in [`RecordBatchT`]
Expand All @@ -51,7 +51,17 @@ impl<A: AsRef<dyn Array>> RecordBatchT<A> {

/// returns the number of rows of every array
pub fn len(&self) -> usize {
self.length
self.height
}

/// returns the number of rows of every array
pub fn height(&self) -> usize {
self.height
}

/// returns the number of arrays
pub fn width(&self) -> usize {
self.arrays.len()
}

/// returns whether the columns have any rows
Expand Down
26 changes: 26 additions & 0 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::borrow::Cow;
use std::{mem, ops};

use polars_row::ArrayRef;
use polars_utils::itertools::Itertools;
use rayon::prelude::*;

Expand Down Expand Up @@ -3334,6 +3335,31 @@ impl DataFrame {
pub(crate) fn infer_height(cols: &[Column]) -> usize {
cols.first().map_or(0, Column::len)
}

pub fn append_record_batch(&mut self, rb: RecordBatchT<ArrayRef>) -> PolarsResult<()> {
polars_ensure!(
rb.arrays().len() == self.width(),
InvalidOperation: "attempt to extend dataframe of width {} with record batch of width {}",
self.width(),
rb.arrays().len(),
);

if rb.height() == 0 {
return Ok(());
}

// SAFETY:
// - we don't adjust the names of the columns
// - each column gets appended the same number of rows, which is an invariant of
// record_batch.
let columns = unsafe { self.get_columns_mut() };
for (col, arr) in columns.iter_mut().zip(rb.into_arrays()) {
let arr_series = Series::from_arrow_chunks(PlSmallStr::EMPTY, vec![arr])?.into_column();
col.append(&arr_series)?;
}

Ok(())
}
}

pub struct RecordBatchIter<'a> {
Expand Down
28 changes: 28 additions & 0 deletions crates/polars-core/src/frame/upstream_traits.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::ops::{Index, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};

use arrow::record_batch::RecordBatchT;

use crate::prelude::*;

impl FromIterator<Series> for DataFrame {
Expand All @@ -22,6 +24,32 @@ impl FromIterator<Column> for DataFrame {
}
}

impl TryExtend<RecordBatchT<Box<dyn Array>>> for DataFrame {
fn try_extend<I: IntoIterator<Item = RecordBatchT<Box<dyn Array>>>>(
&mut self,
iter: I,
) -> PolarsResult<()> {
for record_batch in iter {
self.append_record_batch(record_batch)?;
}

Ok(())
}
}

impl TryExtend<PolarsResult<RecordBatchT<Box<dyn Array>>>> for DataFrame {
fn try_extend<I: IntoIterator<Item = PolarsResult<RecordBatchT<Box<dyn Array>>>>>(
&mut self,
iter: I,
) -> PolarsResult<()> {
for record_batch in iter {
self.append_record_batch(record_batch?)?;
}

Ok(())
}
}

impl Index<usize> for DataFrame {
type Output = Column;

Expand Down
Loading

0 comments on commit 8cb7839

Please sign in to comment.