|
15 | 15 | // specific language governing permissions and limitations
|
16 | 16 | // under the License.
|
17 | 17 |
|
18 |
| -use crate::signature::{ArrayFunctionSignature, TIMEZONE_WILDCARD}; |
| 18 | +use std::sync::Arc; |
| 19 | + |
| 20 | +use crate::signature::{ |
| 21 | + ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, |
| 22 | +}; |
19 | 23 | use crate::{Signature, TypeSignature};
|
20 | 24 | use arrow::{
|
21 | 25 | compute::can_cast_types,
|
@@ -379,13 +383,28 @@ fn coerced_from<'a>(
|
379 | 383 | List(_) if matches!(type_from, FixedSizeList(_, _)) => Some(type_into.clone()),
|
380 | 384 |
|
381 | 385 | // Only accept list and largelist with the same number of dimensions unless the type is Null.
|
382 |
| - // List or LargeList with different dimensions should be handled in TypeSignature or other places before this. |
| 386 | + // List or LargeList with different dimensions should be handled in TypeSignature or other places before this |
383 | 387 | List(_) | LargeList(_)
|
384 | 388 | if datafusion_common::utils::base_type(type_from).eq(&Null)
|
385 | 389 | || list_ndims(type_from) == list_ndims(type_into) =>
|
386 | 390 | {
|
387 | 391 | Some(type_into.clone())
|
388 | 392 | }
|
| 393 | + // should be able to coerce wildcard fixed size list to non wildcard fixed size list |
| 394 | + FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD) => match type_from { |
| 395 | + FixedSizeList(f_from, size_from) => { |
| 396 | + match coerced_from(f_into.data_type(), f_from.data_type()) { |
| 397 | + Some(data_type) if &data_type != f_into.data_type() => { |
| 398 | + let new_field = |
| 399 | + Arc::new(f_into.as_ref().clone().with_data_type(data_type)); |
| 400 | + Some(FixedSizeList(new_field, *size_from)) |
| 401 | + } |
| 402 | + Some(_) => Some(FixedSizeList(f_into.clone(), *size_from)), |
| 403 | + _ => None, |
| 404 | + } |
| 405 | + } |
| 406 | + _ => None, |
| 407 | + }, |
389 | 408 |
|
390 | 409 | Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => {
|
391 | 410 | match type_from {
|
@@ -415,8 +434,12 @@ fn coerced_from<'a>(
|
415 | 434 |
|
416 | 435 | #[cfg(test)]
|
417 | 436 | mod tests {
|
| 437 | + use std::sync::Arc; |
| 438 | + |
| 439 | + use crate::Volatility; |
| 440 | + |
418 | 441 | use super::*;
|
419 |
| - use arrow::datatypes::{DataType, TimeUnit}; |
| 442 | + use arrow::datatypes::{DataType, Field, TimeUnit}; |
420 | 443 |
|
421 | 444 | #[test]
|
422 | 445 | fn test_maybe_data_types() {
|
@@ -492,4 +515,85 @@ mod tests {
|
492 | 515 |
|
493 | 516 | Ok(())
|
494 | 517 | }
|
| 518 | + |
| 519 | + #[test] |
| 520 | + fn test_fixed_list_wildcard_coerce() -> Result<()> { |
| 521 | + let inner = Arc::new(Field::new("item", DataType::Int32, false)); |
| 522 | + let current_types = vec![ |
| 523 | + DataType::FixedSizeList(inner.clone(), 2), // able to coerce for any size |
| 524 | + ]; |
| 525 | + |
| 526 | + let signature = Signature::exact( |
| 527 | + vec![DataType::FixedSizeList( |
| 528 | + inner.clone(), |
| 529 | + FIXED_SIZE_LIST_WILDCARD, |
| 530 | + )], |
| 531 | + Volatility::Stable, |
| 532 | + ); |
| 533 | + |
| 534 | + let coerced_data_types = data_types(¤t_types, &signature).unwrap(); |
| 535 | + assert_eq!(coerced_data_types, current_types); |
| 536 | + |
| 537 | + // make sure it can't coerce to a different size |
| 538 | + let signature = Signature::exact( |
| 539 | + vec![DataType::FixedSizeList(inner.clone(), 3)], |
| 540 | + Volatility::Stable, |
| 541 | + ); |
| 542 | + let coerced_data_types = data_types(¤t_types, &signature); |
| 543 | + assert!(coerced_data_types.is_err()); |
| 544 | + |
| 545 | + // make sure it works with the same type. |
| 546 | + let signature = Signature::exact( |
| 547 | + vec![DataType::FixedSizeList(inner.clone(), 2)], |
| 548 | + Volatility::Stable, |
| 549 | + ); |
| 550 | + let coerced_data_types = data_types(¤t_types, &signature).unwrap(); |
| 551 | + assert_eq!(coerced_data_types, current_types); |
| 552 | + |
| 553 | + Ok(()) |
| 554 | + } |
| 555 | + |
| 556 | + #[test] |
| 557 | + fn test_nested_wildcard_fixed_size_lists() -> Result<()> { |
| 558 | + let type_into = DataType::FixedSizeList( |
| 559 | + Arc::new(Field::new( |
| 560 | + "item", |
| 561 | + DataType::FixedSizeList( |
| 562 | + Arc::new(Field::new("item", DataType::Int32, false)), |
| 563 | + FIXED_SIZE_LIST_WILDCARD, |
| 564 | + ), |
| 565 | + false, |
| 566 | + )), |
| 567 | + FIXED_SIZE_LIST_WILDCARD, |
| 568 | + ); |
| 569 | + |
| 570 | + let type_from = DataType::FixedSizeList( |
| 571 | + Arc::new(Field::new( |
| 572 | + "item", |
| 573 | + DataType::FixedSizeList( |
| 574 | + Arc::new(Field::new("item", DataType::Int8, false)), |
| 575 | + 4, |
| 576 | + ), |
| 577 | + false, |
| 578 | + )), |
| 579 | + 3, |
| 580 | + ); |
| 581 | + |
| 582 | + assert_eq!( |
| 583 | + coerced_from(&type_into, &type_from), |
| 584 | + Some(DataType::FixedSizeList( |
| 585 | + Arc::new(Field::new( |
| 586 | + "item", |
| 587 | + DataType::FixedSizeList( |
| 588 | + Arc::new(Field::new("item", DataType::Int32, false)), |
| 589 | + 4, |
| 590 | + ), |
| 591 | + false, |
| 592 | + )), |
| 593 | + 3, |
| 594 | + )) |
| 595 | + ); |
| 596 | + |
| 597 | + Ok(()) |
| 598 | + } |
495 | 599 | }
|
0 commit comments