Skip to content

Commit

Permalink
Extract parquet statistics from f16 columns, add `ScalarValue::Float1…
Browse files Browse the repository at this point in the history
…6` (#10763)

* Extract parquet statistics from f16 columns

* Update datafusion/common/src/scalar/mod.rs

Co-authored-by: Andrew Lamb <[email protected]>

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
Lordworms and alamb authored Jun 3, 2024
1 parent 826331e commit fe53649
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 24 deletions.
39 changes: 35 additions & 4 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
//! [`ScalarValue`]: stores single values

mod struct_builder;

use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::{HashSet, VecDeque};
use std::convert::Infallible;
use std::fmt;
use std::hash::Hash;
use std::hash::Hasher;
use std::iter::repeat;
use std::str::FromStr;
use std::sync::Arc;
Expand Down Expand Up @@ -55,6 +55,7 @@ use arrow::{
use arrow_buffer::Buffer;
use arrow_schema::{UnionFields, UnionMode};

use half::f16;
pub use struct_builder::ScalarStructBuilder;

/// A dynamically typed, nullable single value.
Expand Down Expand Up @@ -192,6 +193,8 @@ pub enum ScalarValue {
Null,
/// true or false value
Boolean(Option<bool>),
/// 16bit float
Float16(Option<f16>),
/// 32bit float
Float32(Option<f32>),
/// 64bit float
Expand Down Expand Up @@ -285,6 +288,12 @@ pub enum ScalarValue {
Dictionary(Box<DataType>, Box<ScalarValue>),
}

impl Hash for Fl<f16> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.to_bits().hash(state);
}
}

// manual implementation of `PartialEq`
impl PartialEq for ScalarValue {
fn eq(&self, other: &Self) -> bool {
Expand All @@ -307,7 +316,12 @@ impl PartialEq for ScalarValue {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
},
(Float16(v1), Float16(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
},
(Float32(_), _) => false,
(Float16(_), _) => false,
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
Expand Down Expand Up @@ -425,7 +439,12 @@ impl PartialOrd for ScalarValue {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
},
(Float16(v1), Float16(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
},
(Float32(_), _) => None,
(Float16(_), _) => None,
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
Expand Down Expand Up @@ -637,6 +656,7 @@ impl std::hash::Hash for ScalarValue {
s.hash(state)
}
Boolean(v) => v.hash(state),
Float16(v) => v.map(Fl).hash(state),
Float32(v) => v.map(Fl).hash(state),
Float64(v) => v.map(Fl).hash(state),
Int8(v) => v.hash(state),
Expand Down Expand Up @@ -1082,6 +1102,7 @@ impl ScalarValue {
ScalarValue::TimestampNanosecond(_, tz_opt) => {
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone())
}
ScalarValue::Float16(_) => DataType::Float16,
ScalarValue::Float32(_) => DataType::Float32,
ScalarValue::Float64(_) => DataType::Float64,
ScalarValue::Utf8(_) => DataType::Utf8,
Expand Down Expand Up @@ -1276,6 +1297,7 @@ impl ScalarValue {
match self {
ScalarValue::Boolean(v) => v.is_none(),
ScalarValue::Null => true,
ScalarValue::Float16(v) => v.is_none(),
ScalarValue::Float32(v) => v.is_none(),
ScalarValue::Float64(v) => v.is_none(),
ScalarValue::Decimal128(v, _, _) => v.is_none(),
Expand Down Expand Up @@ -1522,6 +1544,7 @@ impl ScalarValue {
}
DataType::Null => ScalarValue::iter_to_null_array(scalars)?,
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
DataType::Float16 => build_array_primitive!(Float16Array, Float16),
DataType::Float32 => build_array_primitive!(Float32Array, Float32),
DataType::Float64 => build_array_primitive!(Float64Array, Float64),
DataType::Int8 => build_array_primitive!(Int8Array, Int8),
Expand Down Expand Up @@ -1682,8 +1705,7 @@ impl ScalarValue {
// not supported if the TimeUnit is not valid (Time32 can
// only be used with Second and Millisecond, Time64 only
// with Microsecond and Nanosecond)
DataType::Float16
| DataType::Time32(TimeUnit::Microsecond)
DataType::Time32(TimeUnit::Microsecond)
| DataType::Time32(TimeUnit::Nanosecond)
| DataType::Time64(TimeUnit::Second)
| DataType::Time64(TimeUnit::Millisecond)
Expand All @@ -1700,7 +1722,6 @@ impl ScalarValue {
);
}
};

Ok(array)
}

Expand Down Expand Up @@ -1921,6 +1942,9 @@ impl ScalarValue {
ScalarValue::Float32(e) => {
build_array_from_option!(Float32, Float32Array, e, size)
}
ScalarValue::Float16(e) => {
build_array_from_option!(Float16, Float16Array, e, size)
}
ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array, e, size),
ScalarValue::Int16(e) => build_array_from_option!(Int16, Int16Array, e, size),
ScalarValue::Int32(e) => build_array_from_option!(Int32, Int32Array, e, size),
Expand Down Expand Up @@ -2595,6 +2619,9 @@ impl ScalarValue {
ScalarValue::Boolean(val) => {
eq_array_primitive!(array, index, BooleanArray, val)?
}
ScalarValue::Float16(val) => {
eq_array_primitive!(array, index, Float16Array, val)?
}
ScalarValue::Float32(val) => {
eq_array_primitive!(array, index, Float32Array, val)?
}
Expand Down Expand Up @@ -2738,6 +2765,7 @@ impl ScalarValue {
+ match self {
ScalarValue::Null
| ScalarValue::Boolean(_)
| ScalarValue::Float16(_)
| ScalarValue::Float32(_)
| ScalarValue::Float64(_)
| ScalarValue::Decimal128(_, _, _)
Expand Down Expand Up @@ -3022,6 +3050,7 @@ impl TryFrom<&DataType> for ScalarValue {
fn try_from(data_type: &DataType) -> Result<Self> {
Ok(match data_type {
DataType::Boolean => ScalarValue::Boolean(None),
DataType::Float16 => ScalarValue::Float16(None),
DataType::Float64 => ScalarValue::Float64(None),
DataType::Float32 => ScalarValue::Float32(None),
DataType::Int8 => ScalarValue::Int8(None),
Expand Down Expand Up @@ -3147,6 +3176,7 @@ impl fmt::Display for ScalarValue {
write!(f, "{v:?},{p:?},{s:?}")?;
}
ScalarValue::Boolean(e) => format_option!(f, e)?,
ScalarValue::Float16(e) => format_option!(f, e)?,
ScalarValue::Float32(e) => format_option!(f, e)?,
ScalarValue::Float64(e) => format_option!(f, e)?,
ScalarValue::Int8(e) => format_option!(f, e)?,
Expand Down Expand Up @@ -3260,6 +3290,7 @@ impl fmt::Debug for ScalarValue {
ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"),
ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"),
ScalarValue::Boolean(_) => write!(f, "Boolean({self})"),
ScalarValue::Float16(_) => write!(f, "Float16({self})"),
ScalarValue::Float32(_) => write!(f, "Float32({self})"),
ScalarValue::Float64(_) => write!(f, "Float64({self})"),
ScalarValue::Int8(_) => write!(f, "Int8({self})"),
Expand Down
14 changes: 12 additions & 2 deletions datafusion/core/src/datasource/physical_plan/parquet/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ use arrow_schema::{Field, FieldRef, Schema};
use datafusion_common::{
internal_datafusion_err, internal_err, plan_err, Result, ScalarValue,
};
use half::f16;
use parquet::file::metadata::ParquetMetaData;
use parquet::file::statistics::Statistics as ParquetStatistics;
use parquet::schema::types::SchemaDescriptor;
use std::sync::Arc;

// Convert the bytes array to i128.
// The endian of the input bytes array must be big-endian.
pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 {
Expand All @@ -39,6 +39,14 @@ pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 {
i128::from_be_bytes(sign_extend_be(b))
}

// Convert the bytes array to f16
pub(crate) fn from_bytes_to_f16(b: &[u8]) -> Option<f16> {
match b {
[low, high] => Some(f16::from_be_bytes([*high, *low])),
_ => None,
}
}

// Copy from arrow-rs
// https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55
// Convert the byte slice to fixed length byte array with the length of 16
Expand Down Expand Up @@ -196,6 +204,9 @@ macro_rules! get_statistic {
value,
))
}
Some(DataType::Float16) => {
Some(ScalarValue::Float16(from_bytes_to_f16(s.$bytes_func())))
}
_ => None,
}
}
Expand Down Expand Up @@ -344,7 +355,6 @@ impl<'a> StatisticsConverter<'a> {
column_name
);
};

Ok(Self {
column_name,
statistics_type,
Expand Down
43 changes: 37 additions & 6 deletions datafusion/core/tests/parquet/arrow_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,29 @@
use std::fs::File;
use std::sync::Arc;

use crate::parquet::{struct_array, Scenario};
use arrow::compute::kernels::cast_utils::Parser;
use arrow::datatypes::{
Date32Type, Date64Type, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType,
};
use arrow_array::{
make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array,
Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array,
Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, StringArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
Decimal128Array, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch,
StringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
UInt64Array, UInt8Array,
};
use arrow_schema::{DataType, Field, Schema};
use datafusion::datasource::physical_plan::parquet::{
RequestedStatistics, StatisticsConverter,
};
use half::f16;
use parquet::arrow::arrow_reader::{ArrowReaderBuilder, ParquetRecordBatchReaderBuilder};
use parquet::arrow::ArrowWriter;
use parquet::file::properties::{EnabledStatistics, WriterProperties};

use crate::parquet::{struct_array, Scenario};

use super::make_test_file_rg;

// TEST HELPERS
Expand Down Expand Up @@ -1203,6 +1204,36 @@ async fn test_float64() {
.run();
}

#[tokio::test]
async fn test_float16() {
// This creates a parquet file of 1 column "f"
// file has 4 record batches, each has 5 rows. They will be saved into 4 row groups
let reader = TestReader {
scenario: Scenario::Float16,
row_per_group: 5,
};

Test {
reader: reader.build().await,
expected_min: Arc::new(Float16Array::from(
vec![-5.0, -4.0, -0.0, 5.0]
.into_iter()
.map(f16::from_f32)
.collect::<Vec<_>>(),
)),
expected_max: Arc::new(Float16Array::from(
vec![-1.0, 0.0, 4.0, 9.0]
.into_iter()
.map(f16::from_f32)
.collect::<Vec<_>>(),
)),
expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]),
expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]),
column_name: "f",
}
.run();
}

#[tokio::test]
async fn test_decimal() {
// This creates a parquet file of 1 column "decimal_col" with decimal data type and precicion 9, scale 2
Expand Down
55 changes: 43 additions & 12 deletions datafusion/core/tests/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,29 @@
use arrow::array::Decimal128Array;
use arrow::{
array::{
Array, ArrayRef, BinaryArray, Date32Array, Date64Array, FixedSizeBinaryArray,
Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array,
DictionaryArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray,
StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
UInt64Array, UInt8Array,
},
datatypes::{DataType, Field, Schema},
datatypes::{DataType, Field, Int32Type, Int8Type, Schema},
record_batch::RecordBatch,
util::pretty::pretty_format_batches,
};
use arrow_array::types::{Int32Type, Int8Type};
use arrow_array::{
make_array, BooleanArray, DictionaryArray, Float32Array, LargeStringArray,
StructArray,
};
use chrono::{Datelike, Duration, TimeDelta};
use datafusion::{
datasource::{physical_plan::ParquetExec, provider_as_source, TableProvider},
physical_plan::{accept, metrics::MetricsSet, ExecutionPlan, ExecutionPlanVisitor},
prelude::{ParquetReadOptions, SessionConfig, SessionContext},
};
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
use half::f16;
use parquet::arrow::ArrowWriter;
use parquet::file::properties::WriterProperties;
use std::sync::Arc;
use tempfile::NamedTempFile;

mod arrow_statistics;
mod custom_reader;
mod file_statistics;
Expand Down Expand Up @@ -79,6 +76,7 @@ enum Scenario {
/// 7 Rows, for each i8, i16, i32, i64, u8, u16, u32, u64, f32, f64
/// -MIN, -100, -1, 0, 1, 100, MAX
NumericLimits,
Float16,
Float64,
Decimal,
DecimalBloomFilterInt32,
Expand Down Expand Up @@ -542,6 +540,12 @@ fn make_f64_batch(v: Vec<f64>) -> RecordBatch {
RecordBatch::try_new(schema, vec![array.clone()]).unwrap()
}

fn make_f16_batch(v: Vec<f16>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float16, true)]));
let array = Arc::new(Float16Array::from(v)) as ArrayRef;
RecordBatch::try_new(schema, vec![array.clone()]).unwrap()
}

/// Return record batch with decimal vector
///
/// Columns are named
Expand Down Expand Up @@ -897,6 +901,34 @@ fn create_data_batch(scenario: Scenario) -> Vec<RecordBatch> {
Scenario::NumericLimits => {
vec![make_numeric_limit_batch()]
}
Scenario::Float16 => {
vec![
make_f16_batch(
vec![-5.0, -4.0, -3.0, -2.0, -1.0]
.into_iter()
.map(f16::from_f32)
.collect(),
),
make_f16_batch(
vec![-4.0, -3.0, -2.0, -1.0, 0.0]
.into_iter()
.map(f16::from_f32)
.collect(),
),
make_f16_batch(
vec![0.0, 1.0, 2.0, 3.0, 4.0]
.into_iter()
.map(f16::from_f32)
.collect(),
),
make_f16_batch(
vec![5.0, 6.0, 7.0, 8.0, 9.0]
.into_iter()
.map(f16::from_f32)
.collect(),
),
]
}
Scenario::Float64 => {
vec![
make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]),
Expand Down Expand Up @@ -1087,7 +1119,6 @@ async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTem
.build();

let batches = create_data_batch(scenario);

let schema = batches[0].schema();

let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap();
Expand Down
Loading

0 comments on commit fe53649

Please sign in to comment.