Skip to content

Commit efd2fd2

Browse files
mustafasrepoalamb
andauthored
Return proper number of expressions for nth_value_agg (#9044)
* Return proper number of expressions * Fix bug * Update datafusion/proto/tests/cases/roundtrip_physical_plan.rs Co-authored-by: Andrew Lamb <[email protected]> * Fix argument type --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 7fb83cc commit efd2fd2

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

datafusion/physical-expr/src/aggregate/nth_value.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use std::sync::Arc;
2424

2525
use crate::aggregate::array_agg_ordered::merge_ordered_arrays;
2626
use crate::aggregate::utils::{down_cast_any_ref, ordering_fields};
27-
use crate::expressions::format_state_name;
27+
use crate::expressions::{format_state_name, Literal};
2828
use crate::{
2929
reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr,
3030
};
@@ -117,7 +117,8 @@ impl AggregateExpr for NthValueAgg {
117117
}
118118

119119
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
120-
vec![self.expr.clone()]
120+
let n = Arc::new(Literal::new(ScalarValue::Int64(Some(self.n)))) as _;
121+
vec![self.expr.clone(), n]
121122
}
122123

123124
fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
@@ -393,7 +394,9 @@ impl NthValueAccumulator {
393394
for index in 0..n_to_add {
394395
let row = get_row_at_idx(values, index)?;
395396
self.values.push_back(row[0].clone());
396-
self.ordering_values.push_back(row[1..].to_vec());
397+
// At index 1, we have n index argument.
398+
// Ordering values cover starting from 2nd index to end
399+
self.ordering_values.push_back(row[2..].to_vec());
397400
}
398401
Ok(())
399402
}

datafusion/proto/tests/cases/roundtrip_physical_plan.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ use datafusion::logical_expr::{
3636
};
3737
use datafusion::parquet::file::properties::WriterProperties;
3838
use datafusion::physical_expr::expressions::Literal;
39+
use datafusion::physical_expr::expressions::NthValueAgg;
3940
use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
4041
use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr};
4142
use datafusion::physical_plan::aggregates::{
@@ -337,17 +338,16 @@ fn rountrip_aggregate() -> Result<()> {
337338
"AVG(b)".to_string(),
338339
DataType::Float64,
339340
))],
340-
// TODO: See <https://github.com/apache/arrow-datafusion/issues/9028>
341-
// // NTH_VALUE
342-
// vec![Arc::new(NthValueAgg::new(
343-
// col("b", &schema)?,
344-
// 1,
345-
// "NTH_VALUE(b, 1)".to_string(),
346-
// DataType::Int64,
347-
// false,
348-
// Vec::new(),
349-
// Vec::new(),
350-
// ))],
341+
// NTH_VALUE
342+
vec![Arc::new(NthValueAgg::new(
343+
col("b", &schema)?,
344+
1,
345+
"NTH_VALUE(b, 1)".to_string(),
346+
DataType::Int64,
347+
false,
348+
Vec::new(),
349+
Vec::new(),
350+
))],
351351
// STRING_AGG
352352
vec![Arc::new(StringAgg::new(
353353
cast(col("b", &schema)?, &schema, DataType::Utf8)?,

0 commit comments

Comments
 (0)