Skip to content

Revert unecessary changes in first/last aggregator #7559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::record_batch::RecordBatch;
use arrow_array::{downcast_primitive, ArrayRef};
use arrow_schema::SchemaRef;
use datafusion_common::Result;
Expand Down Expand Up @@ -42,6 +43,9 @@ pub trait GroupValues: Send {

/// Emits the group values
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;

/// clear the contents and shrink the capacity to free up memory usage
fn clear_shrink(&mut self, batch: &RecordBatch);
}

pub fn new_group_values(schema: SchemaRef) -> Result<Box<dyn GroupValues>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use ahash::RandomState;
use arrow::array::BooleanBufferBuilder;
use arrow::buffer::NullBuffer;
use arrow::datatypes::i256;
use arrow::record_batch::RecordBatch;
use arrow_array::cast::AsArray;
use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray};
use arrow_schema::DataType;
Expand Down Expand Up @@ -206,4 +207,12 @@ where
};
Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))])
}

fn clear_shrink(&mut self, batch: &RecordBatch) {
let count = batch.num_rows();
self.values.clear();
self.values.shrink_to(count);
self.map.clear();
self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared
}
}
12 changes: 12 additions & 0 deletions datafusion/core/src/physical_plan/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use crate::physical_plan::aggregates::group_values::GroupValues;
use ahash::RandomState;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::ArrayRef;
use arrow_schema::SchemaRef;
Expand Down Expand Up @@ -181,4 +182,15 @@ impl GroupValues for GroupValuesRows {
}
})
}

fn clear_shrink(&mut self, batch: &RecordBatch) {
let count = batch.num_rows();
// FIXME: there is no good way to clear_shrink for self.group_values
self.group_values = self.row_converter.empty_rows(count, 0);
self.map.clear();
self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared
self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>();
self.hashes_buffer.clear();
self.hashes_buffer.shrink_to(count);
}
}
195 changes: 157 additions & 38 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,7 @@ mod tests {
use std::sync::Arc;
use std::task::{Context, Poll};

use datafusion_execution::config::SessionConfig;
use futures::{FutureExt, Stream};

// Generate a schema which consists of 5 columns (a, b, c, d, e)
Expand Down Expand Up @@ -1461,7 +1462,22 @@ mod tests {
)
}

async fn check_grouping_sets(input: Arc<dyn ExecutionPlan>) -> Result<()> {
fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
let session_config = SessionConfig::new().with_batch_size(batch_size);
let runtime = Arc::new(
RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(max_memory, 1.0))
.unwrap(),
);
let task_ctx = TaskContext::default()
.with_session_config(session_config)
.with_runtime(runtime);
Arc::new(task_ctx)
}

async fn check_grouping_sets(
input: Arc<dyn ExecutionPlan>,
spill: bool,
) -> Result<()> {
let input_schema = input.schema();

let grouping_set = PhysicalGroupBy {
Expand All @@ -1486,7 +1502,11 @@ mod tests {
DataType::Int64,
))];

let task_ctx = Arc::new(TaskContext::default());
let task_ctx = if spill {
new_spill_ctx(4, 1000)
} else {
Arc::new(TaskContext::default())
};

let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
Expand All @@ -1501,24 +1521,53 @@ mod tests {
let result =
common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?;

let expected = vec![
"+---+-----+-----------------+",
"| a | b | COUNT(1)[count] |",
"+---+-----+-----------------+",
"| | 1.0 | 2 |",
"| | 2.0 | 2 |",
"| | 3.0 | 2 |",
"| | 4.0 | 2 |",
"| 2 | | 2 |",
"| 2 | 1.0 | 2 |",
"| 3 | | 3 |",
"| 3 | 2.0 | 2 |",
"| 3 | 3.0 | 1 |",
"| 4 | | 3 |",
"| 4 | 3.0 | 1 |",
"| 4 | 4.0 | 2 |",
"+---+-----+-----------------+",
];
let expected = if spill {
vec![
"+---+-----+-----------------+",
"| a | b | COUNT(1)[count] |",
"+---+-----+-----------------+",
"| | 1.0 | 1 |",
"| | 1.0 | 1 |",
"| | 2.0 | 1 |",
"| | 2.0 | 1 |",
"| | 3.0 | 1 |",
"| | 3.0 | 1 |",
"| | 4.0 | 1 |",
"| | 4.0 | 1 |",
"| 2 | | 1 |",
"| 2 | | 1 |",
"| 2 | 1.0 | 1 |",
"| 2 | 1.0 | 1 |",
"| 3 | | 1 |",
"| 3 | | 2 |",
"| 3 | 2.0 | 2 |",
"| 3 | 3.0 | 1 |",
"| 4 | | 1 |",
"| 4 | | 2 |",
"| 4 | 3.0 | 1 |",
"| 4 | 4.0 | 2 |",
"+---+-----+-----------------+",
]
} else {
vec![
"+---+-----+-----------------+",
"| a | b | COUNT(1)[count] |",
"+---+-----+-----------------+",
"| | 1.0 | 2 |",
"| | 2.0 | 2 |",
"| | 3.0 | 2 |",
"| | 4.0 | 2 |",
"| 2 | | 2 |",
"| 2 | 1.0 | 2 |",
"| 3 | | 3 |",
"| 3 | 2.0 | 2 |",
"| 3 | 3.0 | 1 |",
"| 4 | | 3 |",
"| 4 | 3.0 | 1 |",
"| 4 | 4.0 | 2 |",
"+---+-----+-----------------+",
]
};
assert_batches_sorted_eq!(expected, &result);

let groups = partial_aggregate.group_expr().expr().to_vec();
Expand All @@ -1532,6 +1581,12 @@ mod tests {

let final_grouping_set = PhysicalGroupBy::new_single(final_group);

let task_ctx = if spill {
new_spill_ctx(4, 3160)
} else {
task_ctx
};

let merged_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
final_grouping_set,
Expand Down Expand Up @@ -1577,7 +1632,7 @@ mod tests {
}

/// build the aggregates on the data from some_data() and check the results
async fn check_aggregates(input: Arc<dyn ExecutionPlan>) -> Result<()> {
async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
let input_schema = input.schema();

let grouping_set = PhysicalGroupBy {
Expand All @@ -1592,7 +1647,11 @@ mod tests {
DataType::Float64,
))];

let task_ctx = Arc::new(TaskContext::default());
let task_ctx = if spill {
new_spill_ctx(2, 2144)
} else {
Arc::new(TaskContext::default())
};

let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
Expand All @@ -1607,15 +1666,29 @@ mod tests {
let result =
common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?;

let expected = [
"+---+---------------+-------------+",
"| a | AVG(b)[count] | AVG(b)[sum] |",
"+---+---------------+-------------+",
"| 2 | 2 | 2.0 |",
"| 3 | 3 | 7.0 |",
"| 4 | 3 | 11.0 |",
"+---+---------------+-------------+",
];
let expected = if spill {
vec![
"+---+---------------+-------------+",
"| a | AVG(b)[count] | AVG(b)[sum] |",
"+---+---------------+-------------+",
"| 2 | 1 | 1.0 |",
"| 2 | 1 | 1.0 |",
"| 3 | 1 | 2.0 |",
"| 3 | 2 | 5.0 |",
"| 4 | 3 | 11.0 |",
"+---+---------------+-------------+",
]
} else {
vec![
"+---+---------------+-------------+",
"| a | AVG(b)[count] | AVG(b)[sum] |",
"+---+---------------+-------------+",
"| 2 | 2 | 2.0 |",
"| 3 | 3 | 7.0 |",
"| 4 | 3 | 11.0 |",
"+---+---------------+-------------+",
]
};
assert_batches_sorted_eq!(expected, &result);

let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
Expand Down Expand Up @@ -1658,7 +1731,13 @@ mod tests {

let metrics = merged_aggregate.metrics().unwrap();
let output_rows = metrics.output_rows().unwrap();
assert_eq!(3, output_rows);
if spill {
// When spilling, the output rows metrics become partial output size + final output size
// This is because final aggregation starts while partial aggregation is still emitting
assert_eq!(8, output_rows);
} else {
assert_eq!(3, output_rows);
}

Ok(())
}
Expand Down Expand Up @@ -1779,31 +1858,63 @@ mod tests {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });

check_aggregates(input).await
check_aggregates(input, false).await
}

#[tokio::test]
async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });

check_grouping_sets(input).await
check_grouping_sets(input, false).await
}

#[tokio::test]
async fn aggregate_source_with_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });

check_aggregates(input).await
check_aggregates(input, false).await
}

#[tokio::test]
async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });

check_grouping_sets(input).await
check_grouping_sets(input, false).await
}

#[tokio::test]
async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });

check_aggregates(input, true).await
}

#[tokio::test]
async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });

check_grouping_sets(input, true).await
}

#[tokio::test]
async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });

check_aggregates(input, true).await
}

#[tokio::test]
async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });

check_grouping_sets(input, true).await
}

#[tokio::test]
Expand Down Expand Up @@ -1971,7 +2082,10 @@ mod tests {
async fn run_first_last_multi_partitions() -> Result<()> {
for use_coalesce_batches in [false, true] {
for is_first_acc in [false, true] {
first_last_multi_partitions(use_coalesce_batches, is_first_acc).await?
for spill in [false, true] {
first_last_multi_partitions(use_coalesce_batches, is_first_acc, spill)
.await?
}
}
}
Ok(())
Expand All @@ -1997,8 +2111,13 @@ mod tests {
async fn first_last_multi_partitions(
use_coalesce_batches: bool,
is_first_acc: bool,
spill: bool,
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let task_ctx = if spill {
new_spill_ctx(2, 2812)
} else {
Arc::new(TaskContext::default())
};

let (schema, data) = some_data_v2();
let partition1 = data[0].clone();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ impl GroupOrderingPartial {
Ok(())
}

/// Return the size of memor allocated by this structure
/// Return the size of memory allocated by this structure
pub(crate) fn size(&self) -> usize {
std::mem::size_of::<Self>()
+ self.order_indices.allocated_size()
Expand Down
Loading