@@ -24,6 +24,7 @@ use std::sync::Arc;
24
24
25
25
use arrow:: array:: ArrayRef ;
26
26
use arrow:: datatypes:: { DataType , Field } ;
27
+ use arrow_array:: cast:: AsArray ;
27
28
28
29
use crate :: aggregate:: utils:: down_cast_any_ref;
29
30
use crate :: expressions:: format_state_name;
@@ -138,9 +139,10 @@ impl Accumulator for DistinctArrayAggAccumulator {
138
139
assert_eq ! ( values. len( ) , 1 , "batch input should only include 1 column!" ) ;
139
140
140
141
let array = & values[ 0 ] ;
141
- let scalar_vec = ScalarValue :: convert_array_to_scalar_vec ( array) ?;
142
- for scalars in scalar_vec {
143
- self . values . extend ( scalars) ;
142
+
143
+ for i in 0 ..array. len ( ) {
144
+ let scalar = ScalarValue :: try_from_array ( & array, i) ?;
145
+ self . values . insert ( scalar) ;
144
146
}
145
147
146
148
Ok ( ( ) )
@@ -151,7 +153,12 @@ impl Accumulator for DistinctArrayAggAccumulator {
151
153
return Ok ( ( ) ) ;
152
154
}
153
155
154
- self . update_batch ( states)
156
+ let array = & states[ 0 ] ;
157
+
158
+ assert_eq ! ( array. len( ) , 1 , "state array should only include 1 row!" ) ;
159
+ // Unwrap outer ListArray then do update batch
160
+ let inner_array = array. as_list :: < i32 > ( ) . value ( 0 ) ;
161
+ self . update_batch ( & [ inner_array] )
155
162
}
156
163
157
164
fn evaluate ( & mut self ) -> Result < ScalarValue > {
@@ -181,47 +188,55 @@ mod tests {
181
188
use arrow_array:: Array ;
182
189
use arrow_array:: ListArray ;
183
190
use arrow_buffer:: OffsetBuffer ;
184
- use datafusion_common:: utils:: array_into_list_array;
185
191
use datafusion_common:: { internal_err, DataFusionError } ;
186
192
187
- // arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray.
188
- fn sort_list_inner ( arr : ScalarValue ) -> ScalarValue {
189
- let arr = match arr {
190
- ScalarValue :: List ( arr) => arr. value ( 0 ) ,
191
- _ => {
192
- panic ! ( "Expected ScalarValue::List, got {:?}" , arr)
193
- }
194
- } ;
193
+ // arrow::compute::sort can't sort nested ListArray directly, so we compare the scalar values pair-wise.
194
+ fn compare_list_contents (
195
+ expected : Vec < ScalarValue > ,
196
+ actual : ScalarValue ,
197
+ ) -> Result < ( ) > {
198
+ let array = actual. to_array ( ) ?;
199
+ let list_array = array. as_list :: < i32 > ( ) ;
200
+ let inner_array = list_array. value ( 0 ) ;
201
+ let mut actual_scalars = vec ! [ ] ;
202
+ for index in 0 ..inner_array. len ( ) {
203
+ let sv = ScalarValue :: try_from_array ( & inner_array, index) ?;
204
+ actual_scalars. push ( sv) ;
205
+ }
195
206
196
- let arr = arrow:: compute:: sort ( & arr, None ) . unwrap ( ) ;
197
- let list_arr = array_into_list_array ( arr) ;
198
- ScalarValue :: List ( Arc :: new ( list_arr) )
199
- }
207
+ if actual_scalars. len ( ) != expected. len ( ) {
208
+ return internal_err ! (
209
+ "Expected and actual list lengths differ: expected={}, actual={}" ,
210
+ expected. len( ) ,
211
+ actual_scalars. len( )
212
+ ) ;
213
+ }
200
214
201
- fn compare_list_contents ( expected : ScalarValue , actual : ScalarValue ) -> Result < ( ) > {
202
- let actual = sort_list_inner ( actual) ;
203
-
204
- match ( & expected, & actual) {
205
- ( ScalarValue :: List ( arr1) , ScalarValue :: List ( arr2) ) => {
206
- if arr1. eq ( arr2) {
207
- Ok ( ( ) )
208
- } else {
209
- internal_err ! (
210
- "Actual value {:?} not found in expected values {:?}" ,
211
- actual,
212
- expected
213
- )
215
+ let mut seen = vec ! [ false ; expected. len( ) ] ;
216
+ for v in expected {
217
+ let mut found = false ;
218
+ for ( i, sv) in actual_scalars. iter ( ) . enumerate ( ) {
219
+ if sv == & v {
220
+ seen[ i] = true ;
221
+ found = true ;
222
+ break ;
214
223
}
215
224
}
216
- _ => {
217
- internal_err ! ( "Expected scalar lists as inputs" )
225
+ if !found {
226
+ return internal_err ! (
227
+ "Expected value {:?} not found in actual values {:?}" ,
228
+ v,
229
+ actual_scalars
230
+ ) ;
218
231
}
219
232
}
233
+
234
+ Ok ( ( ) )
220
235
}
221
236
222
237
fn check_distinct_array_agg (
223
238
input : ArrayRef ,
224
- expected : ScalarValue ,
239
+ expected : Vec < ScalarValue > ,
225
240
datatype : DataType ,
226
241
) -> Result < ( ) > {
227
242
let schema = Schema :: new ( vec ! [ Field :: new( "a" , datatype. clone( ) , false ) ] ) ;
@@ -234,14 +249,13 @@ mod tests {
234
249
true ,
235
250
) ) ;
236
251
let actual = aggregate ( & batch, agg) ?;
237
-
238
252
compare_list_contents ( expected, actual)
239
253
}
240
254
241
255
fn check_merge_distinct_array_agg (
242
256
input1 : ArrayRef ,
243
257
input2 : ArrayRef ,
244
- expected : ScalarValue ,
258
+ expected : Vec < ScalarValue > ,
245
259
datatype : DataType ,
246
260
) -> Result < ( ) > {
247
261
let schema = Schema :: new ( vec ! [ Field :: new( "a" , datatype. clone( ) , false ) ] ) ;
@@ -262,23 +276,20 @@ mod tests {
262
276
accum1. merge_batch ( & [ array] ) ?;
263
277
264
278
let actual = accum1. evaluate ( ) ?;
265
-
266
279
compare_list_contents ( expected, actual)
267
280
}
268
281
269
282
#[ test]
270
283
fn distinct_array_agg_i32 ( ) -> Result < ( ) > {
271
284
let col: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 7 , 4 , 5 , 2 ] ) ) ;
272
- let expected =
273
- ScalarValue :: List ( Arc :: new (
274
- ListArray :: from_iter_primitive :: < Int32Type , _ , _ > ( vec ! [ Some ( vec![
275
- Some ( 1 ) ,
276
- Some ( 2 ) ,
277
- Some ( 4 ) ,
278
- Some ( 5 ) ,
279
- Some ( 7 ) ,
280
- ] ) ] ) ,
281
- ) ) ;
285
+
286
+ let expected = vec ! [
287
+ ScalarValue :: Int32 ( Some ( 1 ) ) ,
288
+ ScalarValue :: Int32 ( Some ( 2 ) ) ,
289
+ ScalarValue :: Int32 ( Some ( 4 ) ) ,
290
+ ScalarValue :: Int32 ( Some ( 5 ) ) ,
291
+ ScalarValue :: Int32 ( Some ( 7 ) ) ,
292
+ ] ;
282
293
283
294
check_distinct_array_agg ( col, expected, DataType :: Int32 )
284
295
}
@@ -288,18 +299,15 @@ mod tests {
288
299
let col1: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 7 , 4 , 5 , 2 ] ) ) ;
289
300
let col2: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 3 , 7 , 8 , 4 ] ) ) ;
290
301
291
- let expected =
292
- ScalarValue :: List ( Arc :: new (
293
- ListArray :: from_iter_primitive :: < Int32Type , _ , _ > ( vec ! [ Some ( vec![
294
- Some ( 1 ) ,
295
- Some ( 2 ) ,
296
- Some ( 3 ) ,
297
- Some ( 4 ) ,
298
- Some ( 5 ) ,
299
- Some ( 7 ) ,
300
- Some ( 8 ) ,
301
- ] ) ] ) ,
302
- ) ) ;
302
+ let expected = vec ! [
303
+ ScalarValue :: Int32 ( Some ( 1 ) ) ,
304
+ ScalarValue :: Int32 ( Some ( 2 ) ) ,
305
+ ScalarValue :: Int32 ( Some ( 3 ) ) ,
306
+ ScalarValue :: Int32 ( Some ( 4 ) ) ,
307
+ ScalarValue :: Int32 ( Some ( 5 ) ) ,
308
+ ScalarValue :: Int32 ( Some ( 7 ) ) ,
309
+ ScalarValue :: Int32 ( Some ( 8 ) ) ,
310
+ ] ;
303
311
304
312
check_merge_distinct_array_agg ( col1, col2, expected, DataType :: Int32 )
305
313
}
@@ -351,23 +359,16 @@ mod tests {
351
359
let l2 = ScalarValue :: List ( Arc :: new ( l2) ) ;
352
360
let l3 = ScalarValue :: List ( Arc :: new ( l3) ) ;
353
361
354
- // Duplicate l1 in the input array and check that it is deduped in the output.
355
- let array = ScalarValue :: iter_to_array ( vec ! [ l1. clone( ) , l2, l3, l1] ) . unwrap ( ) ;
356
-
357
- let expected =
358
- ScalarValue :: List ( Arc :: new (
359
- ListArray :: from_iter_primitive :: < Int32Type , _ , _ > ( vec ! [ Some ( vec![
360
- Some ( 1 ) ,
361
- Some ( 2 ) ,
362
- Some ( 3 ) ,
363
- Some ( 4 ) ,
364
- Some ( 5 ) ,
365
- Some ( 6 ) ,
366
- Some ( 7 ) ,
367
- Some ( 8 ) ,
368
- Some ( 9 ) ,
369
- ] ) ] ) ,
370
- ) ) ;
362
+ // Duplicate l1 and l3 in the input array and check that it is deduped in the output.
363
+ let array = ScalarValue :: iter_to_array ( vec ! [
364
+ l1. clone( ) ,
365
+ l2. clone( ) ,
366
+ l3. clone( ) ,
367
+ l3. clone( ) ,
368
+ l1. clone( ) ,
369
+ ] )
370
+ . unwrap ( ) ;
371
+ let expected = vec ! [ l1, l2, l3] ;
371
372
372
373
check_distinct_array_agg (
373
374
array,
@@ -426,22 +427,10 @@ mod tests {
426
427
let l3 = ScalarValue :: List ( Arc :: new ( l3) ) ;
427
428
428
429
// Duplicate l1 in the input array and check that it is deduped in the output.
429
- let input1 = ScalarValue :: iter_to_array ( vec ! [ l1. clone( ) , l2] ) . unwrap ( ) ;
430
- let input2 = ScalarValue :: iter_to_array ( vec ! [ l1, l3] ) . unwrap ( ) ;
431
-
432
- let expected =
433
- ScalarValue :: List ( Arc :: new (
434
- ListArray :: from_iter_primitive :: < Int32Type , _ , _ > ( vec ! [ Some ( vec![
435
- Some ( 1 ) ,
436
- Some ( 2 ) ,
437
- Some ( 3 ) ,
438
- Some ( 4 ) ,
439
- Some ( 5 ) ,
440
- Some ( 6 ) ,
441
- Some ( 7 ) ,
442
- Some ( 8 ) ,
443
- ] ) ] ) ,
444
- ) ) ;
430
+ let input1 = ScalarValue :: iter_to_array ( vec ! [ l1. clone( ) , l2. clone( ) ] ) . unwrap ( ) ;
431
+ let input2 = ScalarValue :: iter_to_array ( vec ! [ l1. clone( ) , l3. clone( ) ] ) . unwrap ( ) ;
432
+
433
+ let expected = vec ! [ l1, l2, l3] ;
445
434
446
435
check_merge_distinct_array_agg ( input1, input2, expected, DataType :: Int32 )
447
436
}
0 commit comments