Skip to content

Commit 8aad936

Browse files
Compute gcd with u64 instead of i64 because of overflows (#11036)
* compute gcd with unsigned ints * add test for the i64::MAX cases * move unsigned_abs below zero test to remove unnecessary casts * add slt test for gcd on max values instead of unit tests
1 parent 098ba30 commit 8aad936

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

datafusion/functions/src/math/gcd.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,16 @@ fn gcd(args: &[ArrayRef]) -> Result<ArrayRef> {
8888

8989
/// Computes greatest common divisor using Binary GCD algorithm.
9090
pub fn compute_gcd(x: i64, y: i64) -> i64 {
91-
let mut a = x.wrapping_abs();
92-
let mut b = y.wrapping_abs();
93-
94-
if a == 0 {
95-
return b;
91+
if x == 0 {
92+
return y;
9693
}
97-
if b == 0 {
98-
return a;
94+
if y == 0 {
95+
return x;
9996
}
10097

98+
let mut a = x.unsigned_abs();
99+
let mut b = y.unsigned_abs();
100+
101101
let shift = (a | b).trailing_zeros();
102102
a >>= shift;
103103
b >>= shift;
@@ -112,7 +112,8 @@ pub fn compute_gcd(x: i64, y: i64) -> i64 {
112112
b -= a;
113113

114114
if b == 0 {
115-
return a << shift;
115+
// because the input values are i64, casting this back to i64 is safe
116+
return (a << shift) as i64;
116117
}
117118
}
118119
}

datafusion/sqllogictest/test_files/scalar.slt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,16 @@ select gcd(null, null);
474474
----
475475
NULL
476476

477+
# scalar maxes and/or negative 1
478+
query III rowsort
479+
select
480+
gcd(9223372036854775807, -9223372036854775808), -- i64::MIN, i64::MAX
481+
-- wait till fix, cause it fails gcd(-9223372036854775808, -9223372036854775808), -- -i64::MIN, i64::MIN
482+
gcd(9223372036854775807, -1), -- i64::MAX, -1
483+
gcd(-9223372036854775808, -1); -- i64::MIN, -1
484+
----
485+
1 1 1
486+
477487
# gcd with columns
478488
query III rowsort
479489
select gcd(a, b), gcd(c, d), gcd(e, f) from signed_integers;

0 commit comments

Comments
 (0)