From 15259f9ef9a810de441964d5b760d3348bb72d29 Mon Sep 17 00:00:00 2001 From: Edward Davis Date: Sat, 13 Apr 2024 13:17:25 +1000 Subject: [PATCH] add checks --- .../polars-plan/src/dsl/function_expr/struct_.rs | 14 ++++++++------ py-polars/tests/unit/datatypes/test_struct.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/crates/polars-plan/src/dsl/function_expr/struct_.rs b/crates/polars-plan/src/dsl/function_expr/struct_.rs index fd9bed142d12..29e980a197b6 100644 --- a/crates/polars-plan/src/dsl/function_expr/struct_.rs +++ b/crates/polars-plan/src/dsl/function_expr/struct_.rs @@ -43,24 +43,25 @@ impl StructFunction { polars_bail!(StructFieldNotFound: "{}", name.as_ref()); } }), - RenameFields(names) => mapper.map_dtype(|dt| match dt { + RenameFields(names) => mapper.try_map_dtype(|dt| match dt { DataType::Struct(fields) => { + polars_ensure!(fields.len() == names.len(), ComputeError: "expected {} names, got {}", fields.len(), names.len()); let fields = fields .iter() .zip(names.as_ref()) .map(|(fld, name)| Field::new(name, fld.data_type().clone())) .collect(); - DataType::Struct(fields) + Ok(DataType::Struct(fields)) }, // The types will be incorrect, but its better than nothing // we can get an incorrect type with python lambdas, because we only know return type when running // the query - dt => DataType::Struct( + dt => Ok(DataType::Struct( names .iter() .map(|name| Field::new(name, dt.clone())) .collect(), - ), + )), }), PrefixFields(prefix) => mapper.try_map_dtype(|dt| match dt { DataType::Struct(fields) => { @@ -131,8 +132,9 @@ pub(super) fn get_by_name(s: &Series, name: Arc) -> PolarsResult { pub(super) fn rename_fields(s: &Series, names: Arc>) -> PolarsResult { let ca = s.struct_()?; - let fields = ca - .fields() + let fields = ca.fields(); + polars_ensure!(fields.len() == names.len(), ComputeError: "expected {} names, got {}", fields.len(), names.len()); + let fields = fields .iter() .zip(names.as_ref()) .map(|(s, name)| { diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 5f7c6850465b..2c6ba67ca50c 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -868,3 +868,18 @@ def test_struct_null_count_10130() -> None: s = pl.Series([{"a": None}]) assert s.null_count() == 1 + + +def test_struct_rename_mismatch_9052() -> None: + df = pl.DataFrame({"A": [{"p": 1, "q": 2}]}) + + with pytest.raises(pl.ComputeError, match=r"expected 2 names, got 1"): + df.select(pl.col("A").struct.rename_fields(["x"])) + + # Additional cases + # too many fields + with pytest.raises(pl.ComputeError, match=r"expected 2 names, got 3"): + df.select(pl.col("A").struct.rename_fields(["too", "many", "fields"])) + # during schema evaluation + with pytest.raises(pl.ComputeError, match=r"expected 2 names, got 1"): + df.lazy().select(pl.col("A").struct.rename_fields(["x"])).schema