Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: rechunk cross join output in streaming #12511

Merged
merged 2 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions crates/polars-pipe/src/executors/sinks/joins/cross.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,13 @@ impl Operator for CrossJoinProbe {
let iter_right = self.in_process_right.as_mut().unwrap();
let offset = iter_right.next().unwrap();
let right_df = chunk.data.slice(offset as i64, size);
let df = self.in_process_left_df.cross_join(
let mut df = self.in_process_left_df.cross_join(
&right_df,
Some(self.suffix.as_ref()),
None,
)?;
// Cross joins can produce multiple chunks.
df.as_single_chunk_par();
Ok(OperatorResult::HaveMoreOutPut(chunk.with_data(df)))
},
}
Expand All @@ -135,7 +137,7 @@ impl Operator for CrossJoinProbe {

// we use the first join to determine the output names
// this we can amortize the name allocations.
let df = match &self.output_names {
let mut df = match &self.output_names {
None => {
let df = self.in_process_left_df.cross_join(
&right_df,
Expand All @@ -149,6 +151,8 @@ impl Operator for CrossJoinProbe {
.in_process_left_df
._cross_join_with_names(&right_df, names)?,
};
// Cross joins can produce multiple chunks.
df.as_single_chunk_par();

Ok(OperatorResult::HaveMoreOutPut(chunk.with_data(df)))
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ impl SortSinkMultiple {
)
};

debug_assert_eq!(column.chunks().len(), 1);
// Safety: length is correct
unsafe { chunk.data.with_column_unchecked(column) };
Ok(())
Expand Down
12 changes: 8 additions & 4 deletions crates/polars-pipe/src/operators/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ pub struct DataChunk {

impl DataChunk {
pub(crate) fn new(chunk_index: IdxSize, data: DataFrame) -> Self {
// Check the invariant that all columns have a single chunk.
#[cfg(debug_assertions)]
{
for c in data.get_columns() {
assert_eq!(c.chunks().len(), 1);
}
}
Self { chunk_index, data }
}
pub(crate) fn with_data(&self, data: DataFrame) -> Self {
DataChunk {
chunk_index: self.chunk_index,
data,
}
Self::new(self.chunk_index, data)
}
pub(crate) fn is_empty(&self) -> bool {
self.data.height() == 0
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/streaming/test_streaming_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,16 @@ def test_streaming_cross_join_empty() -> None:
).collect(streaming=True)
assert out.shape == (0, 2)
assert out.columns == ["col1", "col1_right"]


def test_streaming_join_rechunk_12498() -> None:
rows = pl.int_range(0, 2)

a = pl.select(A=rows).lazy()
b = pl.select(B=rows).lazy()

q = a.join(b, how="cross")
assert q.collect(streaming=True).to_dict(as_series=False) == {
"A": [0, 1, 0, 1],
"B": [0, 0, 1, 1],
}