From 6e56d24b72ace8c981f8f443c1b7412ec330c180 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 23 Dec 2024 13:50:23 -0500 Subject: [PATCH 1/2] Merge partition columns into scan statistics for data skipping --- kernel/src/scan/data_skipping.rs | 171 ++++++++++++++++++++++++++++++- 1 file changed, 169 insertions(+), 2 deletions(-) diff --git a/kernel/src/scan/data_skipping.rs b/kernel/src/scan/data_skipping.rs index 54eb5344c..a2e677b54 100644 --- a/kernel/src/scan/data_skipping.rs +++ b/kernel/src/scan/data_skipping.rs @@ -1,11 +1,14 @@ use std::borrow::Cow; use std::cmp::Ordering; +use std::collections::{HashMap, HashSet}; use std::sync::{Arc, LazyLock}; +use arrow_array::{Array, ArrayRef, MapArray, RecordBatch, StringArray, StructArray}; use tracing::debug; use crate::actions::get_log_add_schema; use crate::actions::visitors::SelectionVectorVisitor; +use crate::engine::arrow_data::ArrowEngineData; use crate::error::DeltaResult; use crate::expressions::{ column_expr, joined_column_expr, BinaryOperator, ColumnName, Expression as Expr, ExpressionRef, @@ -14,8 +17,10 @@ use crate::expressions::{ use crate::predicates::{ DataSkippingPredicateEvaluator, PredicateEvaluator, PredicateEvaluatorDefaults, }; -use crate::schema::{DataType, PrimitiveType, SchemaRef, SchemaTransform, StructField, StructType}; -use crate::{Engine, EngineData, ExpressionEvaluator, JsonHandler, RowVisitor as _}; +use crate::schema::{ + DataType, MapType, PrimitiveType, SchemaRef, SchemaTransform, StructField, StructType, +}; +use crate::{Engine, EngineData, Error, ExpressionEvaluator, JsonHandler, RowVisitor as _}; #[cfg(test)] mod tests; @@ -43,6 +48,7 @@ fn as_data_skipping_predicate(expr: &Expr, inverted: bool) -> Option { pub(crate) struct DataSkippingFilter { stats_schema: SchemaRef, select_stats_evaluator: Arc, + partitions_evaluator: Arc, skipping_evaluator: Arc, filter_evaluator: Arc, json_handler: Arc, @@ -61,6 +67,8 @@ impl DataSkippingFilter { static PREDICATE_SCHEMA: LazyLock = LazyLock::new(|| { DataType::struct_type([StructField::new("predicate", DataType::BOOLEAN, true)]) }); + static PARITIONS_EXPR: LazyLock = + LazyLock::new(|| column_expr!("add.partitionValues")); static STATS_EXPR: LazyLock = LazyLock::new(|| column_expr!("add.stats")); static FILTER_EXPR: LazyLock = LazyLock::new(|| column_expr!("predicate").distinct(false)); @@ -88,6 +96,8 @@ impl DataSkippingFilter { StructField::new("maxValues", referenced_schema, true), ])); + let partitions_map_type = MapType::new(DataType::STRING, DataType::STRING, true); + // Skipping happens in several steps: // // 1. The stats selector fetches add.stats from the metadata @@ -106,6 +116,12 @@ impl DataSkippingFilter { DataType::STRING, ); + let partitions_evaluator = engine.get_expression_handler().get_evaluator( + get_log_add_schema().clone(), + PARITIONS_EXPR.clone(), + partitions_map_type.into(), + ); + let skipping_evaluator = engine.get_expression_handler().get_evaluator( stats_schema.clone(), Expr::struct_from([as_data_skipping_predicate(&predicate, false)?]), @@ -121,6 +137,7 @@ impl DataSkippingFilter { Some(Self { stats_schema, select_stats_evaluator, + partitions_evaluator, skipping_evaluator, filter_evaluator, json_handler: engine.get_json_handler(), @@ -138,6 +155,11 @@ impl DataSkippingFilter { .parse_json(stats, self.stats_schema.clone())?; assert_eq!(parsed_stats.len(), actions.len()); + let parsed_partitions = self.partitions_evaluator.evaluate(actions)?; + assert_eq!(parsed_partitions.len(), actions.len()); + + let parsed_stats = merge_partitions_into_stats(parsed_partitions, parsed_stats)?; + // evaluate the predicate on the parsed stats, then convert to selection vector let skipping_predicate = self.skipping_evaluator.evaluate(&*parsed_stats)?; assert_eq!(skipping_predicate.len(), actions.len()); @@ -257,3 +279,148 @@ impl DataSkippingPredicateEvaluator for DataSkippingPredicateCreator { Some(Expr::variadic(op, exprs)) } } + +/// This function computes the values for the partition arrays that are added to the stats +/// fields to follow. Since the partition columns are a MapArray, we need to find for each +/// key the value assigned for each log. +fn compute_partition_arrays( + partitions_column: &ArrayRef, + output_schema: &Arc, +) -> DeltaResult> { + let output_types: HashMap = + match output_schema.field_with_name("minValues")?.data_type() { + arrow_schema::DataType::Struct(fields) => fields + .iter() + .map(|field| (field.name().to_owned(), field.data_type().to_owned())), + _ => return Err(Error::engine_data_type("minValues")), + } + .collect(); + + let partitions_array = partitions_column + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::engine_data_type("Partitions"))?; + + let keys: HashSet = partitions_array + .keys() + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::engine_data_type("Partition keys"))? + .iter() + .filter_map(|s| s.map(|t| t.to_string())) + .collect(); + + let partition_values: HashMap = keys + .iter() + .filter_map(|key| { + let cast_type = output_types.get(key)?; + + let values = partitions_array + .iter() + .map(|maybe_partition| { + maybe_partition.and_then(|partition_data| { + let keys = partition_data + .column_by_name("key")? + .as_any() + .downcast_ref::()?; + let values = partition_data + .column_by_name("value")? + .as_any() + .downcast_ref::()?; + + let mut kv = + keys.iter() + .zip(values.iter()) + .filter_map(|(k, v)| match (k, v) { + (Some(k), Some(v)) => Some((k, v)), + _ => None, + }); + + kv.find(|(k, _)| *k == key.as_str()) + .map(|(_, v)| v.to_string()) + }) + }) + .collect::>>(); + + let string_array = StringArray::from(values); + let value_array = arrow_cast::cast(&string_array, &cast_type).ok()?; + Some((key.to_owned(), value_array)) + }) + .collect(); + + Ok(partition_values) +} + +/// This funtion builds up the stats fields for the min and max values. It assumes +/// that the arrays for the partition fields already exist. It will only build those +/// that match the predicate filters. +fn merge_partition_fields_into_stats( + stats_batch: &RecordBatch, + idx: usize, + partition_values: &HashMap>, +) -> DeltaResult> { + let (fields, mut arrays, nulls) = stats_batch + .column(idx) + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::engine_data_type("minValues"))? + .to_owned() + .into_parts(); + for (idx, field) in itertools::enumerate(fields.iter()) { + if let Some(arr) = partition_values.get(field.name()) { + arrays[idx] = Arc::clone(arr); + } + } + Ok(Arc::new(StructArray::new(fields, arrays, nulls)) + as Arc<(dyn arrow_array::Array + 'static)>) +} + +/// This function adds partition data to the stats fields. For each partition field for a log +/// it adds the partition value to both the `minValues` and `maxValues` fields, so that when +/// we match against it with the data skipping filters we can effectively skip files. +fn merge_partitions_into_stats( + partitions: Box, + stats: Box, +) -> DeltaResult> { + let partitions = ArrowEngineData::try_from_engine_data(partitions)?; + let partitions_batch = partitions.record_batch(); + + // If the struct is partitions data is emtpy, return the original stats + let partitions_column = match partitions_batch.column_by_name("output") { + Some(c) => c, + None => return Ok(stats), + }; + + let stats = ArrowEngineData::try_from_engine_data(stats)?; + let stats_batch = stats.record_batch(); + let output_schema = stats_batch.schema(); + + // For each unique partition key, generate the associated array + // to add to the stats fields + let partition_values = compute_partition_arrays(partitions_column, &output_schema)?; + if partition_values.is_empty() { + return Ok(stats); + } + + let mut columns = Vec::default(); + for (idx, field) in itertools::enumerate(output_schema.fields()) { + match field.name().as_str() { + "minValues" => columns.push(merge_partition_fields_into_stats( + stats_batch, + idx, + &partition_values, + )?), + "maxValues" => columns.push(merge_partition_fields_into_stats( + stats_batch, + idx, + &partition_values, + )?), + _ => { + columns.push(Arc::clone(stats_batch.column(idx))); + } + } + } + + let record_batch = RecordBatch::try_new(output_schema, columns)?; + Ok(Box::new(ArrowEngineData::new(record_batch))) +} From bd1b81423310890de3da3061d22f00e8dd17ac11 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 23 Dec 2024 14:25:21 -0500 Subject: [PATCH 2/2] Expand data skipping unit test to include partition --- kernel/src/table_changes/log_replay/tests.rs | 33 +++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/kernel/src/table_changes/log_replay/tests.rs b/kernel/src/table_changes/log_replay/tests.rs index f2dbdd956..81d452cff 100644 --- a/kernel/src/table_changes/log_replay/tests.rs +++ b/kernel/src/table_changes/log_replay/tests.rs @@ -25,6 +25,7 @@ fn get_schema() -> StructType { StructType::new([ StructField::new("id", DataType::INTEGER, true), StructField::new("value", DataType::STRING, true), + StructField::new("partition_col", DataType::INTEGER, true), ]) } @@ -490,6 +491,9 @@ async fn data_skipping_filter() { }), Action::Add(Add { path: "fake_path_1".into(), + partition_values: HashMap::from([ + ("partition_col".to_string(), "1".to_string()), + ]), stats: Some("{\"numRecords\":4,\"minValues\":{\"id\":4},\"maxValues\":{\"id\":6},\"nullCount\":{\"id\":3}}".into()), data_change: true, deletion_vector: deletion_vector.clone(), @@ -503,6 +507,9 @@ async fn data_skipping_filter() { }), Action::Add(Add { path: "fake_path_2".into(), + partition_values: HashMap::from([ + ("partition_col".to_string(), "2".to_string()), + ]), stats: Some("{\"numRecords\":4,\"minValues\":{\"id\":4},\"maxValues\":{\"id\":4},\"nullCount\":{\"id\":3}}".into()), data_change: true, deletion_vector, @@ -511,6 +518,9 @@ async fn data_skipping_filter() { // Add action with max value id = 5 Action::Add(Add { path: "fake_path_3".into(), + partition_values: HashMap::from([ + ("partition_col".to_string(), "3".to_string()), + ]), stats: Some("{\"numRecords\":4,\"minValues\":{\"id\":4},\"maxValues\":{\"id\":5},\"nullCount\":{\"id\":3}}".into()), data_change: true, ..Default::default() @@ -533,7 +543,7 @@ async fn data_skipping_filter() { .unwrap() .into_iter(); - let sv = table_changes_action_iter(engine, commits, logical_schema.into(), predicate) + let sv = table_changes_action_iter(engine.clone(), commits.clone(), logical_schema.clone().into(), predicate) .unwrap() .flat_map(|scan_data| { let scan_data = scan_data.unwrap(); @@ -543,6 +553,27 @@ async fn data_skipping_filter() { // Note: since the first pair is a dv operation, remove action will always be filtered assert_eq!(sv, &[false, true, false, false, true]); + + let predicate = Expression::binary( + BinaryOperator::LessThanOrEqual, + column_expr!("partition_col"), + Scalar::from(2), + ); + let predicate = match PhysicalPredicate::try_new(&predicate, &logical_schema) { + Ok(PhysicalPredicate::Some(p, s)) => Some((p, s)), + other => panic!("Unexpected result: {:?}", other), + }; + + let sv = table_changes_action_iter(engine, commits, logical_schema.into(), predicate) + .unwrap() + .flat_map(|scan_data| { + let scan_data = scan_data.unwrap(); + scan_data.selection_vector + }) + .collect_vec(); + + // Note: since the first pair is a dv operation, remove action will always be filtered + assert_eq!(sv, &[false, true, false, true, false]); } #[tokio::test]