Skip to content

Commit 315b8e9

Browse files
committed
replace parts of test
Signed-off-by: jayzhan211 <[email protected]>
1 parent 5d98c32 commit 315b8e9

File tree

5 files changed

+48
-89
lines changed

5 files changed

+48
-89
lines changed

datafusion/core/src/physical_optimizer/aggregate_statistics.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ pub(crate) mod tests {
420420
// Return appropriate expr depending if COUNT is for col or table (*)
421421
pub(crate) fn count_expr(&self, schema: &Schema) -> Arc<dyn AggregateExpr> {
422422
AggregateExprBuilder::new(count_udaf(), vec![self.column()])
423-
.schema(schema.clone())
423+
.schema(Arc::new(schema.clone()))
424424
.name(self.column_name())
425425
.build()
426426
.unwrap()

datafusion/physical-expr-common/src/aggregate/mod.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub mod stats;
2222
pub mod tdigest;
2323
pub mod utils;
2424

25-
use arrow::datatypes::{DataType, Field, Schema};
25+
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
2626
use datafusion_common::{internal_err, not_impl_err, DFSchema, Result};
2727
use datafusion_expr::function::StateFieldsArgs;
2828
use datafusion_expr::type_coercion::aggregates::check_arg_count;
@@ -75,7 +75,7 @@ pub fn create_aggregate_expr(
7575
builder = builder.sort_exprs(sort_exprs.to_vec());
7676
builder = builder.order_by(ordering_req.to_vec());
7777
builder = builder.logical_exprs(input_exprs.to_vec());
78-
builder = builder.schema(schema.clone());
78+
builder = builder.schema(Arc::new(schema.clone()));
7979
builder = builder.name(name);
8080

8181
if ignore_nulls {
@@ -109,7 +109,7 @@ pub fn create_aggregate_expr_with_dfschema(
109109
builder = builder.logical_exprs(input_exprs.to_vec());
110110
builder = builder.dfschema(dfschema.clone());
111111
let schema: Schema = dfschema.into();
112-
builder = builder.schema(schema);
112+
builder = builder.schema(Arc::new(schema));
113113
builder = builder.name(name);
114114

115115
if ignore_nulls {
@@ -134,7 +134,7 @@ pub struct AggregateExprBuilder {
134134
logical_args: Vec<Expr>,
135135
name: String,
136136
/// Arrow Schema for the aggregate function
137-
schema: Schema,
137+
schema: SchemaRef,
138138
/// Datafusion Schema for the aggregate function
139139
dfschema: DFSchema,
140140
/// The logical order by expressions, it will be deprecated in <https://github.com/apache/datafusion/issues/11359>
@@ -156,7 +156,7 @@ impl AggregateExprBuilder {
156156
args,
157157
logical_args: vec![],
158158
name: String::new(),
159-
schema: Schema::empty(),
159+
schema: Arc::new(Schema::empty()),
160160
dfschema: DFSchema::empty(),
161161
sort_exprs: vec![],
162162
ordering_req: vec![],
@@ -215,7 +215,7 @@ impl AggregateExprBuilder {
215215
logical_args,
216216
data_type,
217217
name,
218-
schema,
218+
schema: Arc::unwrap_or_clone(schema),
219219
dfschema,
220220
sort_exprs,
221221
ordering_req,
@@ -232,7 +232,7 @@ impl AggregateExprBuilder {
232232
self
233233
}
234234

235-
pub fn schema(mut self, schema: Schema) -> Self {
235+
pub fn schema(mut self, schema: SchemaRef) -> Self {
236236
self.schema = schema;
237237
self
238238
}

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 19 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,7 @@ mod tests {
12111211

12121212
use crate::common::collect;
12131213
use datafusion_physical_expr_common::aggregate::{
1214-
create_aggregate_expr, create_aggregate_expr_with_dfschema,
1214+
create_aggregate_expr, create_aggregate_expr_with_dfschema, AggregateExprBuilder,
12151215
};
12161216
use datafusion_physical_expr_common::expressions::Literal;
12171217
use futures::{FutureExt, Stream};
@@ -1351,18 +1351,11 @@ mod tests {
13511351
],
13521352
};
13531353

1354-
let aggregates = vec![create_aggregate_expr(
1355-
&count_udaf(),
1356-
&[lit(1i8)],
1357-
&[datafusion_expr::lit(1i8)],
1358-
&[],
1359-
&[],
1360-
&input_schema,
1361-
"COUNT(1)",
1362-
false,
1363-
false,
1364-
false,
1365-
)?];
1354+
let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
1355+
.schema(Arc::clone(&input_schema))
1356+
.name("COUNT(1)")
1357+
.logical_exprs(vec![datafusion_expr::lit(1i8)])
1358+
.build()?];
13661359

13671360
let task_ctx = if spill {
13681361
new_spill_ctx(4, 1000)
@@ -1501,18 +1494,13 @@ mod tests {
15011494
groups: vec![vec![false]],
15021495
};
15031496

1504-
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
1505-
&avg_udaf(),
1506-
&[col("b", &input_schema)?],
1507-
&[datafusion_expr::col("b")],
1508-
&[],
1509-
&[],
1510-
&input_schema,
1511-
"AVG(b)",
1512-
false,
1513-
false,
1514-
false,
1515-
)?];
1497+
let aggregates: Vec<Arc<dyn AggregateExpr>> =
1498+
vec![
1499+
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
1500+
.schema(Arc::clone(&input_schema))
1501+
.name("AVG(b)")
1502+
.build()?,
1503+
];
15161504

15171505
let task_ctx = if spill {
15181506
// set to an appropriate value to trigger spill
@@ -1803,21 +1791,11 @@ mod tests {
18031791
}
18041792

18051793
// Median(a)
1806-
fn test_median_agg_expr(schema: &Schema) -> Result<Arc<dyn AggregateExpr>> {
1807-
let args = vec![col("a", schema)?];
1808-
let fun = median_udaf();
1809-
datafusion_physical_expr_common::aggregate::create_aggregate_expr(
1810-
&fun,
1811-
&args,
1812-
&[],
1813-
&[],
1814-
&[],
1815-
schema,
1816-
"MEDIAN(a)",
1817-
false,
1818-
false,
1819-
false,
1820-
)
1794+
fn test_median_agg_expr(schema: SchemaRef) -> Result<Arc<dyn AggregateExpr>> {
1795+
AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
1796+
.schema(schema)
1797+
.name("MEDIAN(a)")
1798+
.build()
18211799
}
18221800

18231801
#[tokio::test]
@@ -1840,7 +1818,7 @@ mod tests {
18401818

18411819
// something that allocates within the aggregator
18421820
let aggregates_v0: Vec<Arc<dyn AggregateExpr>> =
1843-
vec![test_median_agg_expr(&input_schema)?];
1821+
vec![test_median_agg_expr(Arc::clone(&input_schema))?];
18441822

18451823
// use fast-path in `row_hash.rs`.
18461824
let aggregates_v2: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(

datafusion/proto/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ chrono = { workspace = true }
5050
datafusion = { workspace = true, default-features = true }
5151
datafusion-common = { workspace = true, default-features = true }
5252
datafusion-expr = { workspace = true }
53+
datafusion-physical-expr-common = { workspace = true }
5354
datafusion-proto-common = { workspace = true }
5455
object_store = { workspace = true }
5556
pbjson = { version = "0.6.0", optional = true }

datafusion/proto/tests/cases/roundtrip_physical_plan.rs

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::vec;
2424

2525
use arrow::array::RecordBatch;
2626
use arrow::csv::WriterBuilder;
27+
use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
2728
use prost::Message;
2829

2930
use datafusion::arrow::array::ArrayRef;
@@ -86,7 +87,7 @@ use datafusion_expr::{
8687
};
8788
use datafusion_functions_aggregate::average::avg_udaf;
8889
use datafusion_functions_aggregate::nth_value::nth_value_udaf;
89-
use datafusion_functions_aggregate::string_agg::StringAgg;
90+
use datafusion_functions_aggregate::string_agg::string_agg_udaf;
9091
use datafusion_proto::physical_plan::{
9192
AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
9293
};
@@ -357,49 +358,28 @@ fn rountrip_aggregate() -> Result<()> {
357358
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
358359
vec![(col("a", &schema)?, "unused".to_string())];
359360

361+
let avg_expr = AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
362+
.schema(Arc::clone(&schema))
363+
.name("AVG(b)")
364+
.build()?;
365+
let nth_expr =
366+
AggregateExprBuilder::new(nth_value_udaf(), vec![col("b", &schema)?, lit(1u64)])
367+
.schema(Arc::clone(&schema))
368+
.name("NTH_VALUE(b, 1)")
369+
.build()?;
370+
let str_agg_expr =
371+
AggregateExprBuilder::new(string_agg_udaf(), vec![col("b", &schema)?, lit(1u64)])
372+
.schema(Arc::clone(&schema))
373+
.name("NTH_VALUE(b, 1)")
374+
.build()?;
375+
360376
let test_cases: Vec<Vec<Arc<dyn AggregateExpr>>> = vec![
361377
// AVG
362-
vec![create_aggregate_expr(
363-
&avg_udaf(),
364-
&[col("b", &schema)?],
365-
&[],
366-
&[],
367-
&[],
368-
&schema,
369-
"AVG(b)",
370-
false,
371-
false,
372-
false,
373-
)?],
378+
vec![avg_expr],
374379
// NTH_VALUE
375-
vec![create_aggregate_expr(
376-
&nth_value_udaf(),
377-
&[col("b", &schema)?, lit(1u64)],
378-
&[],
379-
&[],
380-
&[],
381-
&schema,
382-
"NTH_VALUE(b, 1)",
383-
false,
384-
false,
385-
false,
386-
)?],
380+
vec![nth_expr],
387381
// STRING_AGG
388-
vec![create_aggregate_expr(
389-
&AggregateUDF::new_from_impl(StringAgg::new()),
390-
&[
391-
cast(col("b", &schema)?, &schema, DataType::Utf8)?,
392-
lit(ScalarValue::Utf8(Some(",".to_string()))),
393-
],
394-
&[],
395-
&[],
396-
&[],
397-
&schema,
398-
"STRING_AGG(name, ',')",
399-
false,
400-
false,
401-
false,
402-
)?],
382+
vec![str_agg_expr],
403383
];
404384

405385
for aggregates in test_cases {

0 commit comments

Comments
 (0)