Skip to content

Commit 61ef64c

Browse files
Address review comments III
1 parent 3a5b50a commit 61ef64c

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

datafusion/functions-nested/src/max.rs

+14-14
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
//! [`ScalarUDFImpl`] definitions for array_max function.
1919
use crate::utils::make_scalar_function;
2020
use arrow::array::ArrayRef;
21-
use arrow_schema::DataType;
22-
use arrow_schema::DataType::{FixedSizeList, LargeList, List};
21+
use arrow::datatypes::DataType;
22+
use arrow::datatypes::DataType::List;
2323
use datafusion_common::cast::as_list_array;
24-
use datafusion_common::exec_err;
24+
use datafusion_common::utils::take_function_args;
25+
use datafusion_common::{exec_err, ScalarValue};
2526
use datafusion_doc::Documentation;
2627
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
27-
use datafusion_functions::utils::take_function_args;
2828
use datafusion_functions_aggregate::min_max;
2929
use datafusion_macros::user_doc;
30+
use itertools::Itertools;
3031
use std::any::Any;
3132

3233
make_udf_expr_and_func!(
@@ -90,12 +91,8 @@ impl ScalarUDFImpl for ArrayMax {
9091

9192
fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
9293
match &arg_types[0] {
93-
List(field) | LargeList(field) | FixedSizeList(field, _) => {
94-
Ok(field.data_type().clone())
95-
}
96-
_ => exec_err!(
97-
"Not reachable, data_type should be List, LargeList or FixedSizeList"
98-
),
94+
List(field) => Ok(field.data_type().clone()),
95+
_ => exec_err!("Not reachable, data_type should be List"),
9996
}
10097
}
10198

@@ -127,10 +124,13 @@ pub fn array_max_inner(args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef>
127124
let [arg1] = take_function_args("array_max", args)?;
128125

129126
match arg1.data_type() {
130-
List(_) | LargeList(_) | FixedSizeList(_, _) => {
131-
let input_array = as_list_array(&arg1)?.value(0);
132-
let max_result = min_max::max_batch(&input_array);
133-
max_result?.to_array()
127+
List(_) => {
128+
let input_list_array = as_list_array(&arg1)?;
129+
let result_vec = input_list_array
130+
.iter()
131+
.flat_map(|arr| min_max::max_batch(&arr.unwrap()))
132+
.collect_vec();
133+
ScalarValue::iter_to_array(result_vec)
134134
}
135135
_ => exec_err!("array_max does not support type: {:?}", arg1.data_type()),
136136
}

datafusion/sqllogictest/test_files/array.slt

+14
Original file line numberDiff line numberDiff line change
@@ -1496,6 +1496,20 @@ select array_max(make_array(5.1, -3.2, 6.3, 4.9));
14961496
----
14971497
6.3
14981498

1499+
query ?I
1500+
select input, array_max(input) from (select make_array(d - 1, d, d + 1) input from (values (0), (10), (20), (30), (NULL)) t(d))
1501+
----
1502+
[-1, 0, 1] 1
1503+
[9, 10, 11] 11
1504+
[19, 20, 21] 21
1505+
[29, 30, 31] 31
1506+
[NULL, NULL, NULL] NULL
1507+
1508+
query II
1509+
select array_max(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), array_max(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)'));
1510+
----
1511+
3 1
1512+
14991513
query I
15001514
select array_max(make_array());
15011515
----

docs/source/user-guide/sql/scalar_functions.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -3008,7 +3008,7 @@ array_length(array, dimension)
30083008

30093009
Returns the maximum value in the array.
30103010

3011-
```
3011+
```sql
30123012
array_max(array)
30133013
```
30143014

0 commit comments

Comments
 (0)