@@ -39,7 +39,9 @@ use arrow::array::*;
39
39
use arrow:: datatypes:: SchemaRef ;
40
40
use arrow_schema:: SortOptions ;
41
41
use datafusion_common:: utils:: get_arrayref_at_indices;
42
- use datafusion_common:: { internal_datafusion_err, DataFusionError , Result } ;
42
+ use datafusion_common:: {
43
+ arrow_datafusion_err, internal_datafusion_err, DataFusionError , Result ,
44
+ } ;
43
45
use datafusion_execution:: disk_manager:: RefCountedTempFile ;
44
46
use datafusion_execution:: memory_pool:: proxy:: VecAllocExt ;
45
47
use datafusion_execution:: memory_pool:: { MemoryConsumer , MemoryReservation } ;
@@ -868,7 +870,7 @@ impl GroupedHashAggregateStream {
868
870
get_filter_at_indices ( opt_filter, & batch_indices)
869
871
} )
870
872
. collect :: < Result < Vec < _ > > > ( ) ?;
871
-
873
+
872
874
// Update the accumulators of each partition
873
875
for ( part_idx, part_start_end) in offsets. windows ( 2 ) . enumerate ( ) {
874
876
let ( offset, length) =
@@ -886,7 +888,8 @@ impl GroupedHashAggregateStream {
886
888
. map ( |array| array. slice ( offset, length) )
887
889
. collect :: < Vec < _ > > ( ) ;
888
890
889
- let part_opt_filter = opt_filter. as_ref ( ) . map ( |f| f. slice ( offset, length) ) ;
891
+ let part_opt_filter =
892
+ opt_filter. as_ref ( ) . map ( |f| f. slice ( offset, length) ) ;
890
893
let part_opt_filter =
891
894
part_opt_filter. as_ref ( ) . map ( |filter| filter. as_boolean ( ) ) ;
892
895
@@ -980,21 +983,29 @@ impl GroupedHashAggregateStream {
980
983
981
984
/// Create an output RecordBatch with the group keys and
982
985
/// accumulator states/values specified in emit_to
983
- fn emit ( & mut self , emit_to : EmitTo , spilling : bool ) -> Result < RecordBatch > {
984
- self . emit_single ( emit_to, spilling)
986
+ fn emit ( & mut self , emit_to : EmitTo , spilling : bool ) -> Result < Vec < RecordBatch > > {
987
+ if !self . group_values . is_partitioned ( ) {
988
+ self . emit_single ( emit_to, spilling)
989
+ } else {
990
+ self . emit_partitioned ( emit_to)
991
+ }
985
992
}
986
993
987
994
/// Create an output RecordBatch with the group keys and
988
995
/// accumulator states/values specified in emit_to
989
- fn emit_single ( & mut self , emit_to : EmitTo , spilling : bool ) -> Result < RecordBatch > {
996
+ fn emit_single (
997
+ & mut self ,
998
+ emit_to : EmitTo ,
999
+ spilling : bool ,
1000
+ ) -> Result < Vec < RecordBatch > > {
990
1001
let schema = if spilling {
991
1002
Arc :: clone ( & self . spill_state . spill_schema )
992
1003
} else {
993
1004
self . schema ( )
994
1005
} ;
995
1006
996
1007
if self . group_values . is_empty ( ) {
997
- return Ok ( RecordBatch :: new_empty ( schema ) ) ;
1008
+ return Ok ( Vec :: new ( ) ) ;
998
1009
}
999
1010
1000
1011
let group_values = self . group_values . as_single_mut ( ) ;
@@ -1023,10 +1034,11 @@ impl GroupedHashAggregateStream {
1023
1034
// over the target memory size after emission, we can emit again rather than returning Err.
1024
1035
let _ = self . update_memory_reservation ( ) ;
1025
1036
let batch = RecordBatch :: try_new ( schema, output) ?;
1026
- Ok ( batch)
1037
+
1038
+ Ok ( vec ! [ batch] )
1027
1039
}
1028
1040
1029
- fn emit_partitioned ( & mut self , emit_to : EmitTo , spilling : bool ) -> Result < Vec < RecordBatch > > {
1041
+ fn emit_partitioned ( & mut self , emit_to : EmitTo ) -> Result < Vec < RecordBatch > > {
1030
1042
assert ! (
1031
1043
self . mode == AggregateMode :: Partial
1032
1044
&& matches!( self . group_ordering, GroupOrdering :: None )
@@ -1035,22 +1047,35 @@ impl GroupedHashAggregateStream {
1035
1047
let schema = self . schema ( ) ;
1036
1048
1037
1049
if self . group_values . is_empty ( ) {
1038
- return Ok ( RecordBatch :: new_empty ( schema ) ) ;
1050
+ return Ok ( Vec :: new ( ) ) ;
1039
1051
}
1040
1052
1041
1053
let group_values = self . group_values . as_partitioned_mut ( ) ;
1042
- let mut output = group_values. emit ( emit_to) ?;
1054
+ let mut partitioned_outputs = group_values. emit ( emit_to) ?;
1043
1055
1044
1056
// Next output each aggregate value
1045
- for acc in self . accumulators [ 0 ] . iter_mut ( ) {
1046
- output. extend ( acc. state ( emit_to) ?)
1057
+ for ( output, accs) in partitioned_outputs
1058
+ . iter_mut ( )
1059
+ . zip ( self . accumulators . iter_mut ( ) )
1060
+ {
1061
+ for acc in accs. iter_mut ( ) {
1062
+ output. extend ( acc. state ( emit_to) ?) ;
1063
+ }
1047
1064
}
1048
1065
1049
1066
// emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is
1050
1067
// over the target memory size after emission, we can emit again rather than returning Err.
1051
1068
let _ = self . update_memory_reservation ( ) ;
1052
- let batch = RecordBatch :: try_new ( schema, output) ?;
1053
- Ok ( batch)
1069
+
1070
+ let batch_parts = partitioned_outputs
1071
+ . into_iter ( )
1072
+ . map ( |part| {
1073
+ RecordBatch :: try_new ( schema. clone ( ) , part)
1074
+ . map_err ( |e| arrow_datafusion_err ! ( e) )
1075
+ } )
1076
+ . collect :: < Result < Vec < _ > > > ( ) ?;
1077
+
1078
+ Ok ( batch_parts)
1054
1079
}
1055
1080
1056
1081
/// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly
0 commit comments