From 477a80e28fe519c02843f22dd072906385077819 Mon Sep 17 00:00:00 2001 From: barak1412 Date: Mon, 30 Sep 2024 18:49:20 +0300 Subject: [PATCH] fix: Force nested struct `missing` equality (#19031) --- .../src/chunked_array/comparison/mod.rs | 6 +++--- .../tests/unit/operations/test_comparison.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 13f557d20886..0a58d5192d41 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -780,7 +780,7 @@ where .reduce(reduce) .unwrap(); - if !is_missing & (a.null_count() > 0 || b.null_count() > 0) { + if !is_missing && (a.null_count() > 0 || b.null_count() > 0) { let mut a = a.into_owned(); a.zip_outer_validity(&b); unsafe { @@ -801,7 +801,7 @@ impl ChunkCompareEq<&StructChunked> for StructChunked { struct_helper( self, rhs, - |l, r| l.equal(r).unwrap(), + |l, r| l.equal_missing(r).unwrap(), |a, b| a.bitand(b), false, false, @@ -823,7 +823,7 @@ impl ChunkCompareEq<&StructChunked> for StructChunked { struct_helper( self, rhs, - |l, r| l.not_equal(r).unwrap(), + |l, r| l.not_equal_missing(r).unwrap(), |a, b| a | b, true, false, diff --git a/py-polars/tests/unit/operations/test_comparison.py b/py-polars/tests/unit/operations/test_comparison.py index 71ddf6157e99..523f2050ca36 100644 --- a/py-polars/tests/unit/operations/test_comparison.py +++ b/py-polars/tests/unit/operations/test_comparison.py @@ -184,6 +184,25 @@ def test_struct_equality_18870() -> None: assert result == expected +def test_struct_nested_equality() -> None: + df = pl.DataFrame( + { + "a": [{"foo": 0, "bar": "1"}, {"foo": None, "bar": "1"}, None], + "b": [{"foo": 0, "bar": "1"}] * 3, + } + ) + + # eq + ans = df.select(pl.col("a").eq(pl.col("b"))) + expected = pl.DataFrame({"a": [True, False, None]}) + assert_frame_equal(ans, expected) + + # ne + ans = df.select(pl.col("a").ne(pl.col("b"))) + expected = pl.DataFrame({"a": [False, True, None]}) + assert_frame_equal(ans, expected) + + def isnan(x: Any) -> bool: return isinstance(x, float) and math.isnan(x)