|
18 | 18 | //! [`ScalarUDFImpl`] definitions for array_max function.
|
19 | 19 | use crate::utils::make_scalar_function;
|
20 | 20 | use arrow::array::ArrayRef;
|
21 |
| -use arrow_schema::DataType; |
22 |
| -use arrow_schema::DataType::{FixedSizeList, LargeList, List}; |
| 21 | +use arrow::datatypes::DataType; |
| 22 | +use arrow::datatypes::DataType::List; |
23 | 23 | use datafusion_common::cast::as_list_array;
|
24 |
| -use datafusion_common::exec_err; |
| 24 | +use datafusion_common::utils::take_function_args; |
| 25 | +use datafusion_common::{exec_err, ScalarValue}; |
25 | 26 | use datafusion_doc::Documentation;
|
26 | 27 | use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
|
27 |
| -use datafusion_functions::utils::take_function_args; |
28 | 28 | use datafusion_functions_aggregate::min_max;
|
29 | 29 | use datafusion_macros::user_doc;
|
| 30 | +use itertools::Itertools; |
30 | 31 | use std::any::Any;
|
31 | 32 |
|
32 | 33 | make_udf_expr_and_func!(
|
@@ -90,12 +91,8 @@ impl ScalarUDFImpl for ArrayMax {
|
90 | 91 |
|
91 | 92 | fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
|
92 | 93 | match &arg_types[0] {
|
93 |
| - List(field) | LargeList(field) | FixedSizeList(field, _) => { |
94 |
| - Ok(field.data_type().clone()) |
95 |
| - } |
96 |
| - _ => exec_err!( |
97 |
| - "Not reachable, data_type should be List, LargeList or FixedSizeList" |
98 |
| - ), |
| 94 | + List(field) => Ok(field.data_type().clone()), |
| 95 | + _ => exec_err!("Not reachable, data_type should be List"), |
99 | 96 | }
|
100 | 97 | }
|
101 | 98 |
|
@@ -127,10 +124,13 @@ pub fn array_max_inner(args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef>
|
127 | 124 | let [arg1] = take_function_args("array_max", args)?;
|
128 | 125 |
|
129 | 126 | match arg1.data_type() {
|
130 |
| - List(_) | LargeList(_) | FixedSizeList(_, _) => { |
131 |
| - let input_array = as_list_array(&arg1)?.value(0); |
132 |
| - let max_result = min_max::max_batch(&input_array); |
133 |
| - max_result?.to_array() |
| 127 | + List(_) => { |
| 128 | + let input_list_array = as_list_array(&arg1)?; |
| 129 | + let result_vec = input_list_array |
| 130 | + .iter() |
| 131 | + .flat_map(|arr| min_max::max_batch(&arr.unwrap())) |
| 132 | + .collect_vec(); |
| 133 | + ScalarValue::iter_to_array(result_vec) |
134 | 134 | }
|
135 | 135 | _ => exec_err!("array_max does not support type: {:?}", arg1.data_type()),
|
136 | 136 | }
|
|
0 commit comments