Skip to content

Commit 5955860

Browse files
Unparsing optimized (> 2 inputs) unions (#14031)
* tests and optimizer in testing queries * unparse optimized unions * format Cargo.toml * format Cargo.toml * revert test * rewrite test to avoid cyclic dep * remove old test * cleanup * comments and error handling * handle union with lt 2 inputs
1 parent ad5a04f commit 5955860

File tree

2 files changed

+68
-17
lines changed

2 files changed

+68
-17
lines changed

datafusion/sql/src/unparser/plan.rs

+16-13
Original file line numberDiff line numberDiff line change
@@ -706,13 +706,6 @@ impl Unparser<'_> {
706706
Ok(())
707707
}
708708
LogicalPlan::Union(union) => {
709-
if union.inputs.len() != 2 {
710-
return not_impl_err!(
711-
"UNION ALL expected 2 inputs, but found {}",
712-
union.inputs.len()
713-
);
714-
}
715-
716709
// Covers cases where the UNION is a subquery and the projection is at the top level
717710
if select.already_projected() {
718711
return self.derive_with_dialect_alias(
@@ -729,12 +722,22 @@ impl Unparser<'_> {
729722
.map(|input| self.select_to_sql_expr(input, query))
730723
.collect::<Result<Vec<_>>>()?;
731724

732-
let union_expr = SetExpr::SetOperation {
733-
op: ast::SetOperator::Union,
734-
set_quantifier: ast::SetQuantifier::All,
735-
left: Box::new(input_exprs[0].clone()),
736-
right: Box::new(input_exprs[1].clone()),
737-
};
725+
if input_exprs.len() < 2 {
726+
return internal_err!("UNION operator requires at least 2 inputs");
727+
}
728+
729+
// Build the union expression tree bottom-up by reversing the order
730+
// note that we are also swapping left and right inputs because of the rev
731+
let union_expr = input_exprs
732+
.into_iter()
733+
.rev()
734+
.reduce(|a, b| SetExpr::SetOperation {
735+
op: ast::SetOperator::Union,
736+
set_quantifier: ast::SetQuantifier::All,
737+
left: Box::new(b),
738+
right: Box::new(a),
739+
})
740+
.unwrap();
738741

739742
let Some(query) = query.as_mut() else {
740743
return internal_err!(

datafusion/sql/tests/cases/plan_to_sql.rs

+52-4
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow_schema::*;
18+
use arrow_schema::{DataType, Field, Schema};
1919
use datafusion_common::{assert_contains, DFSchema, DFSchemaRef, Result, TableReference};
2020
use datafusion_expr::test::function_stub::{
2121
count_udaf, max_udaf, min_udaf, sum, sum_udaf,
2222
};
2323
use datafusion_expr::{
24-
col, lit, table_scan, wildcard, Expr, Extension, LogicalPlan, LogicalPlanBuilder,
25-
UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
24+
col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan,
25+
LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
2626
};
2727
use datafusion_functions::unicode;
2828
use datafusion_functions_aggregate::grouping::grouping_udaf;
@@ -42,7 +42,7 @@ use std::{fmt, vec};
4242

4343
use crate::common::{MockContextProvider, MockSessionState};
4444
use datafusion_expr::builder::{
45-
table_scan_with_filter_and_fetch, table_scan_with_filters,
45+
project, table_scan_with_filter_and_fetch, table_scan_with_filters,
4646
};
4747
use datafusion_functions::core::planner::CoreFunctionPlanner;
4848
use datafusion_functions_nested::extract::array_element_udf;
@@ -1615,3 +1615,51 @@ fn test_unparse_extension_to_sql() -> Result<()> {
16151615
}
16161616
Ok(())
16171617
}
1618+
1619+
#[test]
1620+
fn test_unparse_optimized_multi_union() -> Result<()> {
1621+
let unparser = Unparser::default();
1622+
1623+
let schema = Schema::new(vec![
1624+
Field::new("x", DataType::Int32, false),
1625+
Field::new("y", DataType::Utf8, false),
1626+
]);
1627+
1628+
let dfschema = Arc::new(DFSchema::try_from(schema)?);
1629+
1630+
let empty = LogicalPlan::EmptyRelation(EmptyRelation {
1631+
produce_one_row: true,
1632+
schema: dfschema.clone(),
1633+
});
1634+
1635+
let plan = LogicalPlan::Union(Union {
1636+
inputs: vec![
1637+
project(empty.clone(), vec![lit(1).alias("x"), lit("a").alias("y")])?.into(),
1638+
project(empty.clone(), vec![lit(1).alias("x"), lit("b").alias("y")])?.into(),
1639+
project(empty.clone(), vec![lit(2).alias("x"), lit("a").alias("y")])?.into(),
1640+
project(empty.clone(), vec![lit(2).alias("x"), lit("c").alias("y")])?.into(),
1641+
],
1642+
schema: dfschema.clone(),
1643+
});
1644+
1645+
let sql = "SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y";
1646+
1647+
assert_eq!(unparser.plan_to_sql(&plan)?.to_string(), sql);
1648+
1649+
let plan = LogicalPlan::Union(Union {
1650+
inputs: vec![project(
1651+
empty.clone(),
1652+
vec![lit(1).alias("x"), lit("a").alias("y")],
1653+
)?
1654+
.into()],
1655+
schema: dfschema.clone(),
1656+
});
1657+
1658+
if let Some(err) = plan_to_sql(&plan).err() {
1659+
assert_contains!(err.to_string(), "UNION operator requires at least 2 inputs");
1660+
} else {
1661+
panic!("Expected error")
1662+
}
1663+
1664+
Ok(())
1665+
}

0 commit comments

Comments
 (0)