From 76d7fcffdd9d8664a003b76754033ebad4a15847 Mon Sep 17 00:00:00 2001 From: Dan Lovell Date: Thu, 28 Dec 2023 15:10:46 -0500 Subject: [PATCH] feat: udaf: enable multiple column input (#546) --- datafusion/__init__.py | 2 ++ src/udaf.rs | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/datafusion/__init__.py b/datafusion/__init__.py index c854f3f9..df53b396 100644 --- a/datafusion/__init__.py +++ b/datafusion/__init__.py @@ -213,6 +213,8 @@ def udaf(accum, input_type, return_type, state_type, volatility, name=None): ) if name is None: name = accum.__qualname__.lower() + if isinstance(input_type, pa.lib.DataType): + input_type = [input_type] return AggregateUDF( name=name, accumulator=accum, diff --git a/src/udaf.rs b/src/udaf.rs index 5c43b671..0e7a8dea 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -148,14 +148,14 @@ impl PyAggregateUDF { fn new( name: &str, accumulator: PyObject, - input_type: PyArrowType, + input_type: PyArrowType>, return_type: PyArrowType, state_type: PyArrowType>, volatility: &str, ) -> PyResult { let function = create_udaf( name, - vec![input_type.0], + input_type.0, Arc::new(return_type.0), parse_volatility(volatility)?, to_rust_accumulator(accumulator),