@@ -19,19 +19,19 @@ use std::borrow::Cow;
19
19
use std:: hash:: Hash ;
20
20
use std:: { any:: Any , sync:: Arc } ;
21
21
22
- use crate :: expressions:: try_cast;
22
+ use crate :: expressions:: { try_cast, BinaryExpr , CastExpr } ;
23
23
use crate :: PhysicalExpr ;
24
24
25
25
use arrow:: array:: * ;
26
+ use arrow:: compute:: kernels:: cmp:: eq;
26
27
use arrow:: compute:: kernels:: zip:: zip;
27
28
use arrow:: compute:: { and, and_not, is_null, not, nullif, or, prep_null_mask_filter} ;
28
29
use arrow:: datatypes:: { DataType , Schema } ;
29
30
use datafusion_common:: cast:: as_boolean_array;
30
31
use datafusion_common:: { exec_err, internal_err, DataFusionError , Result , ScalarValue } ;
31
- use datafusion_expr:: ColumnarValue ;
32
+ use datafusion_expr:: { ColumnarValue , Operator } ;
32
33
33
34
use super :: { Column , Literal } ;
34
- use datafusion_physical_expr_common:: datum:: compare_with_eq;
35
35
use itertools:: Itertools ;
36
36
37
37
type WhenThen = ( Arc < dyn PhysicalExpr > , Arc < dyn PhysicalExpr > ) ;
@@ -57,9 +57,14 @@ enum EvalMethod {
57
57
InfallibleExprOrNull ,
58
58
/// This is a specialization for a specific use case where we can take a fast path
59
59
/// if there is just one when/then pair and both the `then` and `else` expressions
60
- /// are literal values
60
+ /// are literal values.
61
61
/// CASE WHEN condition THEN literal ELSE literal END
62
62
ScalarOrScalar ,
63
+ /// This is a specialization for a sprcific use case where we can take a fast path
64
+ /// for the divide-by-zero expression when the divisor is zero.
65
+ ///
66
+ /// CASE WHEN y > 0 THEN x / y ELSE NULL END
67
+ DivideZeroExpression ,
63
68
}
64
69
65
70
/// The CASE expression is similar to a series of nested if/else and there are two forms that
@@ -149,6 +154,51 @@ impl CaseExpr {
149
154
&& else_expr. as_ref ( ) . unwrap ( ) . as_any ( ) . is :: < Literal > ( )
150
155
{
151
156
EvalMethod :: ScalarOrScalar
157
+ } else if when_then_expr. len ( ) == 1
158
+ && when_then_expr[ 0 ] . 0 . as_any ( ) . is :: < BinaryExpr > ( )
159
+ {
160
+ let b = when_then_expr[ 0 ]
161
+ . 0
162
+ . as_any ( )
163
+ . downcast_ref :: < BinaryExpr > ( )
164
+ . expect ( "expected binary expression" ) ;
165
+
166
+ if b. op ( ) . eq ( & Operator :: Gt ) {
167
+ if let Some ( col) = b. left ( ) . as_any ( ) . downcast_ref :: < Column > ( ) {
168
+ if let Some ( lit) = b. right ( ) . as_any ( ) . downcast_ref :: < Literal > ( ) {
169
+ if matches ! ( lit. value( ) , ScalarValue :: Int32 ( Some ( 0 ) ) ) {
170
+ if let Some ( b) = when_then_expr[ 0 ]
171
+ . 1
172
+ . as_any ( )
173
+ . downcast_ref :: < BinaryExpr > ( )
174
+ {
175
+ if b. op ( ) . eq ( & Operator :: Divide ) {
176
+ if let Some ( cast) =
177
+ b. right ( ) . as_any ( ) . downcast_ref :: < CastExpr > ( )
178
+ {
179
+ if let Some ( col2) = cast
180
+ . expr ( )
181
+ . as_any ( )
182
+ . downcast_ref :: < Column > ( )
183
+ {
184
+ if col. name ( ) == col2. name ( ) {
185
+ return Ok ( Self {
186
+ expr : None ,
187
+ when_then_expr,
188
+ else_expr,
189
+ eval_method : EvalMethod :: DivideZeroExpression ,
190
+ } ) ;
191
+ }
192
+ }
193
+ }
194
+ }
195
+ }
196
+ }
197
+ }
198
+ }
199
+ }
200
+
201
+ EvalMethod :: NoExpression
152
202
} else {
153
203
EvalMethod :: NoExpression
154
204
} ;
@@ -203,13 +253,7 @@ impl CaseExpr {
203
253
. evaluate_selection ( batch, & remainder) ?;
204
254
let when_value = when_value. into_array ( batch. num_rows ( ) ) ?;
205
255
// build boolean array representing which rows match the "when" value
206
- let when_match = compare_with_eq (
207
- & when_value,
208
- & base_value,
209
- // The types of case and when expressions will be coerced to match.
210
- // We only need to check if the base_value is nested.
211
- base_value. data_type ( ) . is_nested ( ) ,
212
- ) ?;
256
+ let when_match = eq ( & when_value, & base_value) ?;
213
257
// Treat nulls as false
214
258
let when_match = match when_match. null_count ( ) {
215
259
0 => Cow :: Borrowed ( & when_match) ,
@@ -385,12 +429,49 @@ impl CaseExpr {
385
429
386
430
// keep `else_expr`'s data type and return type consistent
387
431
let e = self . else_expr . as_ref ( ) . unwrap ( ) ;
388
- let expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type)
432
+ let expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) )
389
433
. unwrap_or_else ( |_| Arc :: clone ( e) ) ;
390
434
let else_ = Scalar :: new ( expr. evaluate ( batch) ?. into_array ( 1 ) ?) ;
391
435
392
436
Ok ( ColumnarValue :: Array ( zip ( & when_value, & then_value, & else_) ?) )
393
437
}
438
+
439
+ fn divide_by_zero_expr ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
440
+ let return_type = self . data_type ( & batch. schema ( ) ) ?;
441
+
442
+ // start with nulls as default output
443
+ let mut current_value = new_null_array ( & return_type, batch. num_rows ( ) ) ;
444
+ let mut remainder = BooleanArray :: from ( vec ! [ true ; batch. num_rows( ) ] ) ;
445
+ let when_value = BooleanArray :: from ( vec ! [ true ; batch. num_rows( ) ] ) ;
446
+ let then_value = self . when_then_expr ( ) [ 0 ] . 1 . evaluate ( batch) ?;
447
+ current_value = match then_value {
448
+ ColumnarValue :: Scalar ( ScalarValue :: Null ) => {
449
+ nullif ( current_value. as_ref ( ) , & when_value) ?
450
+ }
451
+ ColumnarValue :: Scalar ( then_value) => {
452
+ zip ( & when_value, & then_value. to_scalar ( ) ?, & current_value) ?
453
+ }
454
+ ColumnarValue :: Array ( then_value) => {
455
+ zip ( & when_value, & then_value, & current_value) ?
456
+ }
457
+ } ;
458
+
459
+ // Succeed tuples should be filtered out for short-circuit evaluation,
460
+ // null values for the current when expr should be kept
461
+ remainder = and_not ( & remainder, & when_value) ?;
462
+
463
+ if let Some ( e) = & self . else_expr {
464
+ // keep `else_expr`'s data type and return type consistent
465
+ let expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) )
466
+ . unwrap_or_else ( |_| Arc :: clone ( e) ) ;
467
+ let else_ = expr
468
+ . evaluate_selection ( batch, & remainder) ?
469
+ . into_array ( batch. num_rows ( ) ) ?;
470
+ current_value = zip ( & remainder, & else_, & current_value) ?;
471
+ }
472
+
473
+ Ok ( ColumnarValue :: Array ( current_value) )
474
+ }
394
475
}
395
476
396
477
impl PhysicalExpr for CaseExpr {
@@ -454,6 +535,7 @@ impl PhysicalExpr for CaseExpr {
454
535
self . case_column_or_null ( batch)
455
536
}
456
537
EvalMethod :: ScalarOrScalar => self . scalar_or_scalar ( batch) ,
538
+ EvalMethod :: DivideZeroExpression => self . divide_by_zero_expr ( batch) ,
457
539
}
458
540
}
459
541
@@ -741,6 +823,13 @@ mod tests {
741
823
Ok ( batch)
742
824
}
743
825
826
+ fn case_test_batch2 ( ) -> Result < RecordBatch > {
827
+ let schema = Schema :: new ( vec ! [ Field :: new( "y" , DataType :: Int32 , true ) ] ) ;
828
+ let a = Int32Array :: from ( vec ! [ Some ( 1 ) , Some ( 0 ) , None , Some ( 5 ) ] ) ;
829
+ let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( a) ] ) ?;
830
+ Ok ( batch)
831
+ }
832
+
744
833
#[ test]
745
834
fn case_without_expr_else ( ) -> Result < ( ) > {
746
835
let batch = case_test_batch ( ) ?;
@@ -1212,4 +1301,58 @@ mod tests {
1212
1301
comparison_coercion ( & left_type, right_type)
1213
1302
} )
1214
1303
}
1304
+
1305
+ #[ test]
1306
+ fn gen_optimize_case_for_div_zero ( ) -> Result < ( ) > {
1307
+ let batch = case_test_batch1 ( ) ?;
1308
+ let schema = batch. schema ( ) ;
1309
+
1310
+ let batch2 = case_test_batch2 ( ) ?;
1311
+ let schema2 = batch2. schema ( ) ;
1312
+
1313
+ // DivideZeroExpression: CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END
1314
+ let when1 = binary ( col ( "a" , & schema) ?, Operator :: Gt , lit ( 0i32 ) , & batch. schema ( ) ) ?;
1315
+ let then1 = binary (
1316
+ lit ( 25.0f64 ) ,
1317
+ Operator :: Divide ,
1318
+ cast ( col ( "a" , & schema) ?, & batch. schema ( ) , Float64 ) ?,
1319
+ & batch. schema ( ) ,
1320
+ ) ?;
1321
+ let x = lit ( ScalarValue :: Float64 ( None ) ) ;
1322
+ let expr = generate_case_when_with_type_coercion (
1323
+ None ,
1324
+ vec ! [ ( when1, then1) ] ,
1325
+ Some ( x) ,
1326
+ schema. as_ref ( ) ,
1327
+ ) ?;
1328
+ let case = expr
1329
+ . as_any ( )
1330
+ . downcast_ref :: < CaseExpr > ( )
1331
+ . expect ( "expected case expression" ) ;
1332
+ assert_eq ! ( case. eval_method, EvalMethod :: DivideZeroExpression ) ;
1333
+
1334
+ // NoExpression: CASE WHEN a > 0 THEN 25.0 / cast(y, float64) ELSE float64(null) END
1335
+ let when1 = binary ( col ( "a" , & schema) ?, Operator :: Gt , lit ( 0i32 ) , & batch. schema ( ) ) ?;
1336
+ let then1 = binary (
1337
+ lit ( 25.0f64 ) ,
1338
+ Operator :: Divide ,
1339
+ cast ( col ( "y" , & schema2) ?, & batch2. schema ( ) , Float64 ) ?,
1340
+ & batch2. schema ( ) ,
1341
+ ) ?;
1342
+ let x = lit ( ScalarValue :: Float64 ( None ) ) ;
1343
+
1344
+ let expr = generate_case_when_with_type_coercion (
1345
+ None ,
1346
+ vec ! [ ( when1, then1) ] ,
1347
+ Some ( x) ,
1348
+ schema. as_ref ( ) ,
1349
+ ) ?;
1350
+ let case = expr
1351
+ . as_any ( )
1352
+ . downcast_ref :: < CaseExpr > ( )
1353
+ . expect ( "expected case expression" ) ;
1354
+ assert_eq ! ( case. eval_method, EvalMethod :: NoExpression ) ;
1355
+
1356
+ Ok ( ( ) )
1357
+ }
1215
1358
}
0 commit comments