Skip to content

Commit caeabc1

Browse files
tlm365alamb
andauthored
Optimize performance of math::trunc (#12909)
Signed-off-by: Tai Le Manh <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent 875aaa6 commit caeabc1

File tree

3 files changed

+104
-30
lines changed

3 files changed

+104
-30
lines changed

datafusion/functions/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,8 @@ required-features = ["math_expressions"]
201201
harness = false
202202
name = "strpos"
203203
required-features = ["unicode_expressions"]
204+
205+
[[bench]]
206+
harness = false
207+
name = "trunc"
208+
required-features = ["math_expressions"]

datafusion/functions/benches/trunc.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
extern crate criterion;
19+
20+
use arrow::{
21+
datatypes::{Float32Type, Float64Type},
22+
util::bench_util::create_primitive_array,
23+
};
24+
use criterion::{black_box, criterion_group, criterion_main, Criterion};
25+
use datafusion_expr::ColumnarValue;
26+
use datafusion_functions::math::trunc;
27+
28+
use std::sync::Arc;
29+
30+
fn criterion_benchmark(c: &mut Criterion) {
31+
let trunc = trunc();
32+
for size in [1024, 4096, 8192] {
33+
let f32_array = Arc::new(create_primitive_array::<Float32Type>(size, 0.2));
34+
let f32_args = vec![ColumnarValue::Array(f32_array)];
35+
c.bench_function(&format!("trunc f32 array: {}", size), |b| {
36+
b.iter(|| black_box(trunc.invoke(&f32_args).unwrap()))
37+
});
38+
let f64_array = Arc::new(create_primitive_array::<Float64Type>(size, 0.2));
39+
let f64_args = vec![ColumnarValue::Array(f64_array)];
40+
c.bench_function(&format!("trunc f64 array: {}", size), |b| {
41+
b.iter(|| black_box(trunc.invoke(&f64_args).unwrap()))
42+
});
43+
}
44+
}
45+
46+
criterion_group!(benches, criterion_benchmark);
47+
criterion_main!(benches);

datafusion/functions/src/math/trunc.rs

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ use std::sync::{Arc, OnceLock};
2020

2121
use crate::utils::make_scalar_function;
2222

23-
use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
24-
use arrow::datatypes::DataType;
23+
use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
2524
use arrow::datatypes::DataType::{Float32, Float64};
25+
use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type};
2626
use datafusion_common::ScalarValue::Int64;
27-
use datafusion_common::{exec_err, DataFusionError, Result};
27+
use datafusion_common::{exec_err, Result};
2828
use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
2929
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
3030
use datafusion_expr::TypeSignature::Exact;
@@ -139,44 +139,66 @@ fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
139139
);
140140
}
141141

142-
//if only one arg then invoke toolchain trunc(num) and precision = 0 by default
143-
//or then invoke the compute_truncate method to process precision
142+
// If only one arg then invoke toolchain trunc(num) and precision = 0 by default
143+
// or then invoke the compute_truncate method to process precision
144144
let num = &args[0];
145145
let precision = if args.len() == 1 {
146146
ColumnarValue::Scalar(Int64(Some(0)))
147147
} else {
148148
ColumnarValue::Array(Arc::clone(&args[1]))
149149
};
150150

151-
match args[0].data_type() {
151+
match num.data_type() {
152152
Float64 => match precision {
153-
ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new(
154-
make_function_scalar_inputs!(num, "num", Float64Array, { f64::trunc }),
155-
) as ArrayRef),
156-
ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!(
157-
num,
158-
precision,
159-
"x",
160-
"y",
161-
Float64Array,
162-
Int64Array,
163-
{ compute_truncate64 }
164-
)) as ArrayRef),
153+
ColumnarValue::Scalar(Int64(Some(0))) => {
154+
Ok(Arc::new(
155+
args[0]
156+
.as_primitive::<Float64Type>()
157+
.unary::<_, Float64Type>(|x: f64| {
158+
if x == 0_f64 {
159+
0_f64
160+
} else {
161+
x.trunc()
162+
}
163+
}),
164+
) as ArrayRef)
165+
}
166+
ColumnarValue::Array(precision) => {
167+
let num_array = num.as_primitive::<Float64Type>();
168+
let precision_array = precision.as_primitive::<Int64Type>();
169+
let result: PrimitiveArray<Float64Type> =
170+
arrow::compute::binary(num_array, precision_array, |x, y| {
171+
compute_truncate64(x, y)
172+
})?;
173+
174+
Ok(Arc::new(result) as ArrayRef)
175+
}
165176
_ => exec_err!("trunc function requires a scalar or array for precision"),
166177
},
167178
Float32 => match precision {
168-
ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new(
169-
make_function_scalar_inputs!(num, "num", Float32Array, { f32::trunc }),
170-
) as ArrayRef),
171-
ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!(
172-
num,
173-
precision,
174-
"x",
175-
"y",
176-
Float32Array,
177-
Int64Array,
178-
{ compute_truncate32 }
179-
)) as ArrayRef),
179+
ColumnarValue::Scalar(Int64(Some(0))) => {
180+
Ok(Arc::new(
181+
args[0]
182+
.as_primitive::<Float32Type>()
183+
.unary::<_, Float32Type>(|x: f32| {
184+
if x == 0_f32 {
185+
0_f32
186+
} else {
187+
x.trunc()
188+
}
189+
}),
190+
) as ArrayRef)
191+
}
192+
ColumnarValue::Array(precision) => {
193+
let num_array = num.as_primitive::<Float32Type>();
194+
let precision_array = precision.as_primitive::<Int64Type>();
195+
let result: PrimitiveArray<Float32Type> =
196+
arrow::compute::binary(num_array, precision_array, |x, y| {
197+
compute_truncate32(x, y)
198+
})?;
199+
200+
Ok(Arc::new(result) as ArrayRef)
201+
}
180202
_ => exec_err!("trunc function requires a scalar or array for precision"),
181203
},
182204
other => exec_err!("Unsupported data type {other:?} for function trunc"),

0 commit comments

Comments
 (0)