Skip to content

Commit

Permalink
fix: Fix floordiv / modulo with scalar 0 on LHS (#19143)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Oct 8, 2024
1 parent 4c2546b commit 815e31f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 25 deletions.
30 changes: 15 additions & 15 deletions crates/polars-compute/src/arithmetic/signed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ macro_rules! impl_signed_arith_kernel {
}

fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {
if lhs == 0 {
return rhs.fill_with(0);
}

let mask = rhs.tot_ne_kernel_broadcast(&0);
let valid = combine_validities_and(rhs.validity(), Some(&mask));
let ret = prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).0);
let ret = if lhs == 0 {
rhs.fill_with(0)
} else {
prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).0)
};
ret.with_validity(valid)
}

Expand All @@ -165,13 +165,13 @@ macro_rules! impl_signed_arith_kernel {
}

fn prim_wrapping_trunc_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {
if lhs == 0 {
return rhs.fill_with(0);
}

let mask = rhs.tot_ne_kernel_broadcast(&0);
let valid = combine_validities_and(rhs.validity(), Some(&mask));
let ret = prim_unary_values(rhs, |x| if x != 0 { lhs.wrapping_div(x) } else { 0 });
let ret = if lhs == 0 {
rhs.fill_with(0)
} else {
prim_unary_values(rhs, |x| if x != 0 { lhs.wrapping_div(x) } else { 0 })
};
ret.with_validity(valid)
}

Expand Down Expand Up @@ -205,13 +205,13 @@ macro_rules! impl_signed_arith_kernel {
}

fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {
if lhs == 0 {
return rhs.fill_with(0);
}

let mask = rhs.tot_ne_kernel_broadcast(&0);
let valid = combine_validities_and(rhs.validity(), Some(&mask));
let ret = prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).1);
let ret = if lhs == 0 {
rhs.fill_with(0)
} else {
prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).1)
};
ret.with_validity(valid)
}

Expand Down
20 changes: 10 additions & 10 deletions crates/polars-compute/src/arithmetic/unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ macro_rules! impl_unsigned_arith_kernel {
}

fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {
if lhs == 0 {
return rhs.fill_with(0);
}

let mask = rhs.tot_ne_kernel_broadcast(&0);
let valid = combine_validities_and(rhs.validity(), Some(&mask));
let ret = prim_unary_values(rhs, |x| if x != 0 { lhs / x } else { 0 });
let ret = if lhs == 0 {
rhs.fill_with(0)
} else {
prim_unary_values(rhs, |x| if x != 0 { lhs / x } else { 0 })
};
ret.with_validity(valid)
}

Expand All @@ -125,13 +125,13 @@ macro_rules! impl_unsigned_arith_kernel {
}

fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> {
if lhs == 0 {
return rhs.fill_with(0);
}

let mask = rhs.tot_ne_kernel_broadcast(&0);
let valid = combine_validities_and(rhs.validity(), Some(&mask));
let ret = prim_unary_values(rhs, |x| if x != 0 { lhs % x } else { 0 });
let ret = if lhs == 0 {
rhs.fill_with(0)
} else {
prim_unary_values(rhs, |x| if x != 0 { lhs % x } else { 0 })
};
ret.with_validity(valid)
}

Expand Down
5 changes: 5 additions & 0 deletions py-polars/tests/unit/operations/arithmetic/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,3 +893,8 @@ def test_date_datetime_sub() -> None:
def test_raise_invalid_shape() -> None:
with pytest.raises(pl.exceptions.InvalidOperationError):
pl.DataFrame([[1, 2], [3, 4]]) * pl.DataFrame([1, 2, 3])


def test_integer_divide_scalar_zero_lhs_19142() -> None:
assert_series_equal(pl.Series([0]) // pl.Series([1, 0]), pl.Series([0, None]))
assert_series_equal(pl.Series([0]) % pl.Series([1, 0]), pl.Series([0, None]))

0 comments on commit 815e31f

Please sign in to comment.