diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 819fcb7dd225..c1ee946df77e 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -94,7 +94,7 @@ impl AggregateUDFImpl for GeoMeanUdaf { /// This is the description of the state. accumulator's state() must match the types here. fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ - Field::new("prod", args.return_field.data_type().clone(), true), + Field::new("prod", args.return_type().clone(), true), Field::new("n", DataType::UInt32, true), ]) } diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs index 37bbd1508c91..eba4f6b70d2b 100644 --- a/datafusion/functions-aggregate-common/src/accumulator.rs +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -100,3 +100,10 @@ pub struct StateFieldsArgs<'a> { /// Whether the aggregate function is distinct. pub is_distinct: bool, } + +impl StateFieldsArgs<'_> { + /// The return type of the aggregate function. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 5f51377484a0..aeaeefcd7a72 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -271,7 +271,7 @@ impl AggregateUDFImpl for BitwiseOperation { format!("{} distinct", self.name()).as_str(), ), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_field.data_type().clone(), true), + Field::new_list_field(args.return_type().clone(), true), false, )]) } else { diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index ea2ec63711b4..8264d5fa74cb 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -172,7 +172,7 @@ impl AggregateUDFImpl for FirstValue { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( format_state_name(args.name, "first_value"), - args.return_field.data_type().clone(), + args.return_type().clone(), true, )]; fields.extend(args.ordering_fields.to_vec()); diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index a54d0af34693..aaa0c4b94a7f 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -206,13 +206,13 @@ impl AggregateUDFImpl for Sum { Ok(vec![Field::new_list( format_state_name(args.name, "sum distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_field.data_type().clone(), true), + Field::new_list_field(args.return_type().clone(), true), false, )]) } else { Ok(vec![Field::new( format_state_name(args.name, "sum"), - args.return_field.data_type().clone(), + args.return_type().clone(), true, )]) }