Skip to content

Commit 0a2e422

Browse files
authored
simplify array_has UDF to InList expr when haystack is constant (#15354)
* simplify `array_has` UDF to `InList` expr when haystack is constant * add `.slt` tests, also simplify with `make_array` * tweak comment * add test for `make_array` arg simplification
1 parent f3975da commit 0a2e422

File tree

2 files changed

+315
-1
lines changed

2 files changed

+315
-1
lines changed

datafusion/functions-nested/src/array_has.rs

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ use datafusion_common::cast::as_generic_list_array;
2727
use datafusion_common::utils::string_utils::string_array_to_vec;
2828
use datafusion_common::utils::take_function_args;
2929
use datafusion_common::{exec_err, Result, ScalarValue};
30+
use datafusion_expr::expr::{InList, ScalarFunction};
31+
use datafusion_expr::simplify::ExprSimplifyResult;
3032
use datafusion_expr::{
31-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
33+
ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility,
3234
};
3335
use datafusion_macros::user_doc;
3436
use datafusion_physical_expr_common::datum::compare_with_eq;
3537
use itertools::Itertools;
3638

39+
use crate::make_array::make_array_udf;
3740
use crate::utils::make_scalar_function;
3841

3942
use std::any::Any;
@@ -121,6 +124,52 @@ impl ScalarUDFImpl for ArrayHas {
121124
Ok(DataType::Boolean)
122125
}
123126

127+
fn simplify(
128+
&self,
129+
mut args: Vec<Expr>,
130+
_info: &dyn datafusion_expr::simplify::SimplifyInfo,
131+
) -> Result<ExprSimplifyResult> {
132+
let [haystack, needle] = take_function_args(self.name(), &mut args)?;
133+
134+
// if the haystack is a constant list, we can use an inlist expression which is more
135+
// efficient because the haystack is not varying per-row
136+
if let Expr::Literal(ScalarValue::List(array)) = haystack {
137+
// TODO: support LargeList
138+
// (not supported by `convert_array_to_scalar_vec`)
139+
// (FixedSizeList not supported either, but seems to have worked fine when attempting to
140+
// build a reproducer)
141+
142+
assert_eq!(array.len(), 1); // guarantee of ScalarValue
143+
if let Ok(scalar_values) =
144+
ScalarValue::convert_array_to_scalar_vec(array.as_ref())
145+
{
146+
assert_eq!(scalar_values.len(), 1);
147+
let list = scalar_values
148+
.into_iter()
149+
.flatten()
150+
.map(Expr::Literal)
151+
.collect();
152+
153+
return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
154+
expr: Box::new(std::mem::take(needle)),
155+
list,
156+
negated: false,
157+
})));
158+
}
159+
} else if let Expr::ScalarFunction(ScalarFunction { func, args }) = haystack {
160+
// make_array has a static set of arguments, so we can pull the arguments out from it
161+
if func == &make_array_udf() {
162+
return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
163+
expr: Box::new(std::mem::take(needle)),
164+
list: std::mem::take(args),
165+
negated: false,
166+
})));
167+
}
168+
}
169+
170+
Ok(ExprSimplifyResult::Original(args))
171+
}
172+
124173
fn invoke_with_args(
125174
&self,
126175
args: datafusion_expr::ScalarFunctionArgs,
@@ -542,3 +591,86 @@ fn general_array_has_all_and_any_kernel(
542591
}),
543592
}
544593
}
594+
595+
#[cfg(test)]
596+
mod tests {
597+
use arrow::array::create_array;
598+
use datafusion_common::utils::SingleRowListArrayBuilder;
599+
use datafusion_expr::{
600+
col, execution_props::ExecutionProps, lit, simplify::ExprSimplifyResult, Expr,
601+
ScalarUDFImpl,
602+
};
603+
604+
use crate::expr_fn::make_array;
605+
606+
use super::ArrayHas;
607+
608+
#[test]
609+
fn test_simplify_array_has_to_in_list() {
610+
let haystack = lit(SingleRowListArrayBuilder::new(create_array!(
611+
Int32,
612+
[1, 2, 3]
613+
))
614+
.build_list_scalar());
615+
let needle = col("c");
616+
617+
let props = ExecutionProps::new();
618+
let context = datafusion_expr::simplify::SimplifyContext::new(&props);
619+
620+
let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) =
621+
ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
622+
else {
623+
panic!("Expected simplified expression");
624+
};
625+
626+
assert_eq!(
627+
in_list,
628+
datafusion_expr::expr::InList {
629+
expr: Box::new(needle),
630+
list: vec![lit(1), lit(2), lit(3)],
631+
negated: false,
632+
}
633+
);
634+
}
635+
636+
#[test]
637+
fn test_simplify_array_has_with_make_array_to_in_list() {
638+
let haystack = make_array(vec![lit(1), lit(2), lit(3)]);
639+
let needle = col("c");
640+
641+
let props = ExecutionProps::new();
642+
let context = datafusion_expr::simplify::SimplifyContext::new(&props);
643+
644+
let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) =
645+
ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
646+
else {
647+
panic!("Expected simplified expression");
648+
};
649+
650+
assert_eq!(
651+
in_list,
652+
datafusion_expr::expr::InList {
653+
expr: Box::new(needle),
654+
list: vec![lit(1), lit(2), lit(3)],
655+
negated: false,
656+
}
657+
);
658+
}
659+
660+
#[test]
661+
fn test_array_has_complex_list_not_simplified() {
662+
let haystack = col("c1");
663+
let needle = col("c2");
664+
665+
let props = ExecutionProps::new();
666+
let context = datafusion_expr::simplify::SimplifyContext::new(&props);
667+
668+
let Ok(ExprSimplifyResult::Original(args)) =
669+
ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
670+
else {
671+
panic!("Expected simplified expression");
672+
};
673+
674+
assert_eq!(args, vec![col("c1"), col("c2")],);
675+
}
676+
}

datafusion/sqllogictest/test_files/array.slt

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5969,6 +5969,188 @@ true false true false false false true true false false true false true
59695969
#----
59705970
#true false true false false false true true false false true false true
59715971

5972+
# rewrite various array_has operations to InList where the haystack is a literal list
5973+
# NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements, so we make 4-element haystack lists
5974+
5975+
query I
5976+
with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
5977+
select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c');
5978+
----
5979+
1
5980+
5981+
query TT
5982+
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
5983+
select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c');
5984+
----
5985+
logical_plan
5986+
01)Projection: count(Int64(1)) AS count(*)
5987+
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
5988+
03)----SubqueryAlias: test
5989+
04)------SubqueryAlias: t
5990+
05)--------Projection:
5991+
06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
5992+
07)------------TableScan: tmp_table projection=[value]
5993+
physical_plan
5994+
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
5995+
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
5996+
03)----CoalescePartitionsExec
5997+
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
5998+
05)--------ProjectionExec: expr=[]
5999+
06)----------CoalesceBatchesExec: target_batch_size=8192
6000+
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }])
6001+
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
6002+
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
6003+
6004+
query I
6005+
with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
6006+
select count(*) from test WHERE needle = ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']);
6007+
----
6008+
1
6009+
6010+
query TT
6011+
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
6012+
select count(*) from test WHERE needle = ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']);
6013+
----
6014+
logical_plan
6015+
01)Projection: count(Int64(1)) AS count(*)
6016+
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
6017+
03)----SubqueryAlias: test
6018+
04)------SubqueryAlias: t
6019+
05)--------Projection:
6020+
06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
6021+
07)------------TableScan: tmp_table projection=[value]
6022+
physical_plan
6023+
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
6024+
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
6025+
03)----CoalescePartitionsExec
6026+
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
6027+
05)--------ProjectionExec: expr=[]
6028+
06)----------CoalesceBatchesExec: target_batch_size=8192
6029+
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }])
6030+
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
6031+
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
6032+
6033+
query I
6034+
with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
6035+
select count(*) from test WHERE array_has(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], needle);
6036+
----
6037+
1
6038+
6039+
query TT
6040+
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
6041+
select count(*) from test WHERE array_has(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], needle);
6042+
----
6043+
logical_plan
6044+
01)Projection: count(Int64(1)) AS count(*)
6045+
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
6046+
03)----SubqueryAlias: test
6047+
04)------SubqueryAlias: t
6048+
05)--------Projection:
6049+
06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
6050+
07)------------TableScan: tmp_table projection=[value]
6051+
physical_plan
6052+
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
6053+
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
6054+
03)----CoalescePartitionsExec
6055+
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
6056+
05)--------ProjectionExec: expr=[]
6057+
06)----------CoalesceBatchesExec: target_batch_size=8192
6058+
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }])
6059+
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
6060+
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
6061+
6062+
# FIXME: due to rewrite below not working, this is _extremely_ slow to evaluate
6063+
# query I
6064+
# with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
6065+
# select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle);
6066+
# ----
6067+
# 1
6068+
6069+
# FIXME: array_has with large list haystack not currently rewritten to InList
6070+
query TT
6071+
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
6072+
select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle);
6073+
----
6074+
logical_plan
6075+
01)Projection: count(Int64(1)) AS count(*)
6076+
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
6077+
03)----SubqueryAlias: test
6078+
04)------SubqueryAlias: t
6079+
05)--------Projection:
6080+
06)----------Filter: array_has(LargeList([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)))
6081+
07)------------TableScan: tmp_table projection=[value]
6082+
physical_plan
6083+
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
6084+
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
6085+
03)----CoalescePartitionsExec
6086+
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
6087+
05)--------ProjectionExec: expr=[]
6088+
06)----------CoalesceBatchesExec: target_batch_size=8192
6089+
07)------------FilterExec: array_has([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c], substr(md5(CAST(value@0 AS Utf8)), 1, 32))
6090+
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
6091+
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
6092+
6093+
query I
6094+
with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
6095+
select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'FixedSizeList(4, Utf8View)'), needle);
6096+
----
6097+
1
6098+
6099+
query TT
6100+
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
6101+
select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'FixedSizeList(4, Utf8View)'), needle);
6102+
----
6103+
logical_plan
6104+
01)Projection: count(Int64(1)) AS count(*)
6105+
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
6106+
03)----SubqueryAlias: test
6107+
04)------SubqueryAlias: t
6108+
05)--------Projection:
6109+
06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
6110+
07)------------TableScan: tmp_table projection=[value]
6111+
physical_plan
6112+
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
6113+
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
6114+
03)----CoalescePartitionsExec
6115+
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
6116+
05)--------ProjectionExec: expr=[]
6117+
06)----------CoalesceBatchesExec: target_batch_size=8192
6118+
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }])
6119+
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
6120+
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
6121+
6122+
query I
6123+
with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
6124+
select count(*) from test WHERE array_has([needle], needle);
6125+
----
6126+
100000
6127+
6128+
# TODO: this should probably be possible to completely remove the filter as always true?
6129+
query TT
6130+
explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
6131+
select count(*) from test WHERE array_has([needle], needle);
6132+
----
6133+
logical_plan
6134+
01)Projection: count(Int64(1)) AS count(*)
6135+
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
6136+
03)----SubqueryAlias: test
6137+
04)------SubqueryAlias: t
6138+
05)--------Projection:
6139+
06)----------Filter: __common_expr_3 = __common_expr_3
6140+
07)------------Projection: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) AS __common_expr_3
6141+
08)--------------TableScan: tmp_table projection=[value]
6142+
physical_plan
6143+
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
6144+
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
6145+
03)----CoalescePartitionsExec
6146+
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
6147+
05)--------ProjectionExec: expr=[]
6148+
06)----------CoalesceBatchesExec: target_batch_size=8192
6149+
07)------------FilterExec: __common_expr_3@0 = __common_expr_3@0
6150+
08)--------------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8)), 1, 32) as __common_expr_3]
6151+
09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
6152+
10)------------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
6153+
59726154
# any operator
59736155
query ?
59746156
select column3 from arrays where 'L'=any(column3);

0 commit comments

Comments
 (0)