Skip to content

Commit ac640fb

Browse files
committed
distinct
Signed-off-by: jayzhan211 <[email protected]>
1 parent 9d17c1c commit ac640fb

File tree

8 files changed

+87
-497
lines changed

8 files changed

+87
-497
lines changed

datafusion/core/src/physical_planner.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1842,8 +1842,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
18421842
// TODO: Remove this after array_agg are all udafs
18431843
let (agg_expr, filter, order_by) = match func_def {
18441844
AggregateFunctionDefinition::UDF(udf)
1845-
if udf.name() == "ARRAY_AGG" && (*distinct || order_by.is_some()) =>
1845+
if udf.name() == "ARRAY_AGG" && order_by.is_some() =>
18461846
{
1847+
// not yet support UDAF, fallback to builtin
18471848
let physical_sort_exprs = match order_by {
18481849
Some(exprs) => Some(create_physical_sort_exprs(
18491850
exprs,

datafusion/functions-aggregate/src/array_agg.rs

+78-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
//! Defines physical expressions that can evaluated at runtime during query execution
1919
20-
use arrow::array::{Array, ArrayRef};
20+
use arrow::array::{Array, ArrayRef, AsArray};
2121
use arrow::datatypes::DataType;
2222
use arrow_schema::Field;
2323

@@ -29,6 +29,7 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
2929
use datafusion_expr::utils::format_state_name;
3030
use datafusion_expr::AggregateUDFImpl;
3131
use datafusion_expr::{Accumulator, Signature, Volatility};
32+
use std::collections::HashSet;
3233
use std::sync::Arc;
3334

3435
make_udaf_expr_and_func!(
@@ -82,6 +83,14 @@ impl AggregateUDFImpl for ArrayAgg {
8283
}
8384

8485
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
86+
if args.is_distinct {
87+
return Ok(vec![Field::new_list(
88+
format_state_name(args.name, "distinct_array_agg"),
89+
Field::new("item", args.input_type.clone(), true),
90+
true,
91+
)]);
92+
}
93+
8594
Ok(vec![Field::new_list(
8695
format_state_name(args.name, "array_agg"),
8796
Field::new("item", args.input_type.clone(), true),
@@ -90,6 +99,12 @@ impl AggregateUDFImpl for ArrayAgg {
9099
}
91100

92101
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
102+
if acc_args.is_distinct {
103+
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
104+
acc_args.input_type,
105+
)?));
106+
}
107+
93108
Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?))
94109
}
95110
}
@@ -170,3 +185,65 @@ impl Accumulator for ArrayAggAccumulator {
170185
- std::mem::size_of_val(&self.datatype)
171186
}
172187
}
188+
189+
#[derive(Debug)]
190+
struct DistinctArrayAggAccumulator {
191+
values: HashSet<ScalarValue>,
192+
datatype: DataType,
193+
}
194+
195+
impl DistinctArrayAggAccumulator {
196+
pub fn try_new(datatype: &DataType) -> Result<Self> {
197+
Ok(Self {
198+
values: HashSet::new(),
199+
datatype: datatype.clone(),
200+
})
201+
}
202+
}
203+
204+
impl Accumulator for DistinctArrayAggAccumulator {
205+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
206+
Ok(vec![self.evaluate()?])
207+
}
208+
209+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
210+
assert_eq!(values.len(), 1, "batch input should only include 1 column!");
211+
212+
let array = &values[0];
213+
214+
for i in 0..array.len() {
215+
let scalar = ScalarValue::try_from_array(&array, i)?;
216+
self.values.insert(scalar);
217+
}
218+
219+
Ok(())
220+
}
221+
222+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
223+
if states.is_empty() {
224+
return Ok(());
225+
}
226+
227+
states[0]
228+
.as_list::<i32>()
229+
.iter()
230+
.flatten()
231+
.try_for_each(|val| self.update_batch(&[val]))
232+
}
233+
234+
fn evaluate(&mut self) -> Result<ScalarValue> {
235+
let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
236+
if values.is_empty() {
237+
return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
238+
}
239+
let arr = ScalarValue::new_list(&values, &self.datatype, true);
240+
Ok(ScalarValue::List(arr))
241+
}
242+
243+
fn size(&self) -> usize {
244+
std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
245+
- std::mem::size_of_val(&self.values)
246+
+ self.datatype.size()
247+
- std::mem::size_of_val(&self.datatype)
248+
}
249+
}

0 commit comments

Comments
 (0)