Skip to content

Commit 04c9269

Browse files
committed
Added a pack example for the invoke_batch_with_return_type proposed change
1 parent eb15075 commit 04c9269

File tree

2 files changed

+314
-0
lines changed

2 files changed

+314
-0
lines changed

datafusion/functions-nested/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub mod map;
4747
pub mod map_extract;
4848
pub mod map_keys;
4949
pub mod map_values;
50+
pub mod pack;
5051
pub mod planner;
5152
pub mod position;
5253
pub mod range;
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
use std::sync::Arc;
2+
3+
use arrow::array::{Array, ArrayRef, StructArray};
4+
use arrow::datatypes::{DataType, Field, Fields};
5+
use datafusion_common::error::Result as DFResult;
6+
use datafusion_common::{exec_err, plan_err, ExprSchema};
7+
use datafusion_expr::expr::ScalarFunction;
8+
use datafusion_expr::ExprSchemable;
9+
use datafusion_expr::{
10+
ColumnarValue, Expr, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, Volatility,
11+
};
12+
use itertools::Itertools;
13+
14+
#[derive(Debug)]
15+
pub struct Pack {
16+
signature: Signature,
17+
names: Vec<String>,
18+
}
19+
20+
impl Pack {
21+
pub(crate) const NAME: &'static str = "struct.pack";
22+
23+
pub fn new<I>(names: I) -> Self
24+
where
25+
I: IntoIterator,
26+
I::Item: AsRef<str>,
27+
{
28+
Self {
29+
signature: Signature::one_of(
30+
vec![TypeSignature::Any(0), TypeSignature::VariadicAny],
31+
Volatility::Immutable,
32+
),
33+
names: names
34+
.into_iter()
35+
.map(|n| n.as_ref().to_string())
36+
.collect_vec(),
37+
}
38+
}
39+
40+
pub fn names(&self) -> &[String] {
41+
self.names.as_slice()
42+
}
43+
44+
pub fn new_instance<T>(
45+
names: impl IntoIterator<Item = T>,
46+
args: impl IntoIterator<Item = Expr>,
47+
) -> Expr
48+
where
49+
T: AsRef<str>,
50+
{
51+
Expr::ScalarFunction(ScalarFunction {
52+
func: Arc::new(ScalarUDF::new_from_impl(Pack::new(
53+
names
54+
.into_iter()
55+
.map(|n| n.as_ref().to_string())
56+
.collect_vec(),
57+
))),
58+
args: args.into_iter().collect_vec(),
59+
})
60+
}
61+
62+
pub fn new_instance_from_pair(
63+
pairs: impl IntoIterator<Item = (impl AsRef<str>, Expr)>,
64+
) -> Expr {
65+
let (names, args): (Vec<String>, Vec<Expr>) = pairs
66+
.into_iter()
67+
.map(|(k, v)| (k.as_ref().to_string(), v))
68+
.unzip();
69+
Expr::ScalarFunction(ScalarFunction {
70+
func: Arc::new(ScalarUDF::new_from_impl(Pack::new(names))),
71+
args,
72+
})
73+
}
74+
}
75+
76+
impl ScalarUDFImpl for Pack {
77+
fn as_any(&self) -> &dyn std::any::Any {
78+
self
79+
}
80+
81+
fn name(&self) -> &str {
82+
Self::NAME
83+
}
84+
85+
fn signature(&self) -> &Signature {
86+
&self.signature
87+
}
88+
89+
fn return_type(&self, arg_types: &[DataType]) -> DFResult<DataType> {
90+
todo!()
91+
}
92+
93+
// fn return_type(&self, arg_types: &[DataType]) -> DFResult<DataType> {
94+
// if self.names.len() != arg_types.len() {
95+
// return plan_err!("The number of arguments provided argument must equal the number of expected field names");
96+
// }
97+
//
98+
// let fields = self
99+
// .names
100+
// .iter()
101+
// .zip(arg_types.iter())
102+
// This is how ee currently set nullability
103+
// .map(|(name, dt)| Field::new(name, dt.clone(), true))
104+
// .collect::<Fields>();
105+
//
106+
// Ok(DataType::Struct(fields))
107+
// }
108+
109+
fn invoke_batch(&self, args: &[ColumnarValue], number_rows: usize) -> DFResult<ColumnarValue> {
110+
if number_rows == 0 {
111+
return Ok(ColumnarValue::Array(Arc::new(
112+
StructArray::new_empty_fields(number_rows, None),
113+
)))
114+
}
115+
116+
if self.names.len() != args.len() {
117+
return exec_err!("The number of arguments provided argument must equal the number of expected field names");
118+
}
119+
120+
let children = self
121+
.names
122+
.iter()
123+
.zip(args.iter())
124+
.map(|(name, arg)| {
125+
let arr = match arg {
126+
ColumnarValue::Array(array_value) => array_value.clone(),
127+
ColumnarValue::Scalar(scalar_value) => scalar_value.to_array()?,
128+
};
129+
130+
Ok((name.as_str(), arr))
131+
})
132+
.collect::<DFResult<Vec<_>>>()?;
133+
134+
let (fields, arrays): (Vec<_>, _) = children
135+
.into_iter()
136+
// Here I can either set nullability as true or dependent on the presence of nulls in the array,
137+
// both are not correct nullability is dependent on the schema and not a chunk of the data
138+
.map(|(name, array)| {
139+
(Field::new(name, array.data_type().clone(), true), array)
140+
})
141+
.unzip();
142+
143+
let struct_array = StructArray::try_new(fields.into(), arrays, None)?;
144+
145+
Ok(ColumnarValue::from(Arc::new(struct_array) as ArrayRef))
146+
}
147+
148+
// TODO(joe): support propagating nullability into invoke and therefore use the below method
149+
// see https://github.com/apache/datafusion/issues/12819
150+
fn return_type_from_exprs(
151+
&self,
152+
args: &[Expr],
153+
schema: &dyn ExprSchema,
154+
_arg_types: &[DataType],
155+
) -> DFResult<DataType> {
156+
if self.names.len() != args.len() {
157+
return plan_err!("The number of arguments provided argument must equal the number of expected field names");
158+
}
159+
160+
let fields = self
161+
.names
162+
.iter()
163+
.zip(args.iter())
164+
.map(|(name, expr)| {
165+
let (dt, null) = expr.data_type_and_nullable(schema)?;
166+
Ok(Field::new(name, dt, null))
167+
})
168+
.collect::<DFResult<Vec<Field>>>()?;
169+
170+
Ok(DataType::Struct(Fields::from(fields)))
171+
}
172+
173+
fn invoke_batch_with_return_type(
174+
&self,
175+
args: &[ColumnarValue],
176+
_number_rows: usize,
177+
return_type: &DataType,
178+
) -> DFResult<ColumnarValue> {
179+
if self.names.len() != args.len() {
180+
return exec_err!("The number of arguments provided argument must equal the number of expected field names");
181+
}
182+
183+
let fields = match return_type {
184+
DataType::Struct(fields) => fields.clone(),
185+
_ => {
186+
return exec_err!(
187+
"Return type must be a struct, however it was {:?}",
188+
return_type
189+
)
190+
}
191+
};
192+
193+
let children = fields
194+
.into_iter()
195+
.zip(args.iter())
196+
.map(|(name, arg)| {
197+
let arr = match arg {
198+
ColumnarValue::Array(array_value) => array_value.clone(),
199+
ColumnarValue::Scalar(scalar_value) => scalar_value.to_array()?,
200+
};
201+
202+
Ok((name.clone(), arr))
203+
})
204+
.collect::<DFResult<Vec<_>>>()?;
205+
206+
let struct_array = StructArray::from(children);
207+
208+
Ok(ColumnarValue::from(Arc::new(struct_array) as ArrayRef))
209+
}
210+
}
211+
212+
#[cfg(test)]
213+
mod tests {
214+
use std::collections::HashMap;
215+
use std::sync::Arc;
216+
217+
use crate::pack::Pack;
218+
use arrow::array::{ArrayRef, Int32Array};
219+
use arrow_array::Array;
220+
use arrow_buffer::NullBuffer;
221+
use arrow_schema::{DataType, Field, Fields};
222+
use datafusion_common::DFSchema;
223+
use datafusion_expr::{col, ColumnarValue, ScalarUDFImpl};
224+
225+
#[test]
226+
fn test_pack_not_null() {
227+
let a1 = Arc::new(Int32Array::from_iter_values_with_nulls(
228+
vec![1, 2],
229+
Some(NullBuffer::from([true, false].as_slice())),
230+
)) as ArrayRef;
231+
let schema = DFSchema::from_unqualified_fields(
232+
Fields::from([Arc::new(Field::new("a", DataType::Int32, true))].as_slice()),
233+
HashMap::new(),
234+
);
235+
let pack = Pack::new(vec!["a"]);
236+
237+
assert_eq!(
238+
DataType::Struct(Fields::from([Arc::new(Field::new(
239+
"a",
240+
DataType::Int32,
241+
true
242+
))])),
243+
pack.invoke_batch(&[ColumnarValue::Array(a1.clone())], a1.len())
244+
.unwrap()
245+
.data_type()
246+
);
247+
}
248+
249+
// Cannot have a return value of struct[("a", int32, null)], since the nullability is static
250+
#[test]
251+
// fails
252+
fn test_pack_null() {
253+
let a1 = Arc::new(Int32Array::from_iter_values(vec![1, 2]));
254+
let schema = DFSchema::from_unqualified_fields(
255+
Fields::from([Arc::new(Field::new("a", DataType::Int32, false))].as_slice()),
256+
HashMap::new(),
257+
);
258+
let pack = Pack::new(vec!["a"]);
259+
260+
assert_eq!(
261+
DataType::Struct(Fields::from([Arc::new(Field::new(
262+
"a",
263+
DataType::Int32,
264+
false
265+
))])),
266+
pack.invoke_batch(&[ColumnarValue::Array(a1.clone())], a1.len())
267+
.unwrap()
268+
.data_type()
269+
);
270+
}
271+
272+
#[test]
273+
fn test_pack_rt_null() {
274+
let a1 = Arc::new(Int32Array::from_iter_values(vec![1, 2])) as ArrayRef;
275+
let schema = DFSchema::from_unqualified_fields(
276+
Fields::from([Arc::new(Field::new("a", DataType::Int32, true))]),
277+
HashMap::new(),
278+
)
279+
.unwrap();
280+
let pack = Pack::new(vec!["a"]);
281+
282+
let rt = pack
283+
.return_type_from_exprs(&[col("a")], &schema, &[DataType::Int32])
284+
.unwrap();
285+
286+
let ret = pack
287+
.invoke_batch_with_return_type(&[ColumnarValue::Array(a1.clone())], a1.len(), &rt)
288+
.unwrap();
289+
290+
println!("{:?}", ret.into_array(1).unwrap().data_type());
291+
}
292+
293+
#[test]
294+
fn test_pack_rt_not_null() {
295+
let a1 = Arc::new(Int32Array::from_iter_values(vec![1, 2])) as ArrayRef;
296+
let schema = DFSchema::from_unqualified_fields(
297+
Fields::from([Arc::new(Field::new("a", DataType::Int32, false))]),
298+
HashMap::new(),
299+
)
300+
.unwrap();
301+
let pack = Pack::new(vec!["a"]);
302+
303+
let rt = pack
304+
.return_type_from_exprs(&[col("a")], &schema, &[DataType::Int32])
305+
.unwrap();
306+
307+
let ret = pack
308+
.invoke_batch_with_return_type(&[ColumnarValue::Array(a1.clone())], a1.len(), &rt)
309+
.unwrap();
310+
311+
println!("{:?}", ret.into_array(1).unwrap().data_type());
312+
}
313+
}

0 commit comments

Comments
 (0)