|
15 | 15 | // specific language governing permissions and limitations
|
16 | 16 | // under the License.
|
17 | 17 |
|
18 |
| -use crate::{ArrowError, Field, FieldRef, SchemaBuilder}; |
19 | 18 | use std::ops::Deref;
|
20 | 19 | use std::sync::Arc;
|
21 | 20 |
|
| 21 | +use crate::{ArrowError, DataType, Field, FieldRef, SchemaBuilder}; |
| 22 | + |
22 | 23 | /// A cheaply cloneable, owned slice of [`FieldRef`]
|
23 | 24 | ///
|
24 | 25 | /// Similar to `Arc<Vec<FieldRef>>` or `Arc<[FieldRef]>`
|
@@ -99,6 +100,108 @@ impl Fields {
|
99 | 100 | .all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b))
|
100 | 101 | }
|
101 | 102 |
|
| 103 | + /// Returns a copy of this [`Fields`] containing only those [`FieldRef`] passing a predicate |
| 104 | + /// |
| 105 | + /// Performs a depth-first scan of [`Fields`] invoking `filter` for each [`FieldRef`] |
| 106 | + /// containing no child [`FieldRef`], a leaf field, along with a count of the number |
| 107 | + /// of such leaves encountered so far. Only [`FieldRef`] for which `filter` |
| 108 | + /// returned `true` will be included in the result. |
| 109 | + /// |
| 110 | + /// This can therefore be used to select a subset of fields from nested types |
| 111 | + /// such as [`DataType::Struct`] or [`DataType::List`]. |
| 112 | + /// |
| 113 | + /// ``` |
| 114 | + /// # use arrow_schema::{DataType, Field, Fields}; |
| 115 | + /// let fields = Fields::from(vec![ |
| 116 | + /// Field::new("a", DataType::Int32, true), // Leaf 0 |
| 117 | + /// Field::new("b", DataType::Struct(Fields::from(vec![ |
| 118 | + /// Field::new("c", DataType::Float32, false), // Leaf 1 |
| 119 | + /// Field::new("d", DataType::Float64, false), // Leaf 2 |
| 120 | + /// Field::new("e", DataType::Struct(Fields::from(vec![ |
| 121 | + /// Field::new("f", DataType::Int32, false), // Leaf 3 |
| 122 | + /// Field::new("g", DataType::Float16, false), // Leaf 4 |
| 123 | + /// ])), true), |
| 124 | + /// ])), false) |
| 125 | + /// ]); |
| 126 | + /// let filtered = fields.filter_leaves(|idx, _| [0, 2, 3, 4].contains(&idx)); |
| 127 | + /// let expected = Fields::from(vec![ |
| 128 | + /// Field::new("a", DataType::Int32, true), |
| 129 | + /// Field::new("b", DataType::Struct(Fields::from(vec![ |
| 130 | + /// Field::new("d", DataType::Float64, false), |
| 131 | + /// Field::new("e", DataType::Struct(Fields::from(vec![ |
| 132 | + /// Field::new("f", DataType::Int32, false), |
| 133 | + /// Field::new("g", DataType::Float16, false), |
| 134 | + /// ])), true), |
| 135 | + /// ])), false) |
| 136 | + /// ]); |
| 137 | + /// assert_eq!(filtered, expected); |
| 138 | + /// ``` |
| 139 | + pub fn filter_leaves<F: FnMut(usize, &FieldRef) -> bool>(&self, mut filter: F) -> Self { |
| 140 | + fn filter_field<F: FnMut(&FieldRef) -> bool>( |
| 141 | + f: &FieldRef, |
| 142 | + filter: &mut F, |
| 143 | + ) -> Option<FieldRef> { |
| 144 | + use DataType::*; |
| 145 | + |
| 146 | + let v = match f.data_type() { |
| 147 | + Dictionary(_, v) => v.as_ref(), // Key must be integer |
| 148 | + RunEndEncoded(_, v) => v.data_type(), // Run-ends must be integer |
| 149 | + d => d, |
| 150 | + }; |
| 151 | + let d = match v { |
| 152 | + List(child) => List(filter_field(child, filter)?), |
| 153 | + LargeList(child) => LargeList(filter_field(child, filter)?), |
| 154 | + Map(child, ordered) => Map(filter_field(child, filter)?, *ordered), |
| 155 | + FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size), |
| 156 | + Struct(fields) => { |
| 157 | + let filtered: Fields = fields |
| 158 | + .iter() |
| 159 | + .filter_map(|f| filter_field(f, filter)) |
| 160 | + .collect(); |
| 161 | + |
| 162 | + if filtered.is_empty() { |
| 163 | + return None; |
| 164 | + } |
| 165 | + |
| 166 | + Struct(filtered) |
| 167 | + } |
| 168 | + Union(fields, mode) => { |
| 169 | + let filtered: UnionFields = fields |
| 170 | + .iter() |
| 171 | + .filter_map(|(id, f)| Some((id, filter_field(f, filter)?))) |
| 172 | + .collect(); |
| 173 | + |
| 174 | + if filtered.is_empty() { |
| 175 | + return None; |
| 176 | + } |
| 177 | + |
| 178 | + Union(filtered, *mode) |
| 179 | + } |
| 180 | + _ => return filter(f).then(|| f.clone()), |
| 181 | + }; |
| 182 | + let d = match f.data_type() { |
| 183 | + Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)), |
| 184 | + RunEndEncoded(v, f) => { |
| 185 | + RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d))) |
| 186 | + } |
| 187 | + _ => d, |
| 188 | + }; |
| 189 | + Some(Arc::new(f.as_ref().clone().with_data_type(d))) |
| 190 | + } |
| 191 | + |
| 192 | + let mut leaf_idx = 0; |
| 193 | + let mut filter = |f: &FieldRef| { |
| 194 | + let t = filter(leaf_idx, f); |
| 195 | + leaf_idx += 1; |
| 196 | + t |
| 197 | + }; |
| 198 | + |
| 199 | + self.0 |
| 200 | + .iter() |
| 201 | + .filter_map(|f| filter_field(f, &mut filter)) |
| 202 | + .collect() |
| 203 | + } |
| 204 | + |
102 | 205 | /// Remove a field by index and return it.
|
103 | 206 | ///
|
104 | 207 | /// # Panic
|
@@ -307,3 +410,130 @@ impl FromIterator<(i8, FieldRef)> for UnionFields {
|
307 | 410 | Self(iter.into_iter().collect())
|
308 | 411 | }
|
309 | 412 | }
|
| 413 | + |
| 414 | +#[cfg(test)] |
| 415 | +mod tests { |
| 416 | + use super::*; |
| 417 | + use crate::UnionMode; |
| 418 | + |
| 419 | + #[test] |
| 420 | + fn test_filter() { |
| 421 | + let floats = Fields::from(vec![ |
| 422 | + Field::new("a", DataType::Float32, false), |
| 423 | + Field::new("b", DataType::Float32, false), |
| 424 | + ]); |
| 425 | + let fields = Fields::from(vec![ |
| 426 | + Field::new("a", DataType::Int32, true), |
| 427 | + Field::new("floats", DataType::Struct(floats.clone()), true), |
| 428 | + Field::new("b", DataType::Int16, true), |
| 429 | + Field::new( |
| 430 | + "c", |
| 431 | + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), |
| 432 | + false, |
| 433 | + ), |
| 434 | + Field::new( |
| 435 | + "d", |
| 436 | + DataType::Dictionary( |
| 437 | + Box::new(DataType::Int32), |
| 438 | + Box::new(DataType::Struct(floats.clone())), |
| 439 | + ), |
| 440 | + false, |
| 441 | + ), |
| 442 | + Field::new_list( |
| 443 | + "e", |
| 444 | + Field::new("floats", DataType::Struct(floats.clone()), true), |
| 445 | + true, |
| 446 | + ), |
| 447 | + Field::new( |
| 448 | + "f", |
| 449 | + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3), |
| 450 | + false, |
| 451 | + ), |
| 452 | + Field::new_map( |
| 453 | + "g", |
| 454 | + "entries", |
| 455 | + Field::new("keys", DataType::LargeUtf8, false), |
| 456 | + Field::new("values", DataType::Int32, true), |
| 457 | + false, |
| 458 | + false, |
| 459 | + ), |
| 460 | + Field::new( |
| 461 | + "h", |
| 462 | + DataType::Union( |
| 463 | + UnionFields::new( |
| 464 | + vec![1, 3], |
| 465 | + vec![ |
| 466 | + Field::new("field1", DataType::UInt8, false), |
| 467 | + Field::new("field3", DataType::Utf8, false), |
| 468 | + ], |
| 469 | + ), |
| 470 | + UnionMode::Dense, |
| 471 | + ), |
| 472 | + true, |
| 473 | + ), |
| 474 | + Field::new( |
| 475 | + "i", |
| 476 | + DataType::RunEndEncoded( |
| 477 | + Arc::new(Field::new("run_ends", DataType::Int32, false)), |
| 478 | + Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)), |
| 479 | + ), |
| 480 | + false, |
| 481 | + ), |
| 482 | + ]); |
| 483 | + |
| 484 | + let floats_a = DataType::Struct(vec![floats[0].clone()].into()); |
| 485 | + |
| 486 | + let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1); |
| 487 | + assert_eq!(r.len(), 2); |
| 488 | + assert_eq!(r[0], fields[0]); |
| 489 | + assert_eq!(r[1].data_type(), &floats_a); |
| 490 | + |
| 491 | + let r = fields.filter_leaves(|_, f| f.name() == "a"); |
| 492 | + assert_eq!(r.len(), 5); |
| 493 | + assert_eq!(r[0], fields[0]); |
| 494 | + assert_eq!(r[1].data_type(), &floats_a); |
| 495 | + assert_eq!( |
| 496 | + r[2].data_type(), |
| 497 | + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone())) |
| 498 | + ); |
| 499 | + assert_eq!( |
| 500 | + r[3].as_ref(), |
| 501 | + &Field::new_list("e", Field::new("floats", floats_a.clone(), true), true) |
| 502 | + ); |
| 503 | + assert_eq!( |
| 504 | + r[4].as_ref(), |
| 505 | + &Field::new( |
| 506 | + "i", |
| 507 | + DataType::RunEndEncoded( |
| 508 | + Arc::new(Field::new("run_ends", DataType::Int32, false)), |
| 509 | + Arc::new(Field::new("values", floats_a.clone(), true)), |
| 510 | + ), |
| 511 | + false, |
| 512 | + ) |
| 513 | + ); |
| 514 | + |
| 515 | + let r = fields.filter_leaves(|_, f| f.name() == "floats"); |
| 516 | + assert_eq!(r.len(), 0); |
| 517 | + |
| 518 | + let r = fields.filter_leaves(|idx, _| idx == 9); |
| 519 | + assert_eq!(r.len(), 1); |
| 520 | + assert_eq!(r[0], fields[6]); |
| 521 | + |
| 522 | + let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11); |
| 523 | + assert_eq!(r.len(), 1); |
| 524 | + assert_eq!(r[0], fields[7]); |
| 525 | + |
| 526 | + let union = DataType::Union( |
| 527 | + UnionFields::new(vec![1], vec![Field::new("field1", DataType::UInt8, false)]), |
| 528 | + UnionMode::Dense, |
| 529 | + ); |
| 530 | + |
| 531 | + let r = fields.filter_leaves(|idx, _| idx == 12); |
| 532 | + assert_eq!(r.len(), 1); |
| 533 | + assert_eq!(r[0].data_type(), &union); |
| 534 | + |
| 535 | + let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15); |
| 536 | + assert_eq!(r.len(), 1); |
| 537 | + assert_eq!(r[0], fields[9]); |
| 538 | + } |
| 539 | +} |
0 commit comments