@@ -27,9 +27,11 @@ use arrow::record_batch::RecordBatch;
27
27
use datafusion:: error:: Result ;
28
28
use datafusion:: logical_expr:: Volatility ;
29
29
use datafusion:: prelude:: * ;
30
- use datafusion_common:: { internal_err, ScalarValue } ;
30
+ use datafusion_common:: { exec_err , internal_err, ScalarValue } ;
31
31
use datafusion_expr:: sort_properties:: { ExprProperties , SortProperties } ;
32
- use datafusion_expr:: { ColumnarValue , ScalarUDF , ScalarUDFImpl , Signature } ;
32
+ use datafusion_expr:: {
33
+ ColumnarValue , ScalarFunctionArgs , ScalarUDF , ScalarUDFImpl , Signature ,
34
+ } ;
33
35
34
36
/// This example shows how to use the full ScalarUDFImpl API to implement a user
35
37
/// defined function. As in the `simple_udf.rs` example, this struct implements
@@ -83,23 +85,27 @@ impl ScalarUDFImpl for PowUdf {
83
85
Ok ( DataType :: Float64 )
84
86
}
85
87
86
- /// This is the function that actually calculates the results.
88
+ /// This function actually calculates the results of the scalar function.
89
+ ///
90
+ /// This is the same way that functions provided with DataFusion are invoked,
91
+ /// which permits important special cases:
87
92
///
88
- /// This is the same way that functions built into DataFusion are invoked,
89
- /// which permits important special cases when one or both of the arguments
90
- /// are single values (constants). For example `pow(a, 2)`
93
+ ///1. When one or both of the arguments are single values (constants).
94
+ /// For example `pow(a, 2)`
95
+ /// 2. When the input arrays can be reused (avoid allocating a new output array)
91
96
///
92
97
/// However, it also means the implementation is more complex than when
93
98
/// using `create_udf`.
94
- fn invoke_batch (
95
- & self ,
96
- args : & [ ColumnarValue ] ,
97
- _number_rows : usize ,
98
- ) -> Result < ColumnarValue > {
99
+ fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
100
+ // The other fields of the `args` struct are used for more specialized
101
+ // uses, and are not needed in this example
102
+ let ScalarFunctionArgs { mut args, .. } = args;
99
103
// DataFusion has arranged for the correct inputs to be passed to this
100
104
// function, but we check again to make sure
101
105
assert_eq ! ( args. len( ) , 2 ) ;
102
- let ( base, exp) = ( & args[ 0 ] , & args[ 1 ] ) ;
106
+ // take ownership of arguments by popping in reverse order
107
+ let exp = args. pop ( ) . unwrap ( ) ;
108
+ let base = args. pop ( ) . unwrap ( ) ;
103
109
assert_eq ! ( base. data_type( ) , DataType :: Float64 ) ;
104
110
assert_eq ! ( exp. data_type( ) , DataType :: Float64 ) ;
105
111
@@ -118,7 +124,7 @@ impl ScalarUDFImpl for PowUdf {
118
124
) => {
119
125
// compute the output. Note DataFusion treats `None` as NULL.
120
126
let res = match ( base, exp) {
121
- ( Some ( base) , Some ( exp) ) => Some ( base. powf ( * exp) ) ,
127
+ ( Some ( base) , Some ( exp) ) => Some ( base. powf ( exp) ) ,
122
128
// one or both arguments were NULL
123
129
_ => None ,
124
130
} ;
@@ -140,31 +146,33 @@ impl ScalarUDFImpl for PowUdf {
140
146
// kernel creates very fast "vectorized" code and
141
147
// handles things like null values for us.
142
148
let res: Float64Array =
143
- compute:: unary ( base_array, |base| base. powf ( * exp) ) ;
149
+ compute:: unary ( base_array, |base| base. powf ( exp) ) ;
144
150
Arc :: new ( res)
145
151
}
146
152
} ;
147
153
Ok ( ColumnarValue :: Array ( result_array) )
148
154
}
149
155
150
- // special case if the base is a constant (note this code is quite
151
- // similar to the previous case, so we omit comments)
156
+ // special case if the base is a constant.
157
+ //
158
+ // Note this case is very similar to the previous case, so we could
159
+ // use the same pattern. However, for this case we demonstrate an
160
+ // even more advanced pattern to potentially avoid allocating a new array
152
161
(
153
162
ColumnarValue :: Scalar ( ScalarValue :: Float64 ( base) ) ,
154
163
ColumnarValue :: Array ( exp_array) ,
155
164
) => {
156
165
let res = match base {
157
166
None => new_null_array ( exp_array. data_type ( ) , exp_array. len ( ) ) ,
158
- Some ( base) => {
159
- let exp_array = exp_array. as_primitive :: < Float64Type > ( ) ;
160
- let res: Float64Array =
161
- compute:: unary ( exp_array, |exp| base. powf ( exp) ) ;
162
- Arc :: new ( res)
163
- }
167
+ Some ( base) => maybe_pow_in_place ( base, exp_array) ?,
164
168
} ;
165
169
Ok ( ColumnarValue :: Array ( res) )
166
170
}
167
- // Both arguments are arrays so we have to perform the calculation for every row
171
+ // Both arguments are arrays so we have to perform the calculation
172
+ // for every row
173
+ //
174
+ // Note this could also be done in place using `binary_mut` as
175
+ // is done in `maybe_pow_in_place` but here we use binary for simplicity
168
176
( ColumnarValue :: Array ( base_array) , ColumnarValue :: Array ( exp_array) ) => {
169
177
let res: Float64Array = compute:: binary (
170
178
base_array. as_primitive :: < Float64Type > ( ) ,
@@ -191,6 +199,52 @@ impl ScalarUDFImpl for PowUdf {
191
199
}
192
200
}
193
201
202
+ /// Evaluate `base ^ exp` *without* allocating a new array, if possible
203
+ fn maybe_pow_in_place ( base : f64 , exp_array : ArrayRef ) -> Result < ArrayRef > {
204
+ // Calling `unary` creates a new array for the results. Avoiding
205
+ // allocations is a common optimization in performance critical code.
206
+ // arrow-rs allows this optimization via the `unary_mut`
207
+ // and `binary_mut` kernels in certain cases
208
+ //
209
+ // These kernels can only be used if there are no other references to
210
+ // the arrays (exp_array has to be the last remaining reference).
211
+ let owned_array = exp_array
212
+ // as in the previous example, we first downcast to &Float64Array
213
+ . as_primitive :: < Float64Type > ( )
214
+ // non-obviously, we call clone here to get an owned `Float64Array`.
215
+ // Calling clone() is relatively inexpensive as it increments
216
+ // some ref counts but doesn't clone the data)
217
+ //
218
+ // Once we have the owned Float64Array we can drop the original
219
+ // exp_array (untyped) reference
220
+ . clone ( ) ;
221
+
222
+ // We *MUST* drop the reference to `exp_array` explicitly so that
223
+ // owned_array is the only reference remaining in this function.
224
+ //
225
+ // Note that depending on the query there may still be other references
226
+ // to the underlying buffers, which would prevent reuse. The only way to
227
+ // know for sure is the result of `compute::unary_mut`
228
+ drop ( exp_array) ;
229
+
230
+ // If we have the only reference, compute the result directly into the same
231
+ // allocation as was used for the input array
232
+ match compute:: unary_mut ( owned_array, |exp| base. powf ( exp) ) {
233
+ Err ( _orig_array) => {
234
+ // unary_mut will return the original array if there are other
235
+ // references into the underling buffer (and thus reuse is
236
+ // impossible)
237
+ //
238
+ // In a real implementation, this case should fall back to
239
+ // calling `unary` and allocate a new array; In this example
240
+ // we will return an error for demonstration purposes
241
+ exec_err ! ( "Could not reuse array for maybe_pow_in_place" )
242
+ }
243
+ // a result of OK means the operation was run successfully
244
+ Ok ( res) => Ok ( Arc :: new ( res) ) ,
245
+ }
246
+ }
247
+
194
248
/// In this example we register `PowUdf` as a user defined function
195
249
/// and invoke it via the DataFrame API and SQL
196
250
#[ tokio:: main]
@@ -215,9 +269,29 @@ async fn main() -> Result<()> {
215
269
// print the results
216
270
df. show ( ) . await ?;
217
271
218
- // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL
219
- let sql_df = ctx. sql ( "SELECT pow(2, 10), my_pow(a, b) FROM t" ) . await ?;
220
- sql_df. show ( ) . await ?;
272
+ // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL
273
+ ctx. sql ( "SELECT pow(2, 10), my_pow(a, b) FROM t" )
274
+ . await ?
275
+ . show ( )
276
+ . await ?;
277
+
278
+ // You can also invoke pow_in_place by passing a constant base and a
279
+ // column `a` as the exponent . If there is only a single
280
+ // reference to `a` the code works well
281
+ ctx. sql ( "SELECT pow(2, a) FROM t" ) . await ?. show ( ) . await ?;
282
+
283
+ // However, if there are multiple references to `a` in the evaluation
284
+ // the array storage can not be reused
285
+ let err = ctx
286
+ . sql ( "SELECT pow(2, a), pow(3, a) FROM t" )
287
+ . await ?
288
+ . show ( )
289
+ . await
290
+ . unwrap_err ( ) ;
291
+ assert_eq ! (
292
+ err. to_string( ) ,
293
+ "Execution error: Could not reuse array for maybe_pow_in_place"
294
+ ) ;
221
295
222
296
Ok ( ( ) )
223
297
}
0 commit comments