Skip to content

Commit

Permalink
only check when no cse
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 10, 2024
1 parent e1735d9 commit 6c47e3e
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 42 deletions.
4 changes: 2 additions & 2 deletions crates/polars-mem-engine/src/executors/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl ProjectionExec {
self.has_windows,
self.options.run_parallel,
)?;
check_expand_literals(&self.expr, selected_cols, df.is_empty(), self.options)
check_expand_literals(&df, &self.expr, selected_cols, df.is_empty(), self.options)
});

let df = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;
Expand All @@ -53,7 +53,7 @@ impl ProjectionExec {
self.has_windows,
self.options.run_parallel,
)?;
check_expand_literals(&self.expr, selected_cols, df.is_empty(), self.options)?
check_expand_literals(&df, &self.expr, selected_cols, df.is_empty(), self.options)?
};

// this only runs during testing and check if the runtime type matches the predicted schema
Expand Down
21 changes: 18 additions & 3 deletions crates/polars-mem-engine/src/executors/projection_utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use polars_plan::constants::CSE_REPLACED;
use polars_utils::itertools::Itertools;

use super::*;
Expand Down Expand Up @@ -243,6 +244,7 @@ pub(super) fn evaluate_physical_expressions(
}

pub(super) fn check_expand_literals(
df: &DataFrame,
phys_expr: &[Arc<dyn PhysicalExpr>],
mut selected_columns: Vec<Series>,
zero_length: bool,
Expand All @@ -253,6 +255,16 @@ pub(super) fn check_expand_literals(
};
let duplicate_check = options.duplicate_check;
let should_broadcast = options.should_broadcast;

// When we have CSE we cannot verify scalars yet.
let verify_scalar = if !df.get_columns().is_empty() {
!df.get_columns()[df.width() - 1]
.name()
.starts_with(CSE_REPLACED)
} else {
true
};

let mut df_height = 0;
let mut has_empty = false;
let mut all_equal_len = true;
Expand Down Expand Up @@ -297,11 +309,14 @@ pub(super) fn check_expand_literals(
} else if df_height == 1 {
series
} else {
polars_ensure!(phys.is_scalar(),
InvalidOperation: "Series length {} doesn't match the DataFrame height of {}\n\n\
if verify_scalar {
polars_ensure!(phys.is_scalar(),
InvalidOperation: "Series: {}, length {} doesn't match the DataFrame height of {}\n\n\
If you want this Series to be broadcasted, ensure it is a scalar (for instance by adding '.first()').",
series.len(), df_height
series.name(), series.len(), df_height
);

}
series.new_from_index(0, df_height)
}
},
Expand Down
97 changes: 60 additions & 37 deletions crates/polars-mem-engine/src/executors/stack.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use polars_core::utils::accumulate_dataframes_vertical_unchecked;
use polars_plan::constants::CSE_REPLACED;

use super::*;

Expand All @@ -21,52 +22,74 @@ impl StackExec {
let schema = &*self.input_schema;

// Vertical and horizontal parallelism.
let df =
if self.streamable && df.n_chunks() > 1 && df.height() > 0 && self.options.run_parallel
{
let chunks = df.split_chunks().collect::<Vec<_>>();
let iter = chunks.into_par_iter().map(|mut df| {
let res = evaluate_physical_expressions(
&mut df,
&self.exprs,
state,
self.has_windows,
self.options.run_parallel,
)?;
// We don't have to do a broadcast check as cse is not allowed to hit this.
df._add_columns(res, schema)?;
Ok(df)
});

let df = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;
accumulate_dataframes_vertical_unchecked(df)
}
// Only horizontal parallelism
else {
let df = if self.streamable
&& df.n_chunks() > 1
&& df.height() > 0
&& self.options.run_parallel
{
let chunks = df.split_chunks().collect::<Vec<_>>();
let iter = chunks.into_par_iter().map(|mut df| {
let res = evaluate_physical_expressions(
&mut df,
&self.exprs,
state,
self.has_windows,
self.options.run_parallel,
)?;
if !self.options.should_broadcast {
debug_assert!(
res.iter()
.all(|column| column.name().starts_with("__POLARS_CSER_0x")),
"non-broadcasting hstack should only be used for CSE columns"
);
// Safety: this case only appears as a result of
// CSE optimization, and the usage there produces
// new, unique column names. It is immediately
// followed by a projection which pulls out the
// possibly mismatching column lengths.
unsafe { df.get_columns_mut().extend(res) };
// We don't have to do a broadcast check as cse is not allowed to hit this.
df._add_columns(res, schema)?;
Ok(df)
});

let df = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;
accumulate_dataframes_vertical_unchecked(df)
}
// Only horizontal parallelism
else {
let res = evaluate_physical_expressions(
&mut df,
&self.exprs,
state,
self.has_windows,
self.options.run_parallel,
)?;
if !self.options.should_broadcast {
debug_assert!(
res.iter()
.all(|column| column.name().starts_with("__POLARS_CSER_0x")),
"non-broadcasting hstack should only be used for CSE columns"
);
// Safety: this case only appears as a result of
// CSE optimization, and the usage there produces
// new, unique column names. It is immediately
// followed by a projection which pulls out the
// possibly mismatching column lengths.
unsafe { df.get_columns_mut().extend(res) };
} else {
let height = df.height();

// When we have CSE we cannot verify scalars yet.
let verify_scalar = if !df.get_columns().is_empty() {
!df.get_columns()[df.width() - 1]
.name()
.starts_with(CSE_REPLACED)
} else {
df._add_columns(res, schema)?;
true
};
for (i, c) in res.iter().enumerate() {
let len = c.len();
if verify_scalar && len != height && len == 1 {
polars_ensure!(self.exprs[i].is_scalar(),
InvalidOperation: "Series {}, length {} doesn't match the DataFrame height of {}\n\n\
If you want this Series to be broadcasted, ensure it is a scalar (for instance by adding '.first()').",
c.name(), len, height
);
}
}
df
};
df._add_columns(res, schema)?;
}
df
};

state.clear_window_expr_cache();

Expand Down

0 comments on commit 6c47e3e

Please sign in to comment.