Skip to content

Commit 314af4c

Browse files
feat: nan_value_counts support (#907)
## Issue Fixes #417 ## Description - We compute upper and lower bounds by relying on parquet statistics, but those statistics don't provide `nan_value_count`, so we have to implement it in library itself when arrow record batches are received. - We keep track of it at `ParquetWriter` level cause `write` can be called multiple times . - Added couple of new tests for different types for `nan_val_count`. --------- Co-authored-by: Renjie Liu <[email protected]>
1 parent b84e0d2 commit 314af4c

File tree

8 files changed

+823
-13
lines changed

8 files changed

+823
-13
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/iceberg/src/arrow/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919
2020
mod schema;
2121
pub use schema::*;
22+
23+
mod nan_val_cnt_visitor;
24+
pub(crate) use nan_val_cnt_visitor::*;
25+
2226
pub(crate) mod delete_file_manager;
27+
2328
mod reader;
2429
pub(crate) mod record_batch_projector;
2530
pub(crate) mod record_batch_transformer;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! The module contains the visitor for calculating NaN values in give arrow record batch.
19+
20+
use std::collections::hash_map::Entry;
21+
use std::collections::HashMap;
22+
use std::sync::Arc;
23+
24+
use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, StructArray};
25+
use arrow_schema::DataType;
26+
27+
use crate::arrow::ArrowArrayAccessor;
28+
use crate::spec::{
29+
visit_struct_with_partner, ListType, MapType, NestedFieldRef, PrimitiveType, Schema, SchemaRef,
30+
SchemaWithPartnerVisitor, StructType,
31+
};
32+
use crate::Result;
33+
34+
macro_rules! cast_and_update_cnt_map {
35+
($t:ty, $col:ident, $self:ident, $field_id:ident) => {
36+
let nan_val_cnt = $col
37+
.as_any()
38+
.downcast_ref::<$t>()
39+
.unwrap()
40+
.iter()
41+
.filter(|value| value.map_or(false, |v| v.is_nan()))
42+
.count() as u64;
43+
44+
match $self.nan_value_counts.entry($field_id) {
45+
Entry::Occupied(mut ele) => {
46+
let total_nan_val_cnt = ele.get() + nan_val_cnt;
47+
ele.insert(total_nan_val_cnt);
48+
}
49+
Entry::Vacant(v) => {
50+
v.insert(nan_val_cnt);
51+
}
52+
};
53+
};
54+
}
55+
56+
macro_rules! count_float_nans {
57+
($col:ident, $self:ident, $field_id:ident) => {
58+
match $col.data_type() {
59+
DataType::Float32 => {
60+
cast_and_update_cnt_map!(Float32Array, $col, $self, $field_id);
61+
}
62+
DataType::Float64 => {
63+
cast_and_update_cnt_map!(Float64Array, $col, $self, $field_id);
64+
}
65+
_ => {}
66+
}
67+
};
68+
}
69+
70+
/// Visitor which counts and keeps track of NaN value counts in given record batch(s)
71+
pub struct NanValueCountVisitor {
72+
/// Stores field ID to NaN value count mapping
73+
pub nan_value_counts: HashMap<i32, u64>,
74+
}
75+
76+
impl SchemaWithPartnerVisitor<ArrayRef> for NanValueCountVisitor {
77+
type T = ();
78+
79+
fn schema(
80+
&mut self,
81+
_schema: &Schema,
82+
_partner: &ArrayRef,
83+
_value: Self::T,
84+
) -> Result<Self::T> {
85+
Ok(())
86+
}
87+
88+
fn field(
89+
&mut self,
90+
_field: &NestedFieldRef,
91+
_partner: &ArrayRef,
92+
_value: Self::T,
93+
) -> Result<Self::T> {
94+
Ok(())
95+
}
96+
97+
fn r#struct(
98+
&mut self,
99+
_struct: &StructType,
100+
_partner: &ArrayRef,
101+
_results: Vec<Self::T>,
102+
) -> Result<Self::T> {
103+
Ok(())
104+
}
105+
106+
fn list(&mut self, _list: &ListType, _list_arr: &ArrayRef, _value: Self::T) -> Result<Self::T> {
107+
Ok(())
108+
}
109+
110+
fn map(
111+
&mut self,
112+
_map: &MapType,
113+
_partner: &ArrayRef,
114+
_key_value: Self::T,
115+
_value: Self::T,
116+
) -> Result<Self::T> {
117+
Ok(())
118+
}
119+
120+
fn primitive(&mut self, _p: &PrimitiveType, _col: &ArrayRef) -> Result<Self::T> {
121+
Ok(())
122+
}
123+
124+
fn after_struct_field(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
125+
let field_id = field.id;
126+
count_float_nans!(partner, self, field_id);
127+
Ok(())
128+
}
129+
130+
fn after_list_element(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
131+
let field_id = field.id;
132+
count_float_nans!(partner, self, field_id);
133+
Ok(())
134+
}
135+
136+
fn after_map_key(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
137+
let field_id = field.id;
138+
count_float_nans!(partner, self, field_id);
139+
Ok(())
140+
}
141+
142+
fn after_map_value(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
143+
let field_id = field.id;
144+
count_float_nans!(partner, self, field_id);
145+
Ok(())
146+
}
147+
}
148+
149+
impl NanValueCountVisitor {
150+
/// Creates new instance of NanValueCountVisitor
151+
pub fn new() -> Self {
152+
Self {
153+
nan_value_counts: HashMap::new(),
154+
}
155+
}
156+
157+
/// Compute nan value counts in given schema and record batch
158+
pub fn compute(&mut self, schema: SchemaRef, batch: RecordBatch) -> Result<()> {
159+
let arrow_arr_partner_accessor = ArrowArrayAccessor {};
160+
161+
let struct_arr = Arc::new(StructArray::from(batch)) as ArrayRef;
162+
visit_struct_with_partner(
163+
schema.as_struct(),
164+
&struct_arr,
165+
self,
166+
&arrow_arr_partner_accessor,
167+
)?;
168+
169+
Ok(())
170+
}
171+
}
172+
173+
impl Default for NanValueCountVisitor {
174+
fn default() -> Self {
175+
Self::new()
176+
}
177+
}

crates/iceberg/src/arrow/value.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,8 @@ impl SchemaWithPartnerVisitor<ArrayRef> for ArrowArrayToIcebergStructConverter {
425425
}
426426
}
427427

428-
struct ArrowArrayAccessor;
428+
/// Partner type representing accessing and walking arrow arrays alongside iceberg schema
429+
pub struct ArrowArrayAccessor;
429430

430431
impl PartnerAccessor<ArrayRef> for ArrowArrayAccessor {
431432
fn struct_parner<'a>(&self, schema_partner: &'a ArrayRef) -> Result<&'a ArrayRef> {
@@ -435,6 +436,7 @@ impl PartnerAccessor<ArrayRef> for ArrowArrayAccessor {
435436
"The schema partner is not a struct type",
436437
));
437438
}
439+
438440
Ok(schema_partner)
439441
}
440442

@@ -452,6 +454,7 @@ impl PartnerAccessor<ArrayRef> for ArrowArrayAccessor {
452454
"The struct partner is not a struct array",
453455
)
454456
})?;
457+
455458
let field_pos = struct_array
456459
.fields()
457460
.iter()
@@ -466,6 +469,7 @@ impl PartnerAccessor<ArrayRef> for ArrowArrayAccessor {
466469
format!("Field id {} not found in struct array", field.id),
467470
)
468471
})?;
472+
469473
Ok(struct_array.column(field_pos))
470474
}
471475

crates/iceberg/src/writer/base_writer/data_file_writer.rs

+18-4
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,13 @@ impl<B: FileWriterBuilder> CurrentFileStatus for DataFileWriter<B> {
103103

104104
#[cfg(test)]
105105
mod test {
106+
use std::collections::HashMap;
106107
use std::sync::Arc;
107108

108109
use arrow_array::{Int32Array, StringArray};
109110
use arrow_schema::{DataType, Field};
110111
use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions};
112+
use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
111113
use parquet::file::properties::WriterProperties;
112114
use tempfile::TempDir;
113115

@@ -153,8 +155,14 @@ mod test {
153155
.unwrap();
154156

155157
let arrow_schema = arrow_schema::Schema::new(vec![
156-
Field::new("foo", DataType::Int32, false),
157-
Field::new("bar", DataType::Utf8, false),
158+
Field::new("foo", DataType::Int32, false).with_metadata(HashMap::from([(
159+
PARQUET_FIELD_ID_META_KEY.to_string(),
160+
3.to_string(),
161+
)])),
162+
Field::new("bar", DataType::Utf8, false).with_metadata(HashMap::from([(
163+
PARQUET_FIELD_ID_META_KEY.to_string(),
164+
4.to_string(),
165+
)])),
158166
]);
159167
let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
160168
Arc::new(Int32Array::from(vec![1, 2, 3])),
@@ -224,8 +232,14 @@ mod test {
224232
.await?;
225233

226234
let arrow_schema = arrow_schema::Schema::new(vec![
227-
Field::new("id", DataType::Int32, false),
228-
Field::new("name", DataType::Utf8, false),
235+
Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
236+
PARQUET_FIELD_ID_META_KEY.to_string(),
237+
5.to_string(),
238+
)])),
239+
Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
240+
PARQUET_FIELD_ID_META_KEY.to_string(),
241+
6.to_string(),
242+
)])),
229243
]);
230244
let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
231245
Arc::new(Int32Array::from(vec![1, 2, 3])),

0 commit comments

Comments
 (0)