Skip to content

Commit b0cfec3

Browse files
committed
added an example
1 parent 2930223 commit b0cfec3

File tree

2 files changed

+260
-0
lines changed

2 files changed

+260
-0
lines changed

datafusion/functions-nested/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ pub mod map;
4646
pub mod map_extract;
4747
pub mod map_keys;
4848
pub mod map_values;
49+
pub mod pack;
4950
pub mod planner;
5051
pub mod position;
5152
pub mod range;
+259
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
use std::sync::Arc;
2+
3+
use arrow::array::{Array, ArrayRef, StructArray};
4+
use arrow::datatypes::{DataType, Field, Fields};
5+
use itertools::Itertools;
6+
use datafusion_expr::{ColumnarValue, Expr, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, Volatility};
7+
use datafusion_common::error::Result as DFResult;
8+
use datafusion_common::{exec_err, plan_err, ExprSchema};
9+
use datafusion_expr::expr::ScalarFunction;
10+
use datafusion_expr::ExprSchemable;
11+
12+
#[derive(Debug)]
13+
pub struct Pack {
14+
signature: Signature,
15+
names: Vec<String>,
16+
}
17+
18+
impl Pack {
19+
pub(crate) const NAME: &'static str = "struct.pack";
20+
21+
pub fn new<I>(names: I) -> Self
22+
where
23+
I: IntoIterator,
24+
I::Item: AsRef<str>,
25+
{
26+
Self {
27+
signature: Signature::one_of(
28+
vec![TypeSignature::Any(0), TypeSignature::VariadicAny],
29+
Volatility::Immutable,
30+
),
31+
names: names
32+
.into_iter()
33+
.map(|n| n.as_ref().to_string())
34+
.collect_vec(),
35+
}
36+
}
37+
38+
pub fn names(&self) -> &[String] {
39+
self.names.as_slice()
40+
}
41+
42+
pub fn new_instance<T>(
43+
names: impl IntoIterator<Item = T>,
44+
args: impl IntoIterator<Item = Expr>,
45+
) -> Expr
46+
where
47+
T: AsRef<str>,
48+
{
49+
Expr::ScalarFunction(ScalarFunction {
50+
func: Arc::new(ScalarUDF::new_from_impl(Pack::new(
51+
names
52+
.into_iter()
53+
.map(|n| n.as_ref().to_string())
54+
.collect_vec(),
55+
))),
56+
args: args.into_iter().collect_vec(),
57+
})
58+
}
59+
60+
pub fn new_instance_from_pair(
61+
pairs: impl IntoIterator<Item = (impl AsRef<str>, Expr)>,
62+
) -> Expr {
63+
let (names, args): (Vec<String>, Vec<Expr>) = pairs
64+
.into_iter()
65+
.map(|(k, v)| (k.as_ref().to_string(), v))
66+
.unzip();
67+
Expr::ScalarFunction(ScalarFunction {
68+
func: Arc::new(ScalarUDF::new_from_impl(Pack::new(names))),
69+
args,
70+
})
71+
}
72+
}
73+
74+
impl ScalarUDFImpl for Pack {
75+
fn as_any(&self) -> &dyn std::any::Any {
76+
self
77+
}
78+
79+
fn name(&self) -> &str {
80+
Self::NAME
81+
}
82+
83+
fn signature(&self) -> &Signature {
84+
&self.signature
85+
}
86+
87+
fn return_type(&self, arg_types: &[DataType]) -> DFResult<DataType> {
88+
todo!()
89+
}
90+
91+
92+
// fn return_type(&self, arg_types: &[DataType]) -> DFResult<DataType> {
93+
// if self.names.len() != arg_types.len() {
94+
// return plan_err!("The number of arguments provided argument must equal the number of expected field names");
95+
// }
96+
//
97+
// let fields = self
98+
// .names
99+
// .iter()
100+
// .zip(arg_types.iter())
101+
// .map(|(name, dt)| Field::new(name, dt.clone(), true))
102+
// .collect::<Fields>();
103+
//
104+
// Ok(DataType::Struct(fields))
105+
// }
106+
107+
fn invoke(&self, args: &[ColumnarValue]) -> DFResult<ColumnarValue> {
108+
if self.names.len() != args.len() {
109+
return exec_err!("The number of arguments provided argument must equal the number of expected field names");
110+
}
111+
112+
let children = self
113+
.names
114+
.iter()
115+
.zip(args.iter())
116+
.map(|(name, arg)| {
117+
let arr = match arg {
118+
ColumnarValue::Array(array_value) => array_value.clone(),
119+
ColumnarValue::Scalar(scalar_value) => scalar_value.to_array()?,
120+
};
121+
122+
Ok((name.as_str(), arr))
123+
})
124+
.collect::<DFResult<Vec<_>>>()?;
125+
126+
let (fields, arrays): (Vec<_>, _) = children
127+
.into_iter()
128+
// Here I can either set nullability as true or dependent on the presence of nulls in the array,
129+
// both are not correct nullability is dependent on the schema and not a chunk of the data
130+
.map(|(name, array)| (Field::new(name, array.data_type().clone(), true), array))
131+
.unzip();
132+
133+
let struct_array = StructArray::try_new(fields.into(), arrays, None)?;
134+
135+
Ok(ColumnarValue::from(Arc::new(struct_array) as ArrayRef))
136+
}
137+
138+
// TODO(joe): support propagating nullability into invoke and therefore use the below method
139+
// see https://github.com/apache/datafusion/issues/12819
140+
fn return_type_from_exprs(
141+
&self,
142+
args: &[Expr],
143+
schema: &dyn ExprSchema,
144+
_arg_types: &[DataType],
145+
) -> DFResult<DataType> {
146+
if self.names.len() != args.len() {
147+
return plan_err!("The number of arguments provided argument must equal the number of expected field names");
148+
}
149+
150+
let fields = self
151+
.names
152+
.iter()
153+
.zip(args.iter())
154+
.map(|(name, expr)| {
155+
let (dt, null) = expr.data_type_and_nullable(schema)?;
156+
Ok(Field::new(name, dt, null))
157+
})
158+
.collect::<DFResult<Vec<Field>>>()?;
159+
160+
Ok(DataType::Struct(Fields::from(fields)))
161+
}
162+
163+
fn invoke_with_return_type(&self, args: &[ColumnarValue], return_type: &DataType) -> DFResult<ColumnarValue> {
164+
if self.names.len() != args.len() {
165+
return exec_err!("The number of arguments provided argument must equal the number of expected field names");
166+
}
167+
168+
let fields = match return_type {
169+
DataType::Struct(fields) => fields.clone(),
170+
_ => return exec_err!("Return type must be a struct, however it was {:?}", return_type),
171+
};
172+
173+
let children = fields
174+
.into_iter()
175+
.zip(args.iter())
176+
.map(|(name, arg)| {
177+
let arr = match arg {
178+
ColumnarValue::Array(array_value) => array_value.clone(),
179+
ColumnarValue::Scalar(scalar_value) => scalar_value.to_array()?,
180+
};
181+
182+
Ok((name.clone(), arr))
183+
})
184+
.collect::<DFResult<Vec<_>>>()?;
185+
186+
187+
let struct_array = StructArray::from(children);
188+
189+
Ok(ColumnarValue::from(Arc::new(struct_array) as ArrayRef))
190+
}
191+
192+
fn invoke_no_args(&self, number_rows: usize) -> DFResult<ColumnarValue> {
193+
Ok(ColumnarValue::Array(Arc::new(
194+
StructArray::new_empty_fields(number_rows, None),
195+
)))
196+
}
197+
198+
199+
}
200+
201+
#[cfg(test)]
202+
mod tests {
203+
use std::collections::HashMap;
204+
use std::sync::Arc;
205+
206+
use arrow::array::{ArrayRef, Int32Array};
207+
use arrow_array::Array;
208+
use arrow_buffer::NullBuffer;
209+
use arrow_schema::{DataType, Field, Fields};
210+
use datafusion_common::DFSchema;
211+
use datafusion_expr::{col, ColumnarValue, ScalarUDFImpl};
212+
use crate::pack::Pack;
213+
214+
#[test]
215+
fn test_pack_not_null() {
216+
let a1 = Arc::new(Int32Array::from_iter_values_with_nulls(vec![1, 2], Some(NullBuffer::from([true, false].as_slice())))) as ArrayRef;
217+
let schema= DFSchema::from_unqualified_fields(Fields::from([Arc::new(Field::new("a", DataType::Int32, true))].as_slice()), HashMap::new());
218+
let pack = Pack::new(vec!["a"]);
219+
220+
assert_eq!(DataType::Struct(Fields::from([Arc::new(Field::new("a", DataType::Int32, true))])), pack.invoke(&[ColumnarValue::Array(a1)]).unwrap().data_type());
221+
}
222+
223+
// Cannot have a return value of struct[("a", int32, null)], since the nullability is static
224+
#[test]
225+
// fails
226+
fn test_pack_null() {
227+
let a1 = Arc::new(Int32Array::from_iter_values(vec![1, 2]));
228+
let schema= DFSchema::from_unqualified_fields(Fields::from([Arc::new(Field::new("a", DataType::Int32, false))].as_slice()), HashMap::new());
229+
let pack = Pack::new(vec!["a"]);
230+
231+
assert_eq!(DataType::Struct(Fields::from([Arc::new(Field::new("a", DataType::Int32, false))])), pack.invoke(&[ColumnarValue::Array(a1)]).unwrap().data_type());
232+
}
233+
234+
#[test]
235+
fn test_pack_rt_null() {
236+
let a1 = Arc::new(Int32Array::from_iter_values(vec![1, 2])) as ArrayRef;
237+
let schema= DFSchema::from_unqualified_fields(Fields::from([Arc::new(Field::new("a", DataType::Int32, true))]), HashMap::new()).unwrap();
238+
let pack = Pack::new(vec!["a"]);
239+
240+
let rt = pack.return_type_from_exprs(&[col("a")], &schema, &[DataType::Int32]).unwrap();
241+
242+
let ret = pack.invoke_with_return_type(&[ColumnarValue::Array(a1)], &rt).unwrap();
243+
244+
println!("{:?}", ret.into_array(1).unwrap().data_type());
245+
}
246+
247+
#[test]
248+
fn test_pack_rt_not_null() {
249+
let a1 = Arc::new(Int32Array::from_iter_values(vec![1, 2])) as ArrayRef;
250+
let schema= DFSchema::from_unqualified_fields(Fields::from([Arc::new(Field::new("a", DataType::Int32, false))]), HashMap::new()).unwrap();
251+
let pack = Pack::new(vec!["a"]);
252+
253+
let rt = pack.return_type_from_exprs(&[col("a")], &schema, &[DataType::Int32]).unwrap();
254+
255+
let ret = pack.invoke_with_return_type(&[ColumnarValue::Array(a1)], &rt).unwrap();
256+
257+
println!("{:?}", ret.into_array(1).unwrap().data_type());
258+
}
259+
}

0 commit comments

Comments
 (0)