From 4894e24d34a8388ef9155b24deab78ed29fe1ef5 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 15 Sep 2024 10:06:44 +0200 Subject: [PATCH] fix: Fix accidental raise on shape 1 (#18748) --- crates/polars-core/src/frame/mod.rs | 6 ++++++ crates/polars-mem-engine/src/executors/stack.rs | 6 +++--- py-polars/tests/unit/dataframe/test_extend.py | 11 +++++++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index a72da39d915d..dae71dd44ff0 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -813,6 +813,12 @@ impl DataFrame { self.shape().0 } + /// Returns the size as number of rows * number of columns + pub fn size(&self) -> usize { + let s = self.shape(); + s.0 * s.1 + } + /// Returns `true` if the [`DataFrame`] contains no rows. /// /// # Example diff --git a/crates/polars-mem-engine/src/executors/stack.rs b/crates/polars-mem-engine/src/executors/stack.rs index 43c884b1f563..074bcbaf9e79 100644 --- a/crates/polars-mem-engine/src/executors/stack.rs +++ b/crates/polars-mem-engine/src/executors/stack.rs @@ -66,7 +66,7 @@ impl StackExec { // possibly mismatching column lengths. unsafe { df.get_columns_mut() }.extend(res.into_iter().map(Column::from)); } else { - let height = df.height(); + let (df_height, df_width) = df.shape(); // When we have CSE we cannot verify scalars yet. let verify_scalar = if !df.get_columns().is_empty() { @@ -78,11 +78,11 @@ impl StackExec { }; for (i, c) in res.iter().enumerate() { let len = c.len(); - if verify_scalar && len != height && len == 1 { + if verify_scalar && len != df_height && len == 1 && df_width > 0 { polars_ensure!(self.exprs[i].is_scalar(), InvalidOperation: "Series {}, length {} doesn't match the DataFrame height of {}\n\n\ If you want this Series to be broadcasted, ensure it is a scalar (for instance by adding '.first()').", - c.name(), len, height + c.name(), len, df_height ); } } diff --git a/py-polars/tests/unit/dataframe/test_extend.py b/py-polars/tests/unit/dataframe/test_extend.py index a1f4a451d627..9d4f46c95e2d 100644 --- a/py-polars/tests/unit/dataframe/test_extend.py +++ b/py-polars/tests/unit/dataframe/test_extend.py @@ -84,3 +84,14 @@ def test_extend_column_name_mismatch() -> None: with pytest.raises(ShapeError): df1.extend(df2) + + +def test_initialize_df_18736() -> None: + # Completely empty initialization + df = pl.DataFrame() + s_0 = pl.Series([]) + s_1 = pl.Series([None]) + s_2 = pl.Series([None, None]) + assert df.with_columns(s_0).shape == (0, 1) + assert df.with_columns(s_1).shape == (1, 1) + assert df.with_columns(s_2).shape == (2, 1)