Skip to content

Commit

Permalink
feat: natively support more data types for the abs function. (#7568)
Browse files Browse the repository at this point in the history
* refactor: the return type of `abs()` should be the same as the input type

* add sqllogic tests
  • Loading branch information
jonahgao authored Sep 18, 2023
1 parent 5718a3f commit d512bee
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 30 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ pub fn make_partition(sz: i32) -> RecordBatch {

/// Specialised String representation
fn col_str(column: &ArrayRef, row_index: usize) -> String {
if column.is_null(row_index) {
if column.data_type() == &DataType::Null || column.is_null(row_index) {
return "NULL".to_string();
}

Expand Down
10 changes: 6 additions & 4 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -789,8 +789,9 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::ArrowTypeof => Ok(Utf8),

BuiltinScalarFunction::Abs
| BuiltinScalarFunction::Acos
BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()),

BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
| BuiltinScalarFunction::Acosh
Expand Down Expand Up @@ -1162,8 +1163,9 @@ impl BuiltinScalarFunction {
Signature::uniform(2, vec![Int64], self.volatility())
}
BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()),
BuiltinScalarFunction::Abs
| BuiltinScalarFunction::Acos
BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()),

BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
| BuiltinScalarFunction::Acosh
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -883,14 +883,14 @@ mod test {
fn scalar_function() -> Result<()> {
let empty = empty();
let lit_expr = lit(10i64);
let fun: BuiltinScalarFunction = BuiltinScalarFunction::Abs;
let fun: BuiltinScalarFunction = BuiltinScalarFunction::Acos;
let scalar_function_expr =
Expr::ScalarFunction(ScalarFunction::new(fun, vec![lit_expr]));
let plan = LogicalPlan::Projection(Projection::try_new(
vec![scalar_function_expr],
empty,
)?);
let expected = "Projection: abs(CAST(Int64(10) AS Float64))\n EmptyRelation";
let expected = "Projection: acos(CAST(Int64(10) AS Float64))\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
}

Expand Down
6 changes: 5 additions & 1 deletion datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ pub fn create_physical_expr(
)))))
})
}
BuiltinScalarFunction::Abs => {
let input_data_type = input_phy_exprs[0].data_type(input_schema)?;
let abs_fun = math_expressions::create_abs_function(&input_data_type)?;
Arc::new(move |args| make_scalar_function(abs_fun)(args))
}
// These don't need args and input schema
_ => create_physical_fun(fun, execution_props)?,
};
Expand Down Expand Up @@ -360,7 +365,6 @@ pub fn create_physical_fun(
) -> Result<ScalarFunctionImplementation> {
Ok(match fun {
// math functions
BuiltinScalarFunction::Abs => Arc::new(math_expressions::abs),
BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos),
BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin),
BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan),
Expand Down
74 changes: 72 additions & 2 deletions datafusion/physical-expr/src/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
//! Math expressions

use arrow::array::ArrayRef;
use arrow::array::{BooleanArray, Float32Array, Float64Array, Int64Array};
use arrow::array::{
BooleanArray, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array,
};
use arrow::datatypes::DataType;
use datafusion_common::internal_err;
use arrow::error::ArrowError;
use datafusion_common::ScalarValue;
use datafusion_common::ScalarValue::{Float32, Int64};
use datafusion_common::{internal_err, not_impl_err};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use rand::{thread_rng, Rng};
Expand All @@ -31,6 +35,8 @@ use std::iter;
use std::mem::swap;
use std::sync::Arc;

type MathArrayFunction = fn(&[ArrayRef]) -> Result<ArrayRef>;

macro_rules! downcast_compute_op {
($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident) => {{
let n = $ARRAY.as_any().downcast_ref::<$TYPE>();
Expand Down Expand Up @@ -667,6 +673,70 @@ fn compute_truncate64(x: f64, y: i64) -> f64 {
(x * factor).round() / factor
}

macro_rules! make_abs_function {
($ARRAY_TYPE:ident) => {{
|args: &[ArrayRef]| {
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
let res: $ARRAY_TYPE = array.unary(|x| x.abs());
Ok(Arc::new(res) as ArrayRef)
}
}};
}

macro_rules! make_try_abs_function {
($ARRAY_TYPE:ident) => {{
|args: &[ArrayRef]| {
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
let res: $ARRAY_TYPE = array.try_unary(|x| {
x.checked_abs().ok_or_else(|| {
ArrowError::ComputeError(format!(
"{} overflow on abs({})",
stringify!($ARRAY_TYPE),
x
))
})
})?;
Ok(Arc::new(res) as ArrayRef)
}
}};
}

/// Abs SQL function
/// Return different implementations based on input datatype to reduce branches during execution
pub(super) fn create_abs_function(
input_data_type: &DataType,
) -> Result<MathArrayFunction> {
match input_data_type {
DataType::Float32 => Ok(make_abs_function!(Float32Array)),
DataType::Float64 => Ok(make_abs_function!(Float64Array)),

// Types that may overflow, such as abs(-128_i8).
DataType::Int8 => Ok(make_try_abs_function!(Int8Array)),
DataType::Int16 => Ok(make_try_abs_function!(Int16Array)),
DataType::Int32 => Ok(make_try_abs_function!(Int32Array)),
DataType::Int64 => Ok(make_try_abs_function!(Int64Array)),

// Types of results are the same as the input.
DataType::Null
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => Ok(|args: &[ArrayRef]| Ok(args[0].clone())),

// Decimal should keep the same precision and scale by using `with_data_type()`.
// https://github.com/apache/arrow-rs/issues/4644
DataType::Decimal128(_, _) => Ok(|args: &[ArrayRef]| {
let array = downcast_arg!(&args[0], "abs arg", Decimal128Array);
let res: Decimal128Array = array
.unary(i128::abs)
.with_data_type(args[0].data_type().clone());
Ok(Arc::new(res) as ArrayRef)
}),

other => not_impl_err!("Unsupported data type {other:?} for function abs"),
}
}

#[cfg(test)]
mod tests {

Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1835,12 +1835,12 @@ mod roundtrip_tests {
let execution_props = ExecutionProps::new();

let fun_expr = functions::create_physical_fun(
&BuiltinScalarFunction::Abs,
&BuiltinScalarFunction::Acos,
&execution_props,
)?;

let expr = ScalarFunctionExpr::new(
"abs",
"acos",
fun_expr,
vec![col("a", &schema)?],
&DataType::Int64,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/decimal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ select c1%c5 from decimal_simple;
query T
select arrow_typeof(abs(c1)) from decimal_simple limit 1;
----
Float64
Decimal128(10, 6)


query R rowsort
Expand Down
Loading

0 comments on commit d512bee

Please sign in to comment.