Skip to content

Commit dbc7890

Browse files
authored
feat: Support FixedSizedList in array_distance function (#12381)
1 parent 4b51bbe commit dbc7890

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

datafusion/functions-nested/src/distance.rs

+17-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use datafusion_common::cast::{
2828
as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
2929
as_int64_array,
3030
};
31+
use datafusion_common::utils::coerced_fixed_size_list_to_list;
3132
use datafusion_common::DataFusionError;
3233
use datafusion_common::{exec_err, Result};
3334
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
@@ -51,7 +52,7 @@ pub(super) struct ArrayDistance {
5152
impl ArrayDistance {
5253
pub fn new() -> Self {
5354
Self {
54-
signature: Signature::variadic_any(Volatility::Immutable),
55+
signature: Signature::user_defined(Volatility::Immutable),
5556
aliases: vec!["list_distance".to_string()],
5657
}
5758
}
@@ -77,6 +78,21 @@ impl ScalarUDFImpl for ArrayDistance {
7778
}
7879
}
7980

81+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
82+
if arg_types.len() != 2 {
83+
return exec_err!("array_distance expects exactly two arguments");
84+
}
85+
let mut result = Vec::new();
86+
for arg_type in arg_types {
87+
match arg_type {
88+
List(_) | LargeList(_) | FixedSizeList(_, _) => result.push(coerced_fixed_size_list_to_list(arg_type)),
89+
_ => return exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."),
90+
}
91+
}
92+
93+
Ok(result)
94+
}
95+
8096
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
8197
make_scalar_function(array_distance_inner)(args)
8298
}

datafusion/sqllogictest/test_files/array.slt

+57
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,38 @@ AS VALUES
629629
(arrow_cast(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), 'FixedSizeList(10, List(Int64))'), [28, 29, 30], [28, 29, 30], 10)
630630
;
631631

632+
statement ok
633+
CREATE TABLE arrays_distance_table
634+
AS VALUES
635+
(make_array(1, 2, 3), make_array(1, 2, 3), make_array(1.1, 2.2, 3.3) , make_array(1.1, NULL, 3.3)),
636+
(make_array(1, 2, 3), make_array(4, 5, 6), make_array(4.4, 5.5, 6.6), make_array(4.4, NULL, 6.6)),
637+
(make_array(1, 2, 3), make_array(7, 8, 9), make_array(7.7, 8.8, 9.9), make_array(7.7, NULL, 9.9)),
638+
(make_array(1, 2, 3), make_array(10, 11, 12), make_array(10.1, 11.2, 12.3), make_array(10.1, NULL, 12.3))
639+
;
640+
641+
statement ok
642+
CREATE TABLE large_arrays_distance_table
643+
AS
644+
SELECT
645+
arrow_cast(column1, 'LargeList(Int64)') AS column1,
646+
arrow_cast(column2, 'LargeList(Int64)') AS column2,
647+
arrow_cast(column3, 'LargeList(Float64)') AS column3,
648+
arrow_cast(column4, 'LargeList(Float64)') AS column4
649+
FROM arrays_distance_table
650+
;
651+
652+
statement ok
653+
CREATE TABLE fixed_size_arrays_distance_table
654+
AS
655+
SELECT
656+
arrow_cast(column1, 'FixedSizeList(3, Int64)') AS column1,
657+
arrow_cast(column2, 'FixedSizeList(3, Int64)') AS column2,
658+
arrow_cast(column3, 'FixedSizeList(3, Float64)') AS column3,
659+
arrow_cast(column4, 'FixedSizeList(3, Float64)') AS column4
660+
FROM arrays_distance_table
661+
;
662+
663+
632664
# Array literal
633665

634666
## boolean coercion is not supported
@@ -4768,6 +4800,31 @@ select list_distance([1, 2, 3], [1, 2, 3]) AS distance;
47684800
----
47694801
0
47704802

4803+
# array_distance with columns
4804+
query RRR
4805+
select array_distance(column1, column2), array_distance(column1, column3), array_distance(column1, column4) from arrays_distance_table;
4806+
----
4807+
0 0.374165738677 NULL
4808+
5.196152422707 6.063827174318 NULL
4809+
10.392304845413 11.778794505381 NULL
4810+
15.58845726812 15.935494971917 NULL
4811+
4812+
query RRR
4813+
select array_distance(column1, column2), array_distance(column1, column3), array_distance(column1, column4) from large_arrays_distance_table;
4814+
----
4815+
0 0.374165738677 NULL
4816+
5.196152422707 6.063827174318 NULL
4817+
10.392304845413 11.778794505381 NULL
4818+
15.58845726812 15.935494971917 NULL
4819+
4820+
query RRR
4821+
select array_distance(column1, column2), array_distance(column1, column3), array_distance(column1, column4) from fixed_size_arrays_distance_table;
4822+
----
4823+
0 0.374165738677 NULL
4824+
5.196152422707 6.063827174318 NULL
4825+
10.392304845413 11.778794505381 NULL
4826+
15.58845726812 15.935494971917 NULL
4827+
47714828

47724829
## array_dims (aliases: `list_dims`)
47734830

0 commit comments

Comments
 (0)