Skip to content

Commit 667c77a

Browse files
rluvatonOmega359alamb
authored
feat(function): add least function (#13786)
* start adding least fn * feat(function): add least function * update function name * fix scalar smaller function * add tests * run Clippy and Fmt * Generated docs using `./dev/update_function_docs.sh` * add comment why `descending: false` * update comment * Update least.rs Co-authored-by: Bruce Ritchie <[email protected]> * Update scalar_functions.md * run ./dev/update_function_docs.sh to update docs * merge greatest and least implementation to one * add header --------- Co-authored-by: Bruce Ritchie <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent 9b19d36 commit 667c77a

File tree

6 files changed

+612
-134
lines changed

6 files changed

+612
-134
lines changed

datafusion/functions/src/core/greatest.rs

+49-134
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,19 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{make_comparator, Array, ArrayRef, BooleanArray};
18+
use crate::core::greatest_least_utils::GreatestLeastOperator;
19+
use arrow::array::{make_comparator, Array, BooleanArray};
1920
use arrow::compute::kernels::cmp;
20-
use arrow::compute::kernels::zip::zip;
2121
use arrow::compute::SortOptions;
2222
use arrow::datatypes::DataType;
2323
use arrow_buffer::BooleanBuffer;
24-
use datafusion_common::{exec_err, plan_err, Result, ScalarValue};
24+
use datafusion_common::{internal_err, Result, ScalarValue};
2525
use datafusion_doc::Documentation;
26-
use datafusion_expr::binary::type_union_resolution;
2726
use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL;
2827
use datafusion_expr::ColumnarValue;
2928
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
3029
use std::any::Any;
31-
use std::sync::{Arc, OnceLock};
30+
use std::sync::OnceLock;
3231

3332
const SORT_OPTIONS: SortOptions = SortOptions {
3433
// We want greatest first
@@ -57,79 +56,57 @@ impl GreatestFunc {
5756
}
5857
}
5958

60-
fn get_logical_null_count(arr: &dyn Array) -> usize {
61-
arr.logical_nulls()
62-
.map(|n| n.null_count())
63-
.unwrap_or_default()
64-
}
59+
impl GreatestLeastOperator for GreatestFunc {
60+
const NAME: &'static str = "greatest";
6561

66-
/// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array
67-
/// Nulls are always considered smaller than any other value
68-
fn get_larger(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray> {
69-
// Fast path:
70-
// If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel
71-
// - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined.
72-
// - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case
73-
if !lhs.data_type().is_nested()
74-
&& get_logical_null_count(lhs) == 0
75-
&& get_logical_null_count(rhs) == 0
76-
{
77-
return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into());
78-
}
62+
fn keep_scalar<'a>(
63+
lhs: &'a ScalarValue,
64+
rhs: &'a ScalarValue,
65+
) -> Result<&'a ScalarValue> {
66+
if !lhs.data_type().is_nested() {
67+
return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) };
68+
}
7969

80-
let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?;
70+
// If complex type we can't compare directly as we want null values to be smaller
71+
let cmp = make_comparator(
72+
lhs.to_array()?.as_ref(),
73+
rhs.to_array()?.as_ref(),
74+
SORT_OPTIONS,
75+
)?;
8176

82-
if lhs.len() != rhs.len() {
83-
return exec_err!(
84-
"All arrays should have the same length for greatest comparison"
85-
);
77+
if cmp(0, 0).is_ge() {
78+
Ok(lhs)
79+
} else {
80+
Ok(rhs)
81+
}
8682
}
8783

88-
let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge());
89-
90-
// No nulls as we only want to keep the values that are larger, its either true or false
91-
Ok(BooleanArray::new(values, None))
92-
}
93-
94-
/// Return array where the largest value at each index is kept
95-
fn keep_larger(lhs: ArrayRef, rhs: ArrayRef) -> Result<ArrayRef> {
96-
// True for values that we should keep from the left array
97-
let keep_lhs = get_larger(lhs.as_ref(), rhs.as_ref())?;
98-
99-
let larger = zip(&keep_lhs, &lhs, &rhs)?;
84+
/// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array
85+
/// Nulls are always considered smaller than any other value
86+
fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray> {
87+
// Fast path:
88+
// If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel
89+
// - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined.
90+
// - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case
91+
if !lhs.data_type().is_nested()
92+
&& lhs.logical_null_count() == 0
93+
&& rhs.logical_null_count() == 0
94+
{
95+
return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into());
96+
}
10097

101-
Ok(larger)
102-
}
98+
let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?;
10399

104-
fn keep_larger_scalar<'a>(
105-
lhs: &'a ScalarValue,
106-
rhs: &'a ScalarValue,
107-
) -> Result<&'a ScalarValue> {
108-
if !lhs.data_type().is_nested() {
109-
return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) };
110-
}
111-
112-
// If complex type we can't compare directly as we want null values to be smaller
113-
let cmp = make_comparator(
114-
lhs.to_array()?.as_ref(),
115-
rhs.to_array()?.as_ref(),
116-
SORT_OPTIONS,
117-
)?;
100+
if lhs.len() != rhs.len() {
101+
return internal_err!(
102+
"All arrays should have the same length for greatest comparison"
103+
);
104+
}
118105

119-
if cmp(0, 0).is_ge() {
120-
Ok(lhs)
121-
} else {
122-
Ok(rhs)
123-
}
124-
}
106+
let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge());
125107

126-
fn find_coerced_type(data_types: &[DataType]) -> Result<DataType> {
127-
if data_types.is_empty() {
128-
plan_err!("greatest was called without any arguments. It requires at least 1.")
129-
} else if let Some(coerced_type) = type_union_resolution(data_types) {
130-
Ok(coerced_type)
131-
} else {
132-
plan_err!("Cannot find a common type for arguments")
108+
// No nulls as we only want to keep the values that are larger, its either true or false
109+
Ok(BooleanArray::new(values, None))
133110
}
134111
}
135112

@@ -151,74 +128,12 @@ impl ScalarUDFImpl for GreatestFunc {
151128
}
152129

153130
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
154-
if args.is_empty() {
155-
return exec_err!(
156-
"greatest was called with no arguments. It requires at least 1."
157-
);
158-
}
159-
160-
// Some engines (e.g. SQL Server) allow greatest with single arg, it's a noop
161-
if args.len() == 1 {
162-
return Ok(args[0].clone());
163-
}
164-
165-
// Split to scalars and arrays for later optimization
166-
let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x {
167-
ColumnarValue::Scalar(_) => true,
168-
ColumnarValue::Array(_) => false,
169-
});
170-
171-
let mut arrays_iter = arrays.iter().map(|x| match x {
172-
ColumnarValue::Array(a) => a,
173-
_ => unreachable!(),
174-
});
175-
176-
let first_array = arrays_iter.next();
177-
178-
let mut largest: ArrayRef;
179-
180-
// Optimization: merge all scalars into one to avoid recomputing
181-
if !scalars.is_empty() {
182-
let mut scalars_iter = scalars.iter().map(|x| match x {
183-
ColumnarValue::Scalar(s) => s,
184-
_ => unreachable!(),
185-
});
186-
187-
// We have at least one scalar
188-
let mut largest_scalar = scalars_iter.next().unwrap();
189-
190-
for scalar in scalars_iter {
191-
largest_scalar = keep_larger_scalar(largest_scalar, scalar)?;
192-
}
193-
194-
// If we only have scalars, return the largest one
195-
if arrays.is_empty() {
196-
return Ok(ColumnarValue::Scalar(largest_scalar.clone()));
197-
}
198-
199-
// We have at least one array
200-
let first_array = first_array.unwrap();
201-
202-
// Start with the largest value
203-
largest = keep_larger(
204-
Arc::clone(first_array),
205-
largest_scalar.to_array_of_size(first_array.len())?,
206-
)?;
207-
} else {
208-
// If we only have arrays, start with the first array
209-
// (We must have at least one array)
210-
largest = Arc::clone(first_array.unwrap());
211-
}
212-
213-
for array in arrays_iter {
214-
largest = keep_larger(Arc::clone(array), largest)?;
215-
}
216-
217-
Ok(ColumnarValue::Array(largest))
131+
super::greatest_least_utils::execute_conditional::<Self>(args)
218132
}
219133

220134
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
221-
let coerced_type = find_coerced_type(arg_types)?;
135+
let coerced_type =
136+
super::greatest_least_utils::find_coerced_type::<Self>(arg_types)?;
222137

223138
Ok(vec![coerced_type; arg_types.len()])
224139
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Array, ArrayRef, BooleanArray};
19+
use arrow::compute::kernels::zip::zip;
20+
use arrow::datatypes::DataType;
21+
use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
22+
use datafusion_expr_common::columnar_value::ColumnarValue;
23+
use datafusion_expr_common::type_coercion::binary::type_union_resolution;
24+
use std::sync::Arc;
25+
26+
pub(super) trait GreatestLeastOperator {
27+
const NAME: &'static str;
28+
29+
fn keep_scalar<'a>(
30+
lhs: &'a ScalarValue,
31+
rhs: &'a ScalarValue,
32+
) -> Result<&'a ScalarValue>;
33+
34+
/// Return array with true for values that we should keep from the lhs array
35+
fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray>;
36+
}
37+
38+
fn keep_array<Op: GreatestLeastOperator>(
39+
lhs: ArrayRef,
40+
rhs: ArrayRef,
41+
) -> Result<ArrayRef> {
42+
// True for values that we should keep from the left array
43+
let keep_lhs = Op::get_indexes_to_keep(lhs.as_ref(), rhs.as_ref())?;
44+
45+
let result = zip(&keep_lhs, &lhs, &rhs)?;
46+
47+
Ok(result)
48+
}
49+
50+
pub(super) fn execute_conditional<Op: GreatestLeastOperator>(
51+
args: &[ColumnarValue],
52+
) -> Result<ColumnarValue> {
53+
if args.is_empty() {
54+
return internal_err!(
55+
"{} was called with no arguments. It requires at least 1.",
56+
Op::NAME
57+
);
58+
}
59+
60+
// Some engines (e.g. SQL Server) allow greatest/least with single arg, it's a noop
61+
if args.len() == 1 {
62+
return Ok(args[0].clone());
63+
}
64+
65+
// Split to scalars and arrays for later optimization
66+
let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x {
67+
ColumnarValue::Scalar(_) => true,
68+
ColumnarValue::Array(_) => false,
69+
});
70+
71+
let mut arrays_iter = arrays.iter().map(|x| match x {
72+
ColumnarValue::Array(a) => a,
73+
_ => unreachable!(),
74+
});
75+
76+
let first_array = arrays_iter.next();
77+
78+
let mut result: ArrayRef;
79+
80+
// Optimization: merge all scalars into one to avoid recomputing (constant folding)
81+
if !scalars.is_empty() {
82+
let mut scalars_iter = scalars.iter().map(|x| match x {
83+
ColumnarValue::Scalar(s) => s,
84+
_ => unreachable!(),
85+
});
86+
87+
// We have at least one scalar
88+
let mut result_scalar = scalars_iter.next().unwrap();
89+
90+
for scalar in scalars_iter {
91+
result_scalar = Op::keep_scalar(result_scalar, scalar)?;
92+
}
93+
94+
// If we only have scalars, return the one that we should keep (largest/least)
95+
if arrays.is_empty() {
96+
return Ok(ColumnarValue::Scalar(result_scalar.clone()));
97+
}
98+
99+
// We have at least one array
100+
let first_array = first_array.unwrap();
101+
102+
// Start with the result value
103+
result = keep_array::<Op>(
104+
Arc::clone(first_array),
105+
result_scalar.to_array_of_size(first_array.len())?,
106+
)?;
107+
} else {
108+
// If we only have arrays, start with the first array
109+
// (We must have at least one array)
110+
result = Arc::clone(first_array.unwrap());
111+
}
112+
113+
for array in arrays_iter {
114+
result = keep_array::<Op>(Arc::clone(array), result)?;
115+
}
116+
117+
Ok(ColumnarValue::Array(result))
118+
}
119+
120+
pub(super) fn find_coerced_type<Op: GreatestLeastOperator>(
121+
data_types: &[DataType],
122+
) -> Result<DataType> {
123+
if data_types.is_empty() {
124+
plan_err!(
125+
"{} was called without any arguments. It requires at least 1.",
126+
Op::NAME
127+
)
128+
} else if let Some(coerced_type) = type_union_resolution(data_types) {
129+
Ok(coerced_type)
130+
} else {
131+
plan_err!("Cannot find a common type for arguments")
132+
}
133+
}

0 commit comments

Comments
 (0)