Skip to content

Commit 3b6aac2

Browse files
authored
Support struct coercion in type_union_resolution (#12839)
* support strucy Signed-off-by: jayzhan211 <[email protected]> * fix struct Signed-off-by: jayzhan211 <[email protected]> * rm todo Signed-off-by: jayzhan211 <[email protected]> * add more test Signed-off-by: jayzhan211 <[email protected]> * fix field order Signed-off-by: jayzhan211 <[email protected]> * add lsit of stuct test Signed-off-by: jayzhan211 <[email protected]> * upd err msg Signed-off-by: jayzhan211 <[email protected]> * fmt Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent 58c32cb commit 3b6aac2

File tree

4 files changed

+228
-14
lines changed

4 files changed

+228
-14
lines changed

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ use crate::operator::Operator;
2525
use arrow::array::{new_empty_array, Array};
2626
use arrow::compute::can_cast_types;
2727
use arrow::datatypes::{
28-
DataType, Field, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
29-
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
28+
DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION,
29+
DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
3030
};
3131
use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result};
3232

@@ -370,6 +370,8 @@ impl From<&DataType> for TypeCategory {
370370
/// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules
371371
/// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted
372372
/// decimal precision and scale when coercing decimal types.
373+
///
374+
/// This function doesn't preserve correct field name and nullability for the struct type, we only care about data type.
373375
pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
374376
if data_types.is_empty() {
375377
return None;
@@ -476,6 +478,46 @@ fn type_union_resolution_coercion(
476478
type_union_resolution_coercion(lhs.data_type(), rhs.data_type());
477479
new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true))))
478480
}
481+
(DataType::Struct(lhs), DataType::Struct(rhs)) => {
482+
if lhs.len() != rhs.len() {
483+
return None;
484+
}
485+
486+
// Search the field in the right hand side with the SAME field name
487+
fn search_corresponding_coerced_type(
488+
lhs_field: &FieldRef,
489+
rhs: &Fields,
490+
) -> Option<DataType> {
491+
for rhs_field in rhs.iter() {
492+
if lhs_field.name() == rhs_field.name() {
493+
if let Some(t) = type_union_resolution_coercion(
494+
lhs_field.data_type(),
495+
rhs_field.data_type(),
496+
) {
497+
return Some(t);
498+
} else {
499+
return None;
500+
}
501+
}
502+
}
503+
504+
None
505+
}
506+
507+
let types = lhs
508+
.iter()
509+
.map(|lhs_field| search_corresponding_coerced_type(lhs_field, rhs))
510+
.collect::<Option<Vec<_>>>()?;
511+
512+
let fields = types
513+
.into_iter()
514+
.enumerate()
515+
.map(|(i, datatype)| {
516+
Arc::new(Field::new(format!("c{i}"), datatype, true))
517+
})
518+
.collect::<Vec<FieldRef>>();
519+
Some(DataType::Struct(fields.into()))
520+
}
479521
_ => {
480522
// numeric coercion is the same as comparison coercion, both find the narrowest type
481523
// that can accommodate both types

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -221,20 +221,37 @@ fn get_valid_types_with_scalar_udf(
221221
current_types: &[DataType],
222222
func: &ScalarUDF,
223223
) -> Result<Vec<Vec<DataType>>> {
224-
let valid_types = match signature {
224+
match signature {
225225
TypeSignature::UserDefined => match func.coerce_types(current_types) {
226-
Ok(coerced_types) => vec![coerced_types],
227-
Err(e) => return exec_err!("User-defined coercion failed with {:?}", e),
226+
Ok(coerced_types) => Ok(vec![coerced_types]),
227+
Err(e) => exec_err!("User-defined coercion failed with {:?}", e),
228228
},
229-
TypeSignature::OneOf(signatures) => signatures
230-
.iter()
231-
.filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok())
232-
.flatten()
233-
.collect::<Vec<_>>(),
234-
_ => get_valid_types(signature, current_types)?,
235-
};
229+
TypeSignature::OneOf(signatures) => {
230+
let mut res = vec![];
231+
let mut errors = vec![];
232+
for sig in signatures {
233+
match get_valid_types_with_scalar_udf(sig, current_types, func) {
234+
Ok(valid_types) => {
235+
res.extend(valid_types);
236+
}
237+
Err(e) => {
238+
errors.push(e.to_string());
239+
}
240+
}
241+
}
236242

237-
Ok(valid_types)
243+
// Every signature failed, return the joined error
244+
if res.is_empty() {
245+
internal_err!(
246+
"Failed to match any signature, errors: {}",
247+
errors.join(",")
248+
)
249+
} else {
250+
Ok(res)
251+
}
252+
}
253+
_ => get_valid_types(signature, current_types),
254+
}
238255
}
239256

240257
fn get_valid_types_with_aggregate_udf(

datafusion/functions-nested/src/make_array.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ use arrow_array::{
2727
use arrow_buffer::OffsetBuffer;
2828
use arrow_schema::DataType::{LargeList, List, Null};
2929
use arrow_schema::{DataType, Field};
30+
use datafusion_common::{exec_err, internal_err};
3031
use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result};
3132
use datafusion_expr::binary::type_union_resolution;
3233
use datafusion_expr::TypeSignature;
3334
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
35+
use itertools::Itertools;
3436

3537
use crate::utils::make_scalar_function;
3638

@@ -106,6 +108,32 @@ impl ScalarUDFImpl for MakeArray {
106108

107109
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
108110
if let Some(new_type) = type_union_resolution(arg_types) {
111+
// TODO: Move the logic to type_union_resolution if this applies to other functions as well
112+
// Handle struct where we only change the data type but preserve the field name and nullability.
113+
// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1"
114+
let is_struct_and_has_same_key = are_all_struct_and_have_same_key(arg_types)?;
115+
if is_struct_and_has_same_key {
116+
let data_types: Vec<_> = if let DataType::Struct(fields) = &arg_types[0] {
117+
fields.iter().map(|f| f.data_type().to_owned()).collect()
118+
} else {
119+
return internal_err!("Struct type is checked is the previous function, so this should be unreachable");
120+
};
121+
122+
let mut final_struct_types = vec![];
123+
for s in arg_types {
124+
let mut new_fields = vec![];
125+
if let DataType::Struct(fields) = s {
126+
for (i, f) in fields.iter().enumerate() {
127+
let field = Arc::unwrap_or_clone(Arc::clone(f))
128+
.with_data_type(data_types[i].to_owned());
129+
new_fields.push(Arc::new(field));
130+
}
131+
}
132+
final_struct_types.push(DataType::Struct(new_fields.into()))
133+
}
134+
return Ok(final_struct_types);
135+
}
136+
109137
if let DataType::FixedSizeList(field, _) = new_type {
110138
Ok(vec![DataType::List(field); arg_types.len()])
111139
} else if new_type.is_null() {
@@ -123,6 +151,26 @@ impl ScalarUDFImpl for MakeArray {
123151
}
124152
}
125153

154+
fn are_all_struct_and_have_same_key(data_types: &[DataType]) -> Result<bool> {
155+
let mut keys_string: Option<String> = None;
156+
for data_type in data_types {
157+
if let DataType::Struct(fields) = data_type {
158+
let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
159+
if let Some(ref k) = keys_string {
160+
if *k != keys {
161+
return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys);
162+
}
163+
} else {
164+
keys_string = Some(keys);
165+
}
166+
} else {
167+
return Ok(false);
168+
}
169+
}
170+
171+
Ok(true)
172+
}
173+
126174
// Empty array is a special case that is useful for many other array functions
127175
pub(super) fn empty_array_type() -> DataType {
128176
DataType::List(Arc::new(Field::new("item", DataType::Int64, true)))

datafusion/sqllogictest/test_files/struct.slt

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,34 @@ You reached the bottom!
374374
statement ok
375375
drop view complex_view;
376376

377+
# struct with different keys r1 and r2 is not valid
378+
statement ok
379+
create table t(a struct<r1 varchar, c int>, b struct<r2 varchar, c float>) as values (struct('red', 1), struct('blue', 2.3));
380+
381+
# Expect same keys for struct type but got mismatched pair r1,c and r2,c
382+
query error
383+
select [a, b] from t;
384+
385+
statement ok
386+
drop table t;
387+
388+
# struct with the same key
389+
statement ok
390+
create table t(a struct<r varchar, c int>, b struct<r varchar, c float>) as values (struct('red', 1), struct('blue', 2.3));
391+
392+
query T
393+
select arrow_typeof([a, b]) from t;
394+
----
395+
List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
396+
397+
query ?
398+
select [a, b] from t;
399+
----
400+
[{r: red, c: 1}, {r: blue, c: 2}]
401+
402+
statement ok
403+
drop table t;
404+
377405
# Test row alias
378406

379407
query ?
@@ -412,7 +440,6 @@ select * from t;
412440
----
413441
{r: red, b: 2} {r: blue, b: 2.3}
414442

415-
# TODO: Should be coerced to float
416443
query T
417444
select arrow_typeof(c1) from t;
418445
----
@@ -422,3 +449,83 @@ query T
422449
select arrow_typeof(c2) from t;
423450
----
424451
Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
452+
453+
statement ok
454+
drop table t;
455+
456+
##################################
457+
## Test Coalesce with Struct
458+
##################################
459+
460+
statement ok
461+
CREATE TABLE t (
462+
s1 struct(a int, b varchar),
463+
s2 struct(a float, b varchar)
464+
) AS VALUES
465+
(row(1, 'red'), row(1.1, 'string1')),
466+
(row(2, 'blue'), row(2.2, 'string2')),
467+
(row(3, 'green'), row(33.2, 'string3'))
468+
;
469+
470+
query ?
471+
select coalesce(s1) from t;
472+
----
473+
{a: 1, b: red}
474+
{a: 2, b: blue}
475+
{a: 3, b: green}
476+
477+
# TODO: a's type should be float
478+
query T
479+
select arrow_typeof(coalesce(s1)) from t;
480+
----
481+
Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
482+
Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
483+
Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
484+
485+
statement ok
486+
drop table t;
487+
488+
statement ok
489+
CREATE TABLE t (
490+
s1 struct(a int, b varchar),
491+
s2 struct(a float, b varchar)
492+
) AS VALUES
493+
(row(1, 'red'), row(1.1, 'string1')),
494+
(null, row(2.2, 'string2')),
495+
(row(3, 'green'), row(33.2, 'string3'))
496+
;
497+
498+
# TODO: second column should not be null
499+
query ?
500+
select coalesce(s1) from t;
501+
----
502+
{a: 1, b: red}
503+
NULL
504+
{a: 3, b: green}
505+
506+
statement ok
507+
drop table t;
508+
509+
# row() with incorrect order
510+
statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'blue' to value of Float64 type
511+
create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values
512+
(row('red', 1), row(2.3, 'blue')),
513+
(row('purple', 1), row('green', 2.3));
514+
515+
# out of order struct literal
516+
# TODO: This query should not fail
517+
statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
518+
create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2, r: 'b'});
519+
520+
##################################
521+
## Test Array of Struct
522+
##################################
523+
524+
query ?
525+
select [{r: 'a', c: 1}, {r: 'b', c: 2}];
526+
----
527+
[{r: a, c: 1}, {r: b, c: 2}]
528+
529+
# Can't create a list of struct with different field types
530+
query error
531+
select [{r: 'a', c: 1}, {c: 2, r: 'b'}];

0 commit comments

Comments
 (0)