Skip to content

Commit fc84a63

Browse files
authored
Support List for Array aggregate order and distinct (#9234)
* first draft Signed-off-by: jayzhan211 <[email protected]> * fix convert_first_level_array_to_scalar_vec Signed-off-by: jayzhan211 <[email protected]> * add doc Signed-off-by: jayzhan211 <[email protected]> * fix nth Signed-off-by: jayzhan211 <[email protected]> * support distinct Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * rm convert_first_level_array_to_scalar_vec Signed-off-by: jayzhan211 <[email protected]> * add doc and assertion Signed-off-by: jayzhan211 <[email protected]> * fix doc Signed-off-by: jayzhan211 <[email protected]> * fix doc Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent 0e728fc commit fc84a63

File tree

6 files changed

+194
-123
lines changed

6 files changed

+194
-123
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ use arrow::{
5252
UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION,
5353
},
5454
};
55-
use arrow_array::cast::as_list_array;
5655
use arrow_array::{ArrowNativeTypeOp, Scalar};
5756

5857
pub use struct_builder::ScalarStructBuilder;
@@ -2138,28 +2137,67 @@ impl ScalarValue {
21382137

21392138
/// Retrieve ScalarValue for each row in `array`
21402139
///
2141-
/// Example
2140+
/// Example 1: Array (ScalarValue::Int32)
21422141
/// ```
21432142
/// use datafusion_common::ScalarValue;
21442143
/// use arrow::array::ListArray;
21452144
/// use arrow::datatypes::{DataType, Int32Type};
21462145
///
2146+
/// // Equivalent to [[1,2,3], [4,5]]
21472147
/// let list_arr = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
21482148
/// Some(vec![Some(1), Some(2), Some(3)]),
2149-
/// None,
21502149
/// Some(vec![Some(4), Some(5)])
21512150
/// ]);
21522151
///
2152+
/// // Convert the array into Scalar Values for each row
21532153
/// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap();
21542154
///
21552155
/// let expected = vec![
2156-
/// vec![
2156+
/// vec![
21572157
/// ScalarValue::Int32(Some(1)),
21582158
/// ScalarValue::Int32(Some(2)),
21592159
/// ScalarValue::Int32(Some(3)),
2160+
/// ],
2161+
/// vec![
2162+
/// ScalarValue::Int32(Some(4)),
2163+
/// ScalarValue::Int32(Some(5)),
2164+
/// ],
2165+
/// ];
2166+
///
2167+
/// assert_eq!(scalar_vec, expected);
2168+
/// ```
2169+
///
2170+
/// Example 2: Nested array (ScalarValue::List)
2171+
/// ```
2172+
/// use datafusion_common::ScalarValue;
2173+
/// use arrow::array::ListArray;
2174+
/// use arrow::datatypes::{DataType, Int32Type};
2175+
/// use datafusion_common::utils::array_into_list_array;
2176+
/// use std::sync::Arc;
2177+
///
2178+
/// let list_arr = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
2179+
/// Some(vec![Some(1), Some(2), Some(3)]),
2180+
/// Some(vec![Some(4), Some(5)])
2181+
/// ]);
2182+
///
2183+
/// // Wrap into another layer of list, we got nested array as [ [[1,2,3], [4,5]] ]
2184+
/// let list_arr = array_into_list_array(Arc::new(list_arr));
2185+
///
2186+
/// // Convert the array into Scalar Values for each row, we got 1D arrays in this example
2187+
/// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap();
2188+
///
2189+
/// let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
2190+
/// Some(vec![Some(1), Some(2), Some(3)]),
2191+
/// ]);
2192+
/// let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
2193+
/// Some(vec![Some(4), Some(5)]),
2194+
/// ]);
2195+
///
2196+
/// let expected = vec![
2197+
/// vec![
2198+
/// ScalarValue::List(Arc::new(l1)),
2199+
/// ScalarValue::List(Arc::new(l2)),
21602200
/// ],
2161-
/// vec![],
2162-
/// vec![ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5))]
21632201
/// ];
21642202
///
21652203
/// assert_eq!(scalar_vec, expected);
@@ -2168,27 +2206,13 @@ impl ScalarValue {
21682206
let mut scalars = Vec::with_capacity(array.len());
21692207

21702208
for index in 0..array.len() {
2171-
let scalar_values = match array.data_type() {
2172-
DataType::List(_) => {
2173-
let list_array = as_list_array(array);
2174-
match list_array.is_null(index) {
2175-
true => Vec::new(),
2176-
false => {
2177-
let nested_array = list_array.value(index);
2178-
ScalarValue::convert_array_to_scalar_vec(&nested_array)?
2179-
.into_iter()
2180-
.flatten()
2181-
.collect()
2182-
}
2183-
}
2184-
}
2185-
_ => {
2186-
let scalar = ScalarValue::try_from_array(array, index)?;
2187-
vec![scalar]
2188-
}
2189-
};
2209+
let nested_array = array.as_list::<i32>().value(index);
2210+
let scalar_values = (0..nested_array.len())
2211+
.map(|i| ScalarValue::try_from_array(&nested_array, i))
2212+
.collect::<Result<Vec<_>>>()?;
21902213
scalars.push(scalar_values);
21912214
}
2215+
21922216
Ok(scalars)
21932217
}
21942218

datafusion/core/tests/sql/aggregates.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
4444
// We should have 1 row containing a list
4545
let column = actual[0].column(0);
4646
assert_eq!(column.len(), 1);
47-
4847
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?;
4948
let mut scalars = scalar_vec[0].clone();
49+
5050
// workaround lack of Ord of ScalarValue
5151
let cmp = |a: &ScalarValue, b: &ScalarValue| {
5252
a.partial_cmp(b).expect("Can compare ScalarValues")

datafusion/physical-expr/src/aggregate/array_agg_distinct.rs

Lines changed: 80 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::sync::Arc;
2424

2525
use arrow::array::ArrayRef;
2626
use arrow::datatypes::{DataType, Field};
27+
use arrow_array::cast::AsArray;
2728

2829
use crate::aggregate::utils::down_cast_any_ref;
2930
use crate::expressions::format_state_name;
@@ -138,9 +139,10 @@ impl Accumulator for DistinctArrayAggAccumulator {
138139
assert_eq!(values.len(), 1, "batch input should only include 1 column!");
139140

140141
let array = &values[0];
141-
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(array)?;
142-
for scalars in scalar_vec {
143-
self.values.extend(scalars);
142+
143+
for i in 0..array.len() {
144+
let scalar = ScalarValue::try_from_array(&array, i)?;
145+
self.values.insert(scalar);
144146
}
145147

146148
Ok(())
@@ -151,7 +153,12 @@ impl Accumulator for DistinctArrayAggAccumulator {
151153
return Ok(());
152154
}
153155

154-
self.update_batch(states)
156+
let array = &states[0];
157+
158+
assert_eq!(array.len(), 1, "state array should only include 1 row!");
159+
// Unwrap outer ListArray then do update batch
160+
let inner_array = array.as_list::<i32>().value(0);
161+
self.update_batch(&[inner_array])
155162
}
156163

157164
fn evaluate(&mut self) -> Result<ScalarValue> {
@@ -181,47 +188,55 @@ mod tests {
181188
use arrow_array::Array;
182189
use arrow_array::ListArray;
183190
use arrow_buffer::OffsetBuffer;
184-
use datafusion_common::utils::array_into_list_array;
185191
use datafusion_common::{internal_err, DataFusionError};
186192

187-
// arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray.
188-
fn sort_list_inner(arr: ScalarValue) -> ScalarValue {
189-
let arr = match arr {
190-
ScalarValue::List(arr) => arr.value(0),
191-
_ => {
192-
panic!("Expected ScalarValue::List, got {:?}", arr)
193-
}
194-
};
193+
// arrow::compute::sort can't sort nested ListArray directly, so we compare the scalar values pair-wise.
194+
fn compare_list_contents(
195+
expected: Vec<ScalarValue>,
196+
actual: ScalarValue,
197+
) -> Result<()> {
198+
let array = actual.to_array()?;
199+
let list_array = array.as_list::<i32>();
200+
let inner_array = list_array.value(0);
201+
let mut actual_scalars = vec![];
202+
for index in 0..inner_array.len() {
203+
let sv = ScalarValue::try_from_array(&inner_array, index)?;
204+
actual_scalars.push(sv);
205+
}
195206

196-
let arr = arrow::compute::sort(&arr, None).unwrap();
197-
let list_arr = array_into_list_array(arr);
198-
ScalarValue::List(Arc::new(list_arr))
199-
}
207+
if actual_scalars.len() != expected.len() {
208+
return internal_err!(
209+
"Expected and actual list lengths differ: expected={}, actual={}",
210+
expected.len(),
211+
actual_scalars.len()
212+
);
213+
}
200214

201-
fn compare_list_contents(expected: ScalarValue, actual: ScalarValue) -> Result<()> {
202-
let actual = sort_list_inner(actual);
203-
204-
match (&expected, &actual) {
205-
(ScalarValue::List(arr1), ScalarValue::List(arr2)) => {
206-
if arr1.eq(arr2) {
207-
Ok(())
208-
} else {
209-
internal_err!(
210-
"Actual value {:?} not found in expected values {:?}",
211-
actual,
212-
expected
213-
)
215+
let mut seen = vec![false; expected.len()];
216+
for v in expected {
217+
let mut found = false;
218+
for (i, sv) in actual_scalars.iter().enumerate() {
219+
if sv == &v {
220+
seen[i] = true;
221+
found = true;
222+
break;
214223
}
215224
}
216-
_ => {
217-
internal_err!("Expected scalar lists as inputs")
225+
if !found {
226+
return internal_err!(
227+
"Expected value {:?} not found in actual values {:?}",
228+
v,
229+
actual_scalars
230+
);
218231
}
219232
}
233+
234+
Ok(())
220235
}
221236

222237
fn check_distinct_array_agg(
223238
input: ArrayRef,
224-
expected: ScalarValue,
239+
expected: Vec<ScalarValue>,
225240
datatype: DataType,
226241
) -> Result<()> {
227242
let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]);
@@ -234,14 +249,13 @@ mod tests {
234249
true,
235250
));
236251
let actual = aggregate(&batch, agg)?;
237-
238252
compare_list_contents(expected, actual)
239253
}
240254

241255
fn check_merge_distinct_array_agg(
242256
input1: ArrayRef,
243257
input2: ArrayRef,
244-
expected: ScalarValue,
258+
expected: Vec<ScalarValue>,
245259
datatype: DataType,
246260
) -> Result<()> {
247261
let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]);
@@ -262,23 +276,20 @@ mod tests {
262276
accum1.merge_batch(&[array])?;
263277

264278
let actual = accum1.evaluate()?;
265-
266279
compare_list_contents(expected, actual)
267280
}
268281

269282
#[test]
270283
fn distinct_array_agg_i32() -> Result<()> {
271284
let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2]));
272-
let expected =
273-
ScalarValue::List(Arc::new(
274-
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![Some(vec![
275-
Some(1),
276-
Some(2),
277-
Some(4),
278-
Some(5),
279-
Some(7),
280-
])]),
281-
));
285+
286+
let expected = vec![
287+
ScalarValue::Int32(Some(1)),
288+
ScalarValue::Int32(Some(2)),
289+
ScalarValue::Int32(Some(4)),
290+
ScalarValue::Int32(Some(5)),
291+
ScalarValue::Int32(Some(7)),
292+
];
282293

283294
check_distinct_array_agg(col, expected, DataType::Int32)
284295
}
@@ -288,18 +299,15 @@ mod tests {
288299
let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2]));
289300
let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4]));
290301

291-
let expected =
292-
ScalarValue::List(Arc::new(
293-
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![Some(vec![
294-
Some(1),
295-
Some(2),
296-
Some(3),
297-
Some(4),
298-
Some(5),
299-
Some(7),
300-
Some(8),
301-
])]),
302-
));
302+
let expected = vec![
303+
ScalarValue::Int32(Some(1)),
304+
ScalarValue::Int32(Some(2)),
305+
ScalarValue::Int32(Some(3)),
306+
ScalarValue::Int32(Some(4)),
307+
ScalarValue::Int32(Some(5)),
308+
ScalarValue::Int32(Some(7)),
309+
ScalarValue::Int32(Some(8)),
310+
];
303311

304312
check_merge_distinct_array_agg(col1, col2, expected, DataType::Int32)
305313
}
@@ -351,23 +359,16 @@ mod tests {
351359
let l2 = ScalarValue::List(Arc::new(l2));
352360
let l3 = ScalarValue::List(Arc::new(l3));
353361

354-
// Duplicate l1 in the input array and check that it is deduped in the output.
355-
let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3, l1]).unwrap();
356-
357-
let expected =
358-
ScalarValue::List(Arc::new(
359-
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![Some(vec![
360-
Some(1),
361-
Some(2),
362-
Some(3),
363-
Some(4),
364-
Some(5),
365-
Some(6),
366-
Some(7),
367-
Some(8),
368-
Some(9),
369-
])]),
370-
));
362+
// Duplicate l1 and l3 in the input array and check that it is deduped in the output.
363+
let array = ScalarValue::iter_to_array(vec![
364+
l1.clone(),
365+
l2.clone(),
366+
l3.clone(),
367+
l3.clone(),
368+
l1.clone(),
369+
])
370+
.unwrap();
371+
let expected = vec![l1, l2, l3];
371372

372373
check_distinct_array_agg(
373374
array,
@@ -426,22 +427,10 @@ mod tests {
426427
let l3 = ScalarValue::List(Arc::new(l3));
427428

428429
// Duplicate l1 in the input array and check that it is deduped in the output.
429-
let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2]).unwrap();
430-
let input2 = ScalarValue::iter_to_array(vec![l1, l3]).unwrap();
431-
432-
let expected =
433-
ScalarValue::List(Arc::new(
434-
ListArray::from_iter_primitive::<Int32Type, _, _>(vec![Some(vec![
435-
Some(1),
436-
Some(2),
437-
Some(3),
438-
Some(4),
439-
Some(5),
440-
Some(6),
441-
Some(7),
442-
Some(8),
443-
])]),
444-
));
430+
let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2.clone()]).unwrap();
431+
let input2 = ScalarValue::iter_to_array(vec![l1.clone(), l3.clone()]).unwrap();
432+
433+
let expected = vec![l1, l2, l3];
445434

446435
check_merge_distinct_array_agg(input1, input2, expected, DataType::Int32)
447436
}

0 commit comments

Comments
 (0)