Skip to content

Commit

Permalink
fix(python): dot product of two integer series is cast to float (#15502)
Browse files Browse the repository at this point in the history
  • Loading branch information
CanglongCl authored Apr 8, 2024
1 parent a8c9738 commit ad45545
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4947,7 +4947,7 @@ def round_sig_figs(self, digits: int) -> Series:
]
"""

def dot(self, other: Series | ArrayLike) -> float | None:
def dot(self, other: Series | ArrayLike) -> int | float | None:
"""
Compute the dot/inner product between two Series.
Expand Down
27 changes: 24 additions & 3 deletions py-polars/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,30 @@ impl PySeries {
self.series.shrink_to_fit();
}

fn dot(&self, other: &PySeries) -> PyResult<f64> {
let out = self.series.dot(&other.series).map_err(PyPolarsErr::from)?;
Ok(out)
fn dot(&self, other: &PySeries, py: Python) -> PyResult<PyObject> {
let lhs_dtype = self.series.dtype();
let rhs_dtype = other.series.dtype();

if !lhs_dtype.is_numeric() {
return Err(PyPolarsErr::from(polars_err!(opq = dot, lhs_dtype)).into());
};
if !rhs_dtype.is_numeric() {
return Err(PyPolarsErr::from(polars_err!(opq = dot, rhs_dtype)).into());
}

let result: AnyValue = if lhs_dtype.is_float() || rhs_dtype.is_float() {
(&self.series * &other.series)
.sum::<f64>()
.map_err(PyPolarsErr::from)?
.into()
} else {
(&self.series * &other.series)
.sum::<i64>()
.map_err(PyPolarsErr::from)?
.into()
};

Ok(Wrap(result).into_py(py))
}

#[cfg(feature = "ipc_streaming")]
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,28 @@ def test_dot_product() -> None:
assert df["a"].dot(df["b"]) == 20
assert typing.cast(int, df.select([pl.col("a").dot("b")])[0, "a"]) == 20

result = pl.Series([1, 2, 3]) @ pl.Series([4, 5, 6])
assert isinstance(result, int)
assert result == 32

result = pl.Series([1, 2, 3]) @ pl.Series([4.0, 5.0, 6.0])
assert isinstance(result, float)
assert result == 32.0

result = pl.Series([1.0, 2.0, 3.0]) @ pl.Series([4.0, 5.0, 6.0])
assert isinstance(result, float)
assert result == 32.0

with pytest.raises(
pl.InvalidOperationError, match="`dot` operation not supported for dtype `bool`"
):
pl.Series([True, False, False, True]) @ pl.Series([4, 5, 6, 7])

with pytest.raises(
pl.InvalidOperationError, match="`dot` operation not supported for dtype `str`"
):
pl.Series([1, 2, 3, 4]) @ pl.Series(["True", "False", "False", "True"])


def test_hash_rows() -> None:
df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [2, 2, 2, 2]})
Expand Down

0 comments on commit ad45545

Please sign in to comment.