From ad45545d802f3ebd1ccf6f08a565b29a9e62656d Mon Sep 17 00:00:00 2001 From: Lava <34743145+CanglongCl@users.noreply.github.com> Date: Mon, 8 Apr 2024 09:24:37 -0700 Subject: [PATCH] fix(python): dot product of two integer series is cast to float (#15502) --- py-polars/polars/series/series.py | 2 +- py-polars/src/series/mod.rs | 27 ++++++++++++++++++++--- py-polars/tests/unit/dataframe/test_df.py | 22 ++++++++++++++++++ 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index a88b571b9364..31a5577b1d4a 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4947,7 +4947,7 @@ def round_sig_figs(self, digits: int) -> Series: ] """ - def dot(self, other: Series | ArrayLike) -> float | None: + def dot(self, other: Series | ArrayLike) -> int | float | None: """ Compute the dot/inner product between two Series. diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index b3175bc9e731..7d33f8858152 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -604,9 +604,30 @@ impl PySeries { self.series.shrink_to_fit(); } - fn dot(&self, other: &PySeries) -> PyResult { - let out = self.series.dot(&other.series).map_err(PyPolarsErr::from)?; - Ok(out) + fn dot(&self, other: &PySeries, py: Python) -> PyResult { + let lhs_dtype = self.series.dtype(); + let rhs_dtype = other.series.dtype(); + + if !lhs_dtype.is_numeric() { + return Err(PyPolarsErr::from(polars_err!(opq = dot, lhs_dtype)).into()); + }; + if !rhs_dtype.is_numeric() { + return Err(PyPolarsErr::from(polars_err!(opq = dot, rhs_dtype)).into()); + } + + let result: AnyValue = if lhs_dtype.is_float() || rhs_dtype.is_float() { + (&self.series * &other.series) + .sum::() + .map_err(PyPolarsErr::from)? + .into() + } else { + (&self.series * &other.series) + .sum::() + .map_err(PyPolarsErr::from)? + .into() + }; + + Ok(Wrap(result).into_py(py)) } #[cfg(feature = "ipc_streaming")] diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index aaf95548a709..8f33a0fef9d1 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -1407,6 +1407,28 @@ def test_dot_product() -> None: assert df["a"].dot(df["b"]) == 20 assert typing.cast(int, df.select([pl.col("a").dot("b")])[0, "a"]) == 20 + result = pl.Series([1, 2, 3]) @ pl.Series([4, 5, 6]) + assert isinstance(result, int) + assert result == 32 + + result = pl.Series([1, 2, 3]) @ pl.Series([4.0, 5.0, 6.0]) + assert isinstance(result, float) + assert result == 32.0 + + result = pl.Series([1.0, 2.0, 3.0]) @ pl.Series([4.0, 5.0, 6.0]) + assert isinstance(result, float) + assert result == 32.0 + + with pytest.raises( + pl.InvalidOperationError, match="`dot` operation not supported for dtype `bool`" + ): + pl.Series([True, False, False, True]) @ pl.Series([4, 5, 6, 7]) + + with pytest.raises( + pl.InvalidOperationError, match="`dot` operation not supported for dtype `str`" + ): + pl.Series([1, 2, 3, 4]) @ pl.Series(["True", "False", "False", "True"]) + def test_hash_rows() -> None: df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [2, 2, 2, 2]})