Skip to content

Commit 10da22e

Browse files
goldmedaladbmal
authored andcommitted
Introduce INFORMATION_SCHEMA.ROUTINES table (apache#13255)
* tmp * introduce routines table * add is_deterministic field * cargo fmt * rollback the session_state changed
1 parent b0b6e44 commit 10da22e

File tree

1 file changed

+155
-12
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+155
-12
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 155 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,19 @@ use std::borrow::Cow;
1919
use std::hash::Hash;
2020
use std::{any::Any, sync::Arc};
2121

22-
use crate::expressions::try_cast;
22+
use crate::expressions::{try_cast, BinaryExpr, CastExpr};
2323
use crate::PhysicalExpr;
2424

2525
use arrow::array::*;
26+
use arrow::compute::kernels::cmp::eq;
2627
use arrow::compute::kernels::zip::zip;
2728
use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
2829
use arrow::datatypes::{DataType, Schema};
2930
use datafusion_common::cast::as_boolean_array;
3031
use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue};
31-
use datafusion_expr::ColumnarValue;
32+
use datafusion_expr::{ColumnarValue, Operator};
3233

3334
use super::{Column, Literal};
34-
use datafusion_physical_expr_common::datum::compare_with_eq;
3535
use itertools::Itertools;
3636

3737
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
@@ -57,9 +57,14 @@ enum EvalMethod {
5757
InfallibleExprOrNull,
5858
/// This is a specialization for a specific use case where we can take a fast path
5959
/// if there is just one when/then pair and both the `then` and `else` expressions
60-
/// are literal values
60+
/// are literal values.
6161
/// CASE WHEN condition THEN literal ELSE literal END
6262
ScalarOrScalar,
63+
/// This is a specialization for a sprcific use case where we can take a fast path
64+
/// for the divide-by-zero expression when the divisor is zero.
65+
///
66+
/// CASE WHEN y > 0 THEN x / y ELSE NULL END
67+
DivideZeroExpression,
6368
}
6469

6570
/// The CASE expression is similar to a series of nested if/else and there are two forms that
@@ -149,6 +154,51 @@ impl CaseExpr {
149154
&& else_expr.as_ref().unwrap().as_any().is::<Literal>()
150155
{
151156
EvalMethod::ScalarOrScalar
157+
} else if when_then_expr.len() == 1
158+
&& when_then_expr[0].0.as_any().is::<BinaryExpr>()
159+
{
160+
let b = when_then_expr[0]
161+
.0
162+
.as_any()
163+
.downcast_ref::<BinaryExpr>()
164+
.expect("expected binary expression");
165+
166+
if b.op().eq(&Operator::Gt) {
167+
if let Some(col) = b.left().as_any().downcast_ref::<Column>() {
168+
if let Some(lit) = b.right().as_any().downcast_ref::<Literal>() {
169+
if matches!(lit.value(), ScalarValue::Int32(Some(0))) {
170+
if let Some(b) = when_then_expr[0]
171+
.1
172+
.as_any()
173+
.downcast_ref::<BinaryExpr>()
174+
{
175+
if b.op().eq(&Operator::Divide) {
176+
if let Some(cast) =
177+
b.right().as_any().downcast_ref::<CastExpr>()
178+
{
179+
if let Some(col2) = cast
180+
.expr()
181+
.as_any()
182+
.downcast_ref::<Column>()
183+
{
184+
if col.name() == col2.name() {
185+
return Ok(Self {
186+
expr: None,
187+
when_then_expr,
188+
else_expr,
189+
eval_method: EvalMethod::DivideZeroExpression,
190+
});
191+
}
192+
}
193+
}
194+
}
195+
}
196+
}
197+
}
198+
}
199+
}
200+
201+
EvalMethod::NoExpression
152202
} else {
153203
EvalMethod::NoExpression
154204
};
@@ -203,13 +253,7 @@ impl CaseExpr {
203253
.evaluate_selection(batch, &remainder)?;
204254
let when_value = when_value.into_array(batch.num_rows())?;
205255
// build boolean array representing which rows match the "when" value
206-
let when_match = compare_with_eq(
207-
&when_value,
208-
&base_value,
209-
// The types of case and when expressions will be coerced to match.
210-
// We only need to check if the base_value is nested.
211-
base_value.data_type().is_nested(),
212-
)?;
256+
let when_match = eq(&when_value, &base_value)?;
213257
// Treat nulls as false
214258
let when_match = match when_match.null_count() {
215259
0 => Cow::Borrowed(&when_match),
@@ -385,12 +429,49 @@ impl CaseExpr {
385429

386430
// keep `else_expr`'s data type and return type consistent
387431
let e = self.else_expr.as_ref().unwrap();
388-
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)
432+
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
389433
.unwrap_or_else(|_| Arc::clone(e));
390434
let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
391435

392436
Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
393437
}
438+
439+
fn divide_by_zero_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
440+
let return_type = self.data_type(&batch.schema())?;
441+
442+
// start with nulls as default output
443+
let mut current_value = new_null_array(&return_type, batch.num_rows());
444+
let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
445+
let when_value = BooleanArray::from(vec![true; batch.num_rows()]);
446+
let then_value = self.when_then_expr()[0].1.evaluate(batch)?;
447+
current_value = match then_value {
448+
ColumnarValue::Scalar(ScalarValue::Null) => {
449+
nullif(current_value.as_ref(), &when_value)?
450+
}
451+
ColumnarValue::Scalar(then_value) => {
452+
zip(&when_value, &then_value.to_scalar()?, &current_value)?
453+
}
454+
ColumnarValue::Array(then_value) => {
455+
zip(&when_value, &then_value, &current_value)?
456+
}
457+
};
458+
459+
// Succeed tuples should be filtered out for short-circuit evaluation,
460+
// null values for the current when expr should be kept
461+
remainder = and_not(&remainder, &when_value)?;
462+
463+
if let Some(e) = &self.else_expr {
464+
// keep `else_expr`'s data type and return type consistent
465+
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
466+
.unwrap_or_else(|_| Arc::clone(e));
467+
let else_ = expr
468+
.evaluate_selection(batch, &remainder)?
469+
.into_array(batch.num_rows())?;
470+
current_value = zip(&remainder, &else_, &current_value)?;
471+
}
472+
473+
Ok(ColumnarValue::Array(current_value))
474+
}
394475
}
395476

396477
impl PhysicalExpr for CaseExpr {
@@ -454,6 +535,7 @@ impl PhysicalExpr for CaseExpr {
454535
self.case_column_or_null(batch)
455536
}
456537
EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
538+
EvalMethod::DivideZeroExpression => self.divide_by_zero_expr(batch),
457539
}
458540
}
459541

@@ -741,6 +823,13 @@ mod tests {
741823
Ok(batch)
742824
}
743825

826+
fn case_test_batch2() -> Result<RecordBatch> {
827+
let schema = Schema::new(vec![Field::new("y", DataType::Int32, true)]);
828+
let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
829+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
830+
Ok(batch)
831+
}
832+
744833
#[test]
745834
fn case_without_expr_else() -> Result<()> {
746835
let batch = case_test_batch()?;
@@ -1212,4 +1301,58 @@ mod tests {
12121301
comparison_coercion(&left_type, right_type)
12131302
})
12141303
}
1304+
1305+
#[test]
1306+
fn gen_optimize_case_for_div_zero() -> Result<()> {
1307+
let batch = case_test_batch1()?;
1308+
let schema = batch.schema();
1309+
1310+
let batch2 = case_test_batch2()?;
1311+
let schema2 = batch2.schema();
1312+
1313+
// DivideZeroExpression: CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END
1314+
let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
1315+
let then1 = binary(
1316+
lit(25.0f64),
1317+
Operator::Divide,
1318+
cast(col("a", &schema)?, &batch.schema(), Float64)?,
1319+
&batch.schema(),
1320+
)?;
1321+
let x = lit(ScalarValue::Float64(None));
1322+
let expr = generate_case_when_with_type_coercion(
1323+
None,
1324+
vec![(when1, then1)],
1325+
Some(x),
1326+
schema.as_ref(),
1327+
)?;
1328+
let case = expr
1329+
.as_any()
1330+
.downcast_ref::<CaseExpr>()
1331+
.expect("expected case expression");
1332+
assert_eq!(case.eval_method, EvalMethod::DivideZeroExpression);
1333+
1334+
// NoExpression: CASE WHEN a > 0 THEN 25.0 / cast(y, float64) ELSE float64(null) END
1335+
let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
1336+
let then1 = binary(
1337+
lit(25.0f64),
1338+
Operator::Divide,
1339+
cast(col("y", &schema2)?, &batch2.schema(), Float64)?,
1340+
&batch2.schema(),
1341+
)?;
1342+
let x = lit(ScalarValue::Float64(None));
1343+
1344+
let expr = generate_case_when_with_type_coercion(
1345+
None,
1346+
vec![(when1, then1)],
1347+
Some(x),
1348+
schema.as_ref(),
1349+
)?;
1350+
let case = expr
1351+
.as_any()
1352+
.downcast_ref::<CaseExpr>()
1353+
.expect("expected case expression");
1354+
assert_eq!(case.eval_method, EvalMethod::NoExpression);
1355+
1356+
Ok(())
1357+
}
12151358
}

0 commit comments

Comments
 (0)