Skip to content

Commit 3854419

Browse files
feat: add support for fixed list wildcard in type signature (#9312)
* feat: add support for fixed list wildcard in type signature * fmt * clippy * update tests * switch coercing * update logic * fix test * add tests to make sure it cant coerce diff sizes * add test for same type
1 parent 6041dea commit 3854419

File tree

3 files changed

+115
-6
lines changed

3 files changed

+115
-6
lines changed

datafusion/common/src/utils.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,9 @@ pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType {
518518
/// Compute the number of dimensions in a list data type.
519519
pub fn list_ndims(data_type: &DataType) -> u64 {
520520
match data_type {
521-
DataType::List(field) | DataType::LargeList(field) => {
522-
1 + list_ndims(field.data_type())
523-
}
521+
DataType::List(field)
522+
| DataType::LargeList(field)
523+
| DataType::FixedSizeList(field, _) => 1 + list_ndims(field.data_type()),
524524
_ => 0,
525525
}
526526
}

datafusion/expr/src/signature.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ use arrow::datatypes::DataType;
3030
/// return results with this timezone.
3131
pub const TIMEZONE_WILDCARD: &str = "+TZ";
3232

33+
/// Constant that is used as a placeholder for any valid fixed size list.
34+
/// This is used where a function can accept a fixed size list type with any
35+
/// valid length. It exists to avoid the need to enumerate all possible fixed size list lengths.
36+
pub const FIXED_SIZE_LIST_WILDCARD: i32 = i32::MIN;
37+
3338
///A function's volatility, which defines the functions eligibility for certain optimizations
3439
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
3540
pub enum Volatility {

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::signature::{ArrayFunctionSignature, TIMEZONE_WILDCARD};
18+
use std::sync::Arc;
19+
20+
use crate::signature::{
21+
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
22+
};
1923
use crate::{Signature, TypeSignature};
2024
use arrow::{
2125
compute::can_cast_types,
@@ -379,13 +383,28 @@ fn coerced_from<'a>(
379383
List(_) if matches!(type_from, FixedSizeList(_, _)) => Some(type_into.clone()),
380384

381385
// Only accept list and largelist with the same number of dimensions unless the type is Null.
382-
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this.
386+
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this
383387
List(_) | LargeList(_)
384388
if datafusion_common::utils::base_type(type_from).eq(&Null)
385389
|| list_ndims(type_from) == list_ndims(type_into) =>
386390
{
387391
Some(type_into.clone())
388392
}
393+
// should be able to coerce wildcard fixed size list to non wildcard fixed size list
394+
FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD) => match type_from {
395+
FixedSizeList(f_from, size_from) => {
396+
match coerced_from(f_into.data_type(), f_from.data_type()) {
397+
Some(data_type) if &data_type != f_into.data_type() => {
398+
let new_field =
399+
Arc::new(f_into.as_ref().clone().with_data_type(data_type));
400+
Some(FixedSizeList(new_field, *size_from))
401+
}
402+
Some(_) => Some(FixedSizeList(f_into.clone(), *size_from)),
403+
_ => None,
404+
}
405+
}
406+
_ => None,
407+
},
389408

390409
Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => {
391410
match type_from {
@@ -415,8 +434,12 @@ fn coerced_from<'a>(
415434

416435
#[cfg(test)]
417436
mod tests {
437+
use std::sync::Arc;
438+
439+
use crate::Volatility;
440+
418441
use super::*;
419-
use arrow::datatypes::{DataType, TimeUnit};
442+
use arrow::datatypes::{DataType, Field, TimeUnit};
420443

421444
#[test]
422445
fn test_maybe_data_types() {
@@ -492,4 +515,85 @@ mod tests {
492515

493516
Ok(())
494517
}
518+
519+
#[test]
520+
fn test_fixed_list_wildcard_coerce() -> Result<()> {
521+
let inner = Arc::new(Field::new("item", DataType::Int32, false));
522+
let current_types = vec![
523+
DataType::FixedSizeList(inner.clone(), 2), // able to coerce for any size
524+
];
525+
526+
let signature = Signature::exact(
527+
vec![DataType::FixedSizeList(
528+
inner.clone(),
529+
FIXED_SIZE_LIST_WILDCARD,
530+
)],
531+
Volatility::Stable,
532+
);
533+
534+
let coerced_data_types = data_types(&current_types, &signature).unwrap();
535+
assert_eq!(coerced_data_types, current_types);
536+
537+
// make sure it can't coerce to a different size
538+
let signature = Signature::exact(
539+
vec![DataType::FixedSizeList(inner.clone(), 3)],
540+
Volatility::Stable,
541+
);
542+
let coerced_data_types = data_types(&current_types, &signature);
543+
assert!(coerced_data_types.is_err());
544+
545+
// make sure it works with the same type.
546+
let signature = Signature::exact(
547+
vec![DataType::FixedSizeList(inner.clone(), 2)],
548+
Volatility::Stable,
549+
);
550+
let coerced_data_types = data_types(&current_types, &signature).unwrap();
551+
assert_eq!(coerced_data_types, current_types);
552+
553+
Ok(())
554+
}
555+
556+
#[test]
557+
fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
558+
let type_into = DataType::FixedSizeList(
559+
Arc::new(Field::new(
560+
"item",
561+
DataType::FixedSizeList(
562+
Arc::new(Field::new("item", DataType::Int32, false)),
563+
FIXED_SIZE_LIST_WILDCARD,
564+
),
565+
false,
566+
)),
567+
FIXED_SIZE_LIST_WILDCARD,
568+
);
569+
570+
let type_from = DataType::FixedSizeList(
571+
Arc::new(Field::new(
572+
"item",
573+
DataType::FixedSizeList(
574+
Arc::new(Field::new("item", DataType::Int8, false)),
575+
4,
576+
),
577+
false,
578+
)),
579+
3,
580+
);
581+
582+
assert_eq!(
583+
coerced_from(&type_into, &type_from),
584+
Some(DataType::FixedSizeList(
585+
Arc::new(Field::new(
586+
"item",
587+
DataType::FixedSizeList(
588+
Arc::new(Field::new("item", DataType::Int32, false)),
589+
4,
590+
),
591+
false,
592+
)),
593+
3,
594+
))
595+
);
596+
597+
Ok(())
598+
}
495599
}

0 commit comments

Comments
 (0)