Skip to content

Commit 6d4b8bb

Browse files
authored
Support nested schema projection (#5148) (#5149)
* Support nested schema projection * Tweak doc * Review feedback
1 parent 8867a1f commit 6d4b8bb

File tree

1 file changed

+231
-1
lines changed

1 file changed

+231
-1
lines changed

arrow-schema/src/fields.rs

Lines changed: 231 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::{ArrowError, Field, FieldRef, SchemaBuilder};
1918
use std::ops::Deref;
2019
use std::sync::Arc;
2120

21+
use crate::{ArrowError, DataType, Field, FieldRef, SchemaBuilder};
22+
2223
/// A cheaply cloneable, owned slice of [`FieldRef`]
2324
///
2425
/// Similar to `Arc<Vec<FieldRef>>` or `Arc<[FieldRef]>`
@@ -99,6 +100,108 @@ impl Fields {
99100
.all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b))
100101
}
101102

103+
/// Returns a copy of this [`Fields`] containing only those [`FieldRef`] passing a predicate
104+
///
105+
/// Performs a depth-first scan of [`Fields`] invoking `filter` for each [`FieldRef`]
106+
/// containing no child [`FieldRef`], a leaf field, along with a count of the number
107+
/// of such leaves encountered so far. Only [`FieldRef`] for which `filter`
108+
/// returned `true` will be included in the result.
109+
///
110+
/// This can therefore be used to select a subset of fields from nested types
111+
/// such as [`DataType::Struct`] or [`DataType::List`].
112+
///
113+
/// ```
114+
/// # use arrow_schema::{DataType, Field, Fields};
115+
/// let fields = Fields::from(vec![
116+
/// Field::new("a", DataType::Int32, true), // Leaf 0
117+
/// Field::new("b", DataType::Struct(Fields::from(vec![
118+
/// Field::new("c", DataType::Float32, false), // Leaf 1
119+
/// Field::new("d", DataType::Float64, false), // Leaf 2
120+
/// Field::new("e", DataType::Struct(Fields::from(vec![
121+
/// Field::new("f", DataType::Int32, false), // Leaf 3
122+
/// Field::new("g", DataType::Float16, false), // Leaf 4
123+
/// ])), true),
124+
/// ])), false)
125+
/// ]);
126+
/// let filtered = fields.filter_leaves(|idx, _| [0, 2, 3, 4].contains(&idx));
127+
/// let expected = Fields::from(vec![
128+
/// Field::new("a", DataType::Int32, true),
129+
/// Field::new("b", DataType::Struct(Fields::from(vec![
130+
/// Field::new("d", DataType::Float64, false),
131+
/// Field::new("e", DataType::Struct(Fields::from(vec![
132+
/// Field::new("f", DataType::Int32, false),
133+
/// Field::new("g", DataType::Float16, false),
134+
/// ])), true),
135+
/// ])), false)
136+
/// ]);
137+
/// assert_eq!(filtered, expected);
138+
/// ```
139+
pub fn filter_leaves<F: FnMut(usize, &FieldRef) -> bool>(&self, mut filter: F) -> Self {
140+
fn filter_field<F: FnMut(&FieldRef) -> bool>(
141+
f: &FieldRef,
142+
filter: &mut F,
143+
) -> Option<FieldRef> {
144+
use DataType::*;
145+
146+
let v = match f.data_type() {
147+
Dictionary(_, v) => v.as_ref(), // Key must be integer
148+
RunEndEncoded(_, v) => v.data_type(), // Run-ends must be integer
149+
d => d,
150+
};
151+
let d = match v {
152+
List(child) => List(filter_field(child, filter)?),
153+
LargeList(child) => LargeList(filter_field(child, filter)?),
154+
Map(child, ordered) => Map(filter_field(child, filter)?, *ordered),
155+
FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size),
156+
Struct(fields) => {
157+
let filtered: Fields = fields
158+
.iter()
159+
.filter_map(|f| filter_field(f, filter))
160+
.collect();
161+
162+
if filtered.is_empty() {
163+
return None;
164+
}
165+
166+
Struct(filtered)
167+
}
168+
Union(fields, mode) => {
169+
let filtered: UnionFields = fields
170+
.iter()
171+
.filter_map(|(id, f)| Some((id, filter_field(f, filter)?)))
172+
.collect();
173+
174+
if filtered.is_empty() {
175+
return None;
176+
}
177+
178+
Union(filtered, *mode)
179+
}
180+
_ => return filter(f).then(|| f.clone()),
181+
};
182+
let d = match f.data_type() {
183+
Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)),
184+
RunEndEncoded(v, f) => {
185+
RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d)))
186+
}
187+
_ => d,
188+
};
189+
Some(Arc::new(f.as_ref().clone().with_data_type(d)))
190+
}
191+
192+
let mut leaf_idx = 0;
193+
let mut filter = |f: &FieldRef| {
194+
let t = filter(leaf_idx, f);
195+
leaf_idx += 1;
196+
t
197+
};
198+
199+
self.0
200+
.iter()
201+
.filter_map(|f| filter_field(f, &mut filter))
202+
.collect()
203+
}
204+
102205
/// Remove a field by index and return it.
103206
///
104207
/// # Panic
@@ -307,3 +410,130 @@ impl FromIterator<(i8, FieldRef)> for UnionFields {
307410
Self(iter.into_iter().collect())
308411
}
309412
}
413+
414+
#[cfg(test)]
415+
mod tests {
416+
use super::*;
417+
use crate::UnionMode;
418+
419+
#[test]
420+
fn test_filter() {
421+
let floats = Fields::from(vec![
422+
Field::new("a", DataType::Float32, false),
423+
Field::new("b", DataType::Float32, false),
424+
]);
425+
let fields = Fields::from(vec![
426+
Field::new("a", DataType::Int32, true),
427+
Field::new("floats", DataType::Struct(floats.clone()), true),
428+
Field::new("b", DataType::Int16, true),
429+
Field::new(
430+
"c",
431+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
432+
false,
433+
),
434+
Field::new(
435+
"d",
436+
DataType::Dictionary(
437+
Box::new(DataType::Int32),
438+
Box::new(DataType::Struct(floats.clone())),
439+
),
440+
false,
441+
),
442+
Field::new_list(
443+
"e",
444+
Field::new("floats", DataType::Struct(floats.clone()), true),
445+
true,
446+
),
447+
Field::new(
448+
"f",
449+
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3),
450+
false,
451+
),
452+
Field::new_map(
453+
"g",
454+
"entries",
455+
Field::new("keys", DataType::LargeUtf8, false),
456+
Field::new("values", DataType::Int32, true),
457+
false,
458+
false,
459+
),
460+
Field::new(
461+
"h",
462+
DataType::Union(
463+
UnionFields::new(
464+
vec![1, 3],
465+
vec![
466+
Field::new("field1", DataType::UInt8, false),
467+
Field::new("field3", DataType::Utf8, false),
468+
],
469+
),
470+
UnionMode::Dense,
471+
),
472+
true,
473+
),
474+
Field::new(
475+
"i",
476+
DataType::RunEndEncoded(
477+
Arc::new(Field::new("run_ends", DataType::Int32, false)),
478+
Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)),
479+
),
480+
false,
481+
),
482+
]);
483+
484+
let floats_a = DataType::Struct(vec![floats[0].clone()].into());
485+
486+
let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1);
487+
assert_eq!(r.len(), 2);
488+
assert_eq!(r[0], fields[0]);
489+
assert_eq!(r[1].data_type(), &floats_a);
490+
491+
let r = fields.filter_leaves(|_, f| f.name() == "a");
492+
assert_eq!(r.len(), 5);
493+
assert_eq!(r[0], fields[0]);
494+
assert_eq!(r[1].data_type(), &floats_a);
495+
assert_eq!(
496+
r[2].data_type(),
497+
&DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone()))
498+
);
499+
assert_eq!(
500+
r[3].as_ref(),
501+
&Field::new_list("e", Field::new("floats", floats_a.clone(), true), true)
502+
);
503+
assert_eq!(
504+
r[4].as_ref(),
505+
&Field::new(
506+
"i",
507+
DataType::RunEndEncoded(
508+
Arc::new(Field::new("run_ends", DataType::Int32, false)),
509+
Arc::new(Field::new("values", floats_a.clone(), true)),
510+
),
511+
false,
512+
)
513+
);
514+
515+
let r = fields.filter_leaves(|_, f| f.name() == "floats");
516+
assert_eq!(r.len(), 0);
517+
518+
let r = fields.filter_leaves(|idx, _| idx == 9);
519+
assert_eq!(r.len(), 1);
520+
assert_eq!(r[0], fields[6]);
521+
522+
let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11);
523+
assert_eq!(r.len(), 1);
524+
assert_eq!(r[0], fields[7]);
525+
526+
let union = DataType::Union(
527+
UnionFields::new(vec![1], vec![Field::new("field1", DataType::UInt8, false)]),
528+
UnionMode::Dense,
529+
);
530+
531+
let r = fields.filter_leaves(|idx, _| idx == 12);
532+
assert_eq!(r.len(), 1);
533+
assert_eq!(r[0].data_type(), &union);
534+
535+
let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15);
536+
assert_eq!(r.len(), 1);
537+
assert_eq!(r[0], fields[9]);
538+
}
539+
}

0 commit comments

Comments
 (0)