@@ -20,11 +20,11 @@ use std::sync::{Arc, OnceLock};
20
20
21
21
use crate :: utils:: make_scalar_function;
22
22
23
- use arrow:: array:: { ArrayRef , Float32Array , Float64Array , Int64Array } ;
24
- use arrow:: datatypes:: DataType ;
23
+ use arrow:: array:: { ArrayRef , AsArray , PrimitiveArray } ;
25
24
use arrow:: datatypes:: DataType :: { Float32 , Float64 } ;
25
+ use arrow:: datatypes:: { DataType , Float32Type , Float64Type , Int64Type } ;
26
26
use datafusion_common:: ScalarValue :: Int64 ;
27
- use datafusion_common:: { exec_err, DataFusionError , Result } ;
27
+ use datafusion_common:: { exec_err, Result } ;
28
28
use datafusion_expr:: scalar_doc_sections:: DOC_SECTION_MATH ;
29
29
use datafusion_expr:: sort_properties:: { ExprProperties , SortProperties } ;
30
30
use datafusion_expr:: TypeSignature :: Exact ;
@@ -139,44 +139,66 @@ fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
139
139
) ;
140
140
}
141
141
142
- //if only one arg then invoke toolchain trunc(num) and precision = 0 by default
143
- //or then invoke the compute_truncate method to process precision
142
+ // If only one arg then invoke toolchain trunc(num) and precision = 0 by default
143
+ // or then invoke the compute_truncate method to process precision
144
144
let num = & args[ 0 ] ;
145
145
let precision = if args. len ( ) == 1 {
146
146
ColumnarValue :: Scalar ( Int64 ( Some ( 0 ) ) )
147
147
} else {
148
148
ColumnarValue :: Array ( Arc :: clone ( & args[ 1 ] ) )
149
149
} ;
150
150
151
- match args [ 0 ] . data_type ( ) {
151
+ match num . data_type ( ) {
152
152
Float64 => match precision {
153
- ColumnarValue :: Scalar ( Int64 ( Some ( 0 ) ) ) => Ok ( Arc :: new (
154
- make_function_scalar_inputs ! ( num, "num" , Float64Array , { f64 :: trunc } ) ,
155
- ) as ArrayRef ) ,
156
- ColumnarValue :: Array ( precision) => Ok ( Arc :: new ( make_function_inputs2 ! (
157
- num,
158
- precision,
159
- "x" ,
160
- "y" ,
161
- Float64Array ,
162
- Int64Array ,
163
- { compute_truncate64 }
164
- ) ) as ArrayRef ) ,
153
+ ColumnarValue :: Scalar ( Int64 ( Some ( 0 ) ) ) => {
154
+ Ok ( Arc :: new (
155
+ args[ 0 ]
156
+ . as_primitive :: < Float64Type > ( )
157
+ . unary :: < _ , Float64Type > ( |x : f64 | {
158
+ if x == 0_f64 {
159
+ 0_f64
160
+ } else {
161
+ x. trunc ( )
162
+ }
163
+ } ) ,
164
+ ) as ArrayRef )
165
+ }
166
+ ColumnarValue :: Array ( precision) => {
167
+ let num_array = num. as_primitive :: < Float64Type > ( ) ;
168
+ let precision_array = precision. as_primitive :: < Int64Type > ( ) ;
169
+ let result: PrimitiveArray < Float64Type > =
170
+ arrow:: compute:: binary ( num_array, precision_array, |x, y| {
171
+ compute_truncate64 ( x, y)
172
+ } ) ?;
173
+
174
+ Ok ( Arc :: new ( result) as ArrayRef )
175
+ }
165
176
_ => exec_err ! ( "trunc function requires a scalar or array for precision" ) ,
166
177
} ,
167
178
Float32 => match precision {
168
- ColumnarValue :: Scalar ( Int64 ( Some ( 0 ) ) ) => Ok ( Arc :: new (
169
- make_function_scalar_inputs ! ( num, "num" , Float32Array , { f32 :: trunc } ) ,
170
- ) as ArrayRef ) ,
171
- ColumnarValue :: Array ( precision) => Ok ( Arc :: new ( make_function_inputs2 ! (
172
- num,
173
- precision,
174
- "x" ,
175
- "y" ,
176
- Float32Array ,
177
- Int64Array ,
178
- { compute_truncate32 }
179
- ) ) as ArrayRef ) ,
179
+ ColumnarValue :: Scalar ( Int64 ( Some ( 0 ) ) ) => {
180
+ Ok ( Arc :: new (
181
+ args[ 0 ]
182
+ . as_primitive :: < Float32Type > ( )
183
+ . unary :: < _ , Float32Type > ( |x : f32 | {
184
+ if x == 0_f32 {
185
+ 0_f32
186
+ } else {
187
+ x. trunc ( )
188
+ }
189
+ } ) ,
190
+ ) as ArrayRef )
191
+ }
192
+ ColumnarValue :: Array ( precision) => {
193
+ let num_array = num. as_primitive :: < Float32Type > ( ) ;
194
+ let precision_array = precision. as_primitive :: < Int64Type > ( ) ;
195
+ let result: PrimitiveArray < Float32Type > =
196
+ arrow:: compute:: binary ( num_array, precision_array, |x, y| {
197
+ compute_truncate32 ( x, y)
198
+ } ) ?;
199
+
200
+ Ok ( Arc :: new ( result) as ArrayRef )
201
+ }
180
202
_ => exec_err ! ( "trunc function requires a scalar or array for precision" ) ,
181
203
} ,
182
204
other => exec_err ! ( "Unsupported data type {other:?} for function trunc" ) ,
0 commit comments