Skip to content

Remove element's nullability of array_agg function #11447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
*actual[0].schema(),
Schema::new(vec![Field::new_list(
"ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
Field::new("item", DataType::UInt32, false),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

array agg is NULL when there are no inputs, I believe after #11299 so this change makes sense to me

Field::new("item", DataType::UInt32, true),
true
),])
);
Expand Down
23 changes: 6 additions & 17 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use arrow_array::Array;
use datafusion_common::cast::as_list_array;
use datafusion_common::utils::array_into_list_array;
use datafusion_common::utils::array_into_list_array_nullable;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::Accumulator;
Expand All @@ -40,8 +40,6 @@ pub struct ArrayAgg {
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// If the input expression can have NULLs
nullable: bool,
}

impl ArrayAgg {
Expand All @@ -50,13 +48,11 @@ impl ArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
nullable: bool,
) -> Self {
Self {
name: name.into(),
input_data_type: data_type,
expr,
nullable,
}
}
}
Expand All @@ -70,22 +66,21 @@ impl AggregateExpr for ArrayAgg {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
Field::new("item", self.input_data_type.clone(), true),
true,
))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(ArrayAggAccumulator::try_new(
&self.input_data_type,
self.nullable,
)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
Field::new("item", self.input_data_type.clone(), true),
true,
)])
}
Expand Down Expand Up @@ -116,16 +111,14 @@ impl PartialEq<dyn Any> for ArrayAgg {
pub(crate) struct ArrayAggAccumulator {
values: Vec<ArrayRef>,
datatype: DataType,
nullable: bool,
}

impl ArrayAggAccumulator {
/// new array_agg accumulator based on given item data type
pub fn try_new(datatype: &DataType, nullable: bool) -> Result<Self> {
pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
values: vec![],
datatype: datatype.clone(),
nullable,
})
}
}
Expand Down Expand Up @@ -169,15 +162,11 @@ impl Accumulator for ArrayAggAccumulator {
self.values.iter().map(|a| a.as_ref()).collect();

if element_arrays.is_empty() {
return Ok(ScalarValue::new_null_list(
self.datatype.clone(),
self.nullable,
1,
));
return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
}

let concated_array = arrow::compute::concat(&element_arrays)?;
let list_array = array_into_list_array(concated_array, self.nullable);
let list_array = array_into_list_array_nullable(concated_array);

Ok(ScalarValue::List(Arc::new(list_array)))
}
Expand Down
23 changes: 5 additions & 18 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ pub struct DistinctArrayAgg {
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// If the input expression can have NULLs
nullable: bool,
}

impl DistinctArrayAgg {
Expand All @@ -52,14 +50,12 @@ impl DistinctArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
nullable: bool,
) -> Self {
let name = name.into();
Self {
name,
input_data_type,
expr,
nullable,
}
}
}
Expand All @@ -74,22 +70,21 @@ impl AggregateExpr for DistinctArrayAgg {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
Field::new("item", self.input_data_type.clone(), true),
true,
))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(DistinctArrayAggAccumulator::try_new(
&self.input_data_type,
self.nullable,
)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new_list(
format_state_name(&self.name, "distinct_array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
Field::new("item", self.input_data_type.clone(), true),
true,
)])
}
Expand Down Expand Up @@ -120,15 +115,13 @@ impl PartialEq<dyn Any> for DistinctArrayAgg {
struct DistinctArrayAggAccumulator {
values: HashSet<ScalarValue>,
datatype: DataType,
nullable: bool,
}

impl DistinctArrayAggAccumulator {
pub fn try_new(datatype: &DataType, nullable: bool) -> Result<Self> {
pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
values: HashSet::new(),
datatype: datatype.clone(),
nullable,
})
}
}
Expand Down Expand Up @@ -166,13 +159,9 @@ impl Accumulator for DistinctArrayAggAccumulator {
fn evaluate(&mut self) -> Result<ScalarValue> {
let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
if values.is_empty() {
return Ok(ScalarValue::new_null_list(
self.datatype.clone(),
self.nullable,
1,
));
return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
}
let arr = ScalarValue::new_list(&values, &self.datatype, self.nullable);
let arr = ScalarValue::new_list(&values, &self.datatype, true);
Ok(ScalarValue::List(arr))
}

Expand Down Expand Up @@ -255,7 +244,6 @@ mod tests {
col("a", &schema)?,
"bla".to_string(),
datatype,
true,
));
let actual = aggregate(&batch, agg)?;
compare_list_contents(expected, actual)
Expand All @@ -272,7 +260,6 @@ mod tests {
col("a", &schema)?,
"bla".to_string(),
datatype,
true,
));

let mut accum1 = agg.create_accumulator()?;
Expand Down
37 changes: 9 additions & 28 deletions datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use arrow::datatypes::{DataType, Field};
use arrow_array::cast::AsArray;
use arrow_array::{new_empty_array, Array, ArrayRef, StructArray};
use arrow_schema::Fields;
use datafusion_common::utils::{array_into_list_array, get_row_at_idx};
use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx};
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::utils::AggregateOrderSensitivity;
use datafusion_expr::Accumulator;
Expand All @@ -50,8 +50,6 @@ pub struct OrderSensitiveArrayAgg {
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// If the input expression can have `NULL`s
nullable: bool,
/// Ordering data types
order_by_data_types: Vec<DataType>,
/// Ordering requirement
Expand All @@ -66,15 +64,13 @@ impl OrderSensitiveArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
nullable: bool,
order_by_data_types: Vec<DataType>,
ordering_req: LexOrdering,
) -> Self {
Self {
name: name.into(),
input_data_type,
expr,
nullable,
order_by_data_types,
ordering_req,
reverse: false,
Expand All @@ -90,8 +86,8 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
// This should be the same as return type of AggregateFunction::OrderSensitiveArrayAgg
Field::new("item", self.input_data_type.clone(), true),
true,
))
}
Expand All @@ -102,25 +98,20 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
&self.order_by_data_types,
self.ordering_req.clone(),
self.reverse,
self.nullable,
)
.map(|acc| Box::new(acc) as _)
}

fn state_fields(&self) -> Result<Vec<Field>> {
let mut fields = vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
Field::new("item", self.input_data_type.clone(), true),
true, // This should be the same as field()
)];
let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types);
fields.push(Field::new_list(
format_state_name(&self.name, "array_agg_orderings"),
Field::new(
"item",
DataType::Struct(Fields::from(orderings)),
self.nullable,
),
Field::new("item", DataType::Struct(Fields::from(orderings)), true),
false,
));
Ok(fields)
Expand All @@ -147,7 +138,6 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
name: self.name.to_string(),
input_data_type: self.input_data_type.clone(),
expr: Arc::clone(&self.expr),
nullable: self.nullable,
order_by_data_types: self.order_by_data_types.clone(),
// Reverse requirement:
ordering_req: reverse_order_bys(&self.ordering_req),
Expand Down Expand Up @@ -186,8 +176,6 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
ordering_req: LexOrdering,
/// Whether the aggregation is running in reverse.
reverse: bool,
/// Whether the input expr is nullable
nullable: bool,
}

impl OrderSensitiveArrayAggAccumulator {
Expand All @@ -198,7 +186,6 @@ impl OrderSensitiveArrayAggAccumulator {
ordering_dtypes: &[DataType],
ordering_req: LexOrdering,
reverse: bool,
nullable: bool,
) -> Result<Self> {
let mut datatypes = vec![datatype.clone()];
datatypes.extend(ordering_dtypes.iter().cloned());
Expand All @@ -208,7 +195,6 @@ impl OrderSensitiveArrayAggAccumulator {
datatypes,
ordering_req,
reverse,
nullable,
})
}
}
Expand Down Expand Up @@ -312,7 +298,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
if self.values.is_empty() {
return Ok(ScalarValue::new_null_list(
self.datatypes[0].clone(),
self.nullable,
true,
1,
));
}
Expand All @@ -322,14 +308,10 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
ScalarValue::new_list_from_iter(
values.into_iter().rev(),
&self.datatypes[0],
self.nullable,
true,
)
} else {
ScalarValue::new_list_from_iter(
values.into_iter(),
&self.datatypes[0],
self.nullable,
)
ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true)
};
Ok(ScalarValue::List(array))
}
Expand Down Expand Up @@ -385,9 +367,8 @@ impl OrderSensitiveArrayAggAccumulator {
column_wise_ordering_values,
None,
)?;
Ok(ScalarValue::List(Arc::new(array_into_list_array(
Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable(
Arc::new(ordering_array),
self.nullable,
))))
}
}
Expand Down
12 changes: 2 additions & 10 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,14 @@ pub fn create_aggregate_expr(
Ok(match (fun, distinct) {
(AggregateFunction::ArrayAgg, false) => {
let expr = Arc::clone(&input_phy_exprs[0]);
let nullable = expr.nullable(input_schema)?;

if ordering_req.is_empty() {
Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable))
Arc::new(expressions::ArrayAgg::new(expr, name, data_type))
} else {
Arc::new(expressions::OrderSensitiveArrayAgg::new(
expr,
name,
data_type,
nullable,
ordering_types,
ordering_req.to_vec(),
))
Expand All @@ -84,13 +82,7 @@ pub fn create_aggregate_expr(
);
}
let expr = Arc::clone(&input_phy_exprs[0]);
let is_expr_nullable = expr.nullable(input_schema)?;
Arc::new(expressions::DistinctArrayAgg::new(
expr,
name,
data_type,
is_expr_nullable,
))
Arc::new(expressions::DistinctArrayAgg::new(expr, name, data_type))
}
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
Arc::clone(&input_phy_exprs[0]),
Expand Down
1 change: 0 additions & 1 deletion datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2231,7 +2231,6 @@ mod tests {
Arc::clone(col_a),
"array_agg",
DataType::Int32,
false,
vec![],
order_by_expr.unwrap_or_default(),
)) as _
Expand Down