Skip to content

Commit d840e98

Browse files
authored
fix: serialize user-defined window functions to proto (#13421)
* Adds roundtrip physical plan test * Adds enum for udwf to `WindowFunction` * initial fix for serializing udwf * Revives deleted test * Adds codec methods for physical plan * Rewrite error message * Minor: rename binding + formatting fixes * Extends `PhysicalExtensionCodec` for udwf * Minor: formatting * Restricts visibility to tests
1 parent c51b432 commit d840e98

File tree

9 files changed

+272
-15
lines changed

9 files changed

+272
-15
lines changed

datafusion/physical-plan/src/windows/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ pub fn create_udwf_window_expr(
194194

195195
/// Implements [`BuiltInWindowFunctionExpr`] for [`WindowUDF`]
196196
#[derive(Clone, Debug)]
197-
struct WindowUDFExpr {
197+
pub struct WindowUDFExpr {
198198
fun: Arc<WindowUDF>,
199199
args: Vec<Arc<dyn PhysicalExpr>>,
200200
/// Display name
@@ -209,6 +209,12 @@ struct WindowUDFExpr {
209209
ignore_nulls: bool,
210210
}
211211

212+
impl WindowUDFExpr {
213+
pub fn fun(&self) -> &Arc<WindowUDF> {
214+
&self.fun
215+
}
216+
}
217+
212218
impl BuiltInWindowFunctionExpr for WindowUDFExpr {
213219
fn as_any(&self) -> &dyn std::any::Any {
214220
self

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,7 @@ message PhysicalWindowExprNode {
853853
oneof window_function {
854854
// BuiltInWindowFunction built_in_function = 2;
855855
string user_defined_aggr_function = 3;
856+
string user_defined_window_function = 10;
856857
}
857858
repeated PhysicalExprNode args = 4;
858859
repeated PhysicalExprNode partition_by = 5;

datafusion/proto/src/generated/pbjson.rs

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/generated/prost.rs

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/physical_plan/from_proto.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ pub fn parse_physical_window_expr(
152152
None => registry.udaf(udaf_name)?
153153
})
154154
}
155+
protobuf::physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(udwf_name) => {
156+
WindowFunctionDefinition::WindowUDF(match &proto.fun_definition {
157+
Some(buf) => codec.try_decode_udwf(udwf_name, buf)?,
158+
None => registry.udwf(udwf_name)?
159+
})
160+
}
155161
}
156162
} else {
157163
return Err(proto_error("Missing required field in protobuf"));

datafusion/proto/src/physical_plan/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ use datafusion::physical_plan::{
6464
ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr,
6565
};
6666
use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
67-
use datafusion_expr::{AggregateUDF, ScalarUDF};
67+
use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF};
6868

6969
use crate::common::{byte_to_string, str_to_byte};
7070
use crate::physical_plan::from_proto::{
@@ -2119,6 +2119,14 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync {
21192119
fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec<u8>) -> Result<()> {
21202120
Ok(())
21212121
}
2122+
2123+
fn try_decode_udwf(&self, name: &str, _buf: &[u8]) -> Result<Arc<WindowUDF>> {
2124+
not_impl_err!("PhysicalExtensionCodec is not provided for window function {name}")
2125+
}
2126+
2127+
fn try_encode_udwf(&self, _node: &WindowUDF, _buf: &mut Vec<u8>) -> Result<()> {
2128+
Ok(())
2129+
}
21222130
}
21232131

21242132
#[derive(Debug)]

datafusion/proto/src/physical_plan/to_proto.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ use std::sync::Arc;
1919

2020
#[cfg(feature = "parquet")]
2121
use datafusion::datasource::file_format::parquet::ParquetSink;
22-
use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
22+
use datafusion::physical_expr::window::{BuiltInWindowExpr, SlidingAggregateWindowExpr};
2323
use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr};
2424
use datafusion::physical_plan::expressions::{
2525
BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr,
2626
Literal, NegativeExpr, NotExpr, TryCastExpr,
2727
};
2828
use datafusion::physical_plan::udaf::AggregateFunctionExpr;
29-
use datafusion::physical_plan::windows::PlainAggregateWindowExpr;
29+
use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr};
3030
use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr};
3131
use datafusion::{
3232
datasource::{
@@ -68,7 +68,7 @@ pub fn serialize_physical_aggr_expr(
6868
ordering_req,
6969
distinct: aggr_expr.is_distinct(),
7070
ignore_nulls: aggr_expr.ignore_nulls(),
71-
fun_definition: (!buf.is_empty()).then_some(buf)
71+
fun_definition: (!buf.is_empty()).then_some(buf),
7272
},
7373
)),
7474
})
@@ -120,6 +120,25 @@ pub fn serialize_physical_window_expr(
120120
window_frame,
121121
codec,
122122
)?
123+
} else if let Some(built_in_window_expr) = expr.downcast_ref::<BuiltInWindowExpr>() {
124+
if let Some(expr) = built_in_window_expr
125+
.get_built_in_func_expr()
126+
.as_any()
127+
.downcast_ref::<WindowUDFExpr>()
128+
{
129+
let mut buf = Vec::new();
130+
codec.try_encode_udwf(expr.fun(), &mut buf)?;
131+
(
132+
physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(
133+
expr.fun().name().to_string(),
134+
),
135+
(!buf.is_empty()).then_some(buf),
136+
)
137+
} else {
138+
return not_impl_err!(
139+
"User-defined window function not supported: {window_expr:?}"
140+
);
141+
}
123142
} else {
124143
return not_impl_err!("WindowExpr not supported: {window_expr:?}");
125144
};

datafusion/proto/tests/cases/mod.rs

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,18 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use arrow::datatypes::{DataType, Field};
1819
use std::any::Any;
19-
20-
use arrow::datatypes::DataType;
20+
use std::fmt::Debug;
2121

2222
use datafusion_common::plan_err;
2323
use datafusion_expr::function::AccumulatorArgs;
2424
use datafusion_expr::{
25-
Accumulator, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, Signature, Volatility,
25+
Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl,
26+
Signature, Volatility, WindowUDFImpl,
2627
};
28+
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
29+
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
2730

2831
mod roundtrip_logical_plan;
2932
mod roundtrip_physical_plan;
@@ -125,3 +128,54 @@ pub struct MyAggregateUdfNode {
125128
#[prost(string, tag = "1")]
126129
pub result: String,
127130
}
131+
132+
#[derive(Debug)]
133+
pub(in crate::cases) struct CustomUDWF {
134+
signature: Signature,
135+
payload: String,
136+
}
137+
138+
impl CustomUDWF {
139+
pub fn new(payload: String) -> Self {
140+
Self {
141+
signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
142+
payload,
143+
}
144+
}
145+
}
146+
147+
impl WindowUDFImpl for CustomUDWF {
148+
fn as_any(&self) -> &dyn Any {
149+
self
150+
}
151+
152+
fn name(&self) -> &str {
153+
"custom_udwf"
154+
}
155+
156+
fn signature(&self) -> &Signature {
157+
&self.signature
158+
}
159+
160+
fn partition_evaluator(
161+
&self,
162+
_partition_evaluator_args: PartitionEvaluatorArgs,
163+
) -> datafusion_common::Result<Box<dyn PartitionEvaluator>> {
164+
Ok(Box::new(CustomUDWFEvaluator {}))
165+
}
166+
167+
fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result<Field> {
168+
Ok(Field::new(field_args.name(), DataType::UInt64, false))
169+
}
170+
}
171+
172+
#[derive(Debug)]
173+
struct CustomUDWFEvaluator;
174+
175+
impl PartitionEvaluator for CustomUDWFEvaluator {}
176+
177+
#[derive(Clone, PartialEq, ::prost::Message)]
178+
pub(in crate::cases) struct CustomUDWFNode {
179+
#[prost(string, tag = "1")]
180+
pub payload: String,
181+
}

0 commit comments

Comments
 (0)