Skip to content

Commit ce835da

Browse files
authored
Add StateFieldsArgs::return_field (#16112)
1 parent dc8161e commit ce835da

File tree

5 files changed

+12
-5
lines changed

5 files changed

+12
-5
lines changed

datafusion-examples/examples/advanced_udaf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ impl AggregateUDFImpl for GeoMeanUdaf {
9494
/// This is the description of the state. accumulator's state() must match the types here.
9595
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
9696
Ok(vec![
97-
Field::new("prod", args.return_field.data_type().clone(), true),
97+
Field::new("prod", args.return_type().clone(), true),
9898
Field::new("n", DataType::UInt32, true),
9999
])
100100
}

datafusion/functions-aggregate-common/src/accumulator.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,10 @@ pub struct StateFieldsArgs<'a> {
100100
/// Whether the aggregate function is distinct.
101101
pub is_distinct: bool,
102102
}
103+
104+
impl StateFieldsArgs<'_> {
105+
/// The return type of the aggregate function.
106+
pub fn return_type(&self) -> &DataType {
107+
self.return_field.data_type()
108+
}
109+
}

datafusion/functions-aggregate/src/bit_and_or_xor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ impl AggregateUDFImpl for BitwiseOperation {
271271
format!("{} distinct", self.name()).as_str(),
272272
),
273273
// See COMMENTS.md to understand why nullable is set to true
274-
Field::new_list_field(args.return_field.data_type().clone(), true),
274+
Field::new_list_field(args.return_type().clone(), true),
275275
false,
276276
)])
277277
} else {

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ impl AggregateUDFImpl for FirstValue {
172172
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
173173
let mut fields = vec![Field::new(
174174
format_state_name(args.name, "first_value"),
175-
args.return_field.data_type().clone(),
175+
args.return_type().clone(),
176176
true,
177177
)];
178178
fields.extend(args.ordering_fields.to_vec());

datafusion/functions-aggregate/src/sum.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,13 @@ impl AggregateUDFImpl for Sum {
206206
Ok(vec![Field::new_list(
207207
format_state_name(args.name, "sum distinct"),
208208
// See COMMENTS.md to understand why nullable is set to true
209-
Field::new_list_field(args.return_field.data_type().clone(), true),
209+
Field::new_list_field(args.return_type().clone(), true),
210210
false,
211211
)])
212212
} else {
213213
Ok(vec![Field::new(
214214
format_state_name(args.name, "sum"),
215-
args.return_field.data_type().clone(),
215+
args.return_type().clone(),
216216
true,
217217
)])
218218
}

0 commit comments

Comments
 (0)