Skip to content

feat: metadata handling for aggregates and window functions #15911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl AggregateUDFImpl for GeoMeanUdaf {
/// This is the description of the state. accumulator's state() must match the types here.
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![
Field::new("prod", args.return_type.clone(), true),
Field::new("prod", args.return_field.data_type().clone(), true),
Field::new("n", DataType::UInt32, true),
])
}
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::HashMap;
use datafusion_common::{Result, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
use datafusion_expr::type_coercion::functions::fields_with_aggregate_udf;
use datafusion_expr::{
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
Expand Down Expand Up @@ -448,9 +448,9 @@ fn get_random_function(
if !args.is_empty() {
// Do type coercion first argument
let a = args[0].clone();
let dt = a.data_type(schema.as_ref()).unwrap();
let coerced = data_types_with_aggregate_udf(&[dt], udf).unwrap();
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
let dt = a.return_field(schema.as_ref()).unwrap();
let coerced = fields_with_aggregate_udf(&[dt], udf).unwrap();
args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap();
}
}

Expand Down
275 changes: 269 additions & 6 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
//! This module contains end to end demonstrations of creating
//! user defined aggregate functions

use std::any::Any;
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::mem::{size_of, size_of_val};
use std::sync::{
Expand All @@ -26,10 +28,10 @@ use std::sync::{
};

use arrow::array::{
types::UInt64Type, AsArray, Int32Array, PrimitiveArray, StringArray, StructArray,
record_batch, types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray,
StringArray, StructArray, UInt64Array,
};
use arrow::datatypes::{Fields, Schema};

use datafusion::common::test_util::batches_to_string;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::MemTable;
Expand All @@ -48,11 +50,12 @@ use datafusion::{
prelude::SessionContext,
scalar::ScalarValue,
};
use datafusion_common::assert_contains;
use datafusion_common::{assert_contains, exec_datafusion_err};
use datafusion_common::{cast::as_primitive_array, exec_err};
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
LogicalPlanBuilder, SimpleAggregateUDF,
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr,
GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::average::AvgAccumulator;

Expand Down Expand Up @@ -781,7 +784,7 @@ struct TestGroupsAccumulator {
}

impl AggregateUDFImpl for TestGroupsAccumulator {
fn as_any(&self) -> &dyn std::any::Any {
fn as_any(&self) -> &dyn Any {
self
}

Expand Down Expand Up @@ -890,3 +893,263 @@ impl GroupsAccumulator for TestGroupsAccumulator {
size_of::<u64>()
}
}

#[derive(Debug)]
struct MetadataBasedAggregateUdf {
name: String,
signature: Signature,
metadata: HashMap<String, String>,
}

impl MetadataBasedAggregateUdf {
fn new(metadata: HashMap<String, String>) -> Self {
// The name we return must be unique. Otherwise we will not call distinct
// instances of this UDF. This is a small hack for the unit tests to get unique
// names, but you could do something more elegant with the metadata.
let name = format!("metadata_based_udf_{}", metadata.len());
Self {
name,
signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable),
metadata,
}
}
}

impl AggregateUDFImpl for MetadataBasedAggregateUdf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
unimplemented!("this should never be called since return_field is implemented");
}

fn return_field(&self, _arg_fields: &[Field]) -> Result<Field> {
Ok(Field::new(self.name(), DataType::UInt64, true)
.with_metadata(self.metadata.clone()))
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let input_expr = acc_args
.exprs
.first()
.ok_or(exec_datafusion_err!("Expected one argument"))?;
let input_field = input_expr.return_field(acc_args.schema)?;

let double_output = input_field
.metadata()
.get("modify_values")
.map(|v| v == "double_output")
.unwrap_or(false);

Ok(Box::new(MetadataBasedAccumulator {
double_output,
curr_sum: 0,
}))
}
}

#[derive(Debug)]
struct MetadataBasedAccumulator {
double_output: bool,
curr_sum: u64,
}

impl Accumulator for MetadataBasedAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let arr = values[0]
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or(exec_datafusion_err!("Expected UInt64Array"))?;

self.curr_sum = arr.iter().fold(self.curr_sum, |a, b| a + b.unwrap_or(0));

Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let v = match self.double_output {
true => self.curr_sum * 2,
false => self.curr_sum,
};

Ok(ScalarValue::from(v))
}

fn size(&self) -> usize {
9
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::from(self.curr_sum)])
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.update_batch(states)
}
}

#[tokio::test]
async fn test_metadata_based_aggregate() -> Result<()> {
let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef;
let schema = Arc::new(Schema::new(vec![
Field::new("no_metadata", DataType::UInt64, true),
Field::new("with_metadata", DataType::UInt64, true).with_metadata(
[("modify_values".to_string(), "double_output".to_string())]
.into_iter()
.collect(),
),
]));

let batch = RecordBatch::try_new(
schema,
vec![Arc::clone(&data_array), Arc::clone(&data_array)],
)?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let df = ctx.table("t").await?;

let no_output_meta_udf =
AggregateUDF::from(MetadataBasedAggregateUdf::new(HashMap::new()));
let with_output_meta_udf = AggregateUDF::from(MetadataBasedAggregateUdf::new(
[("output_metatype".to_string(), "custom_value".to_string())]
.into_iter()
.collect(),
));

let df = df.aggregate(
vec![],
vec![
no_output_meta_udf
.call(vec![col("no_metadata")])
.alias("meta_no_in_no_out"),
no_output_meta_udf
.call(vec![col("with_metadata")])
.alias("meta_with_in_no_out"),
with_output_meta_udf
.call(vec![col("no_metadata")])
.alias("meta_no_in_with_out"),
with_output_meta_udf
.call(vec![col("with_metadata")])
.alias("meta_with_in_with_out"),
],
)?;

let actual = df.collect().await?;

// To test for output metadata handling, we set the expected values on the result
// To test for input metadata handling, we check the numbers returned
let mut output_meta = HashMap::new();
let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string());
let expected_schema = Schema::new(vec![
Field::new("meta_no_in_no_out", DataType::UInt64, true),
Field::new("meta_with_in_no_out", DataType::UInt64, true),
Field::new("meta_no_in_with_out", DataType::UInt64, true)
.with_metadata(output_meta.clone()),
Field::new("meta_with_in_with_out", DataType::UInt64, true)
.with_metadata(output_meta.clone()),
]);

let expected = record_batch!(
("meta_no_in_no_out", UInt64, [50]),
("meta_with_in_no_out", UInt64, [100]),
("meta_no_in_with_out", UInt64, [50]),
("meta_with_in_with_out", UInt64, [100])
)?
.with_schema(Arc::new(expected_schema))?;

assert_eq!(expected, actual[0]);

Ok(())
}

#[tokio::test]
async fn test_metadata_based_aggregate_as_window() -> Result<()> {
let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef;
let schema = Arc::new(Schema::new(vec![
Field::new("no_metadata", DataType::UInt64, true),
Field::new("with_metadata", DataType::UInt64, true).with_metadata(
[("modify_values".to_string(), "double_output".to_string())]
.into_iter()
.collect(),
),
]));

let batch = RecordBatch::try_new(
schema,
vec![Arc::clone(&data_array), Arc::clone(&data_array)],
)?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let df = ctx.table("t").await?;

let no_output_meta_udf = Arc::new(AggregateUDF::from(
MetadataBasedAggregateUdf::new(HashMap::new()),
));
let with_output_meta_udf =
Arc::new(AggregateUDF::from(MetadataBasedAggregateUdf::new(
[("output_metatype".to_string(), "custom_value".to_string())]
.into_iter()
.collect(),
)));

let df = df.select(vec![
Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(Arc::clone(&no_output_meta_udf)),
vec![col("no_metadata")],
))
.alias("meta_no_in_no_out"),
Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(no_output_meta_udf),
vec![col("with_metadata")],
))
.alias("meta_with_in_no_out"),
Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(Arc::clone(&with_output_meta_udf)),
vec![col("no_metadata")],
))
.alias("meta_no_in_with_out"),
Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(with_output_meta_udf),
vec![col("with_metadata")],
))
.alias("meta_with_in_with_out"),
])?;

let actual = df.collect().await?;

// To test for output metadata handling, we set the expected values on the result
// To test for input metadata handling, we check the numbers returned
let mut output_meta = HashMap::new();
let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string());
let expected_schema = Schema::new(vec![
Field::new("meta_no_in_no_out", DataType::UInt64, true),
Field::new("meta_with_in_no_out", DataType::UInt64, true),
Field::new("meta_no_in_with_out", DataType::UInt64, true)
.with_metadata(output_meta.clone()),
Field::new("meta_with_in_with_out", DataType::UInt64, true)
.with_metadata(output_meta.clone()),
]);

let expected = record_batch!(
("meta_no_in_no_out", UInt64, [50, 50, 50, 50, 50]),
("meta_with_in_no_out", UInt64, [100, 100, 100, 100, 100]),
("meta_no_in_with_out", UInt64, [50, 50, 50, 50, 50]),
("meta_with_in_with_out", UInt64, [100, 100, 100, 100, 100])
)?
.with_schema(Arc::new(expected_schema))?;

assert_eq!(expected, actual[0]);

Ok(())
}
Loading
Loading