diff --git a/.github/workflows/python_build.yml b/.github/workflows/python_build.yml index ce2a7e0bfd..d4ed2d4f7d 100644 --- a/.github/workflows/python_build.yml +++ b/.github/workflows/python_build.yml @@ -31,7 +31,7 @@ jobs: run: make check-rust test-minimal: - name: Python Build (Python 3.8 PyArrow 8.0.0) + name: Python Build (Python 3.8 PyArrow 16.0.0) runs-on: ubuntu-latest env: RUSTFLAGS: "-C debuginfo=line-tables-only" @@ -51,7 +51,7 @@ jobs: source venv/bin/activate make setup # Install minimum PyArrow version - pip install -e .[pandas,devel] pyarrow==8.0.0 + pip install -e .[pandas,devel] pyarrow==16.0.0 env: RUSTFLAGS: "-C debuginfo=line-tables-only" @@ -60,10 +60,6 @@ jobs: source venv/bin/activate make unit-test - # - name: Run Integration tests - # run: | - # py.test --cov tests -m integration - test: name: Python Build (Python 3.10 PyArrow latest) runs-on: ubuntu-latest diff --git a/crates/core/src/delta_datafusion/find_files/mod.rs b/crates/core/src/delta_datafusion/find_files/mod.rs index d25d0765ee..956966d3e7 100644 --- a/crates/core/src/delta_datafusion/find_files/mod.rs +++ b/crates/core/src/delta_datafusion/find_files/mod.rs @@ -137,7 +137,7 @@ async fn scan_table_by_files( .with_file_column(true) .build(&snapshot)?; - let logical_schema = df_logical_schema(&snapshot, &scan_config)?; + let logical_schema = df_logical_schema(&snapshot, &scan_config.file_column_name, None)?; // Identify which columns we need to project let mut used_columns = expression diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index c2b410cb74..99749bf79a 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -305,9 +305,13 @@ pub(crate) fn register_store(store: LogStoreRef, env: Arc) { /// at the physical level pub(crate) fn df_logical_schema( snapshot: &DeltaTableState, - scan_config: &DeltaScanConfig, + file_column_name: &Option, + schema: Option, ) -> DeltaResult { - let input_schema = snapshot.arrow_schema()?; + let input_schema = match schema { + Some(schema) => schema, + None => snapshot.input_schema()?, + }; let table_partition_cols = &snapshot.metadata().partition_columns; let mut fields: Vec> = input_schema @@ -326,7 +330,7 @@ pub(crate) fn df_logical_schema( )); } - if let Some(file_column_name) = &scan_config.file_column_name { + if let Some(file_column_name) = file_column_name { fields.push(Arc::new(Field::new(file_column_name, DataType::Utf8, true))); } @@ -528,7 +532,11 @@ impl<'a> DeltaScanBuilder<'a> { None => self.snapshot.arrow_schema(), }?; - let logical_schema = df_logical_schema(self.snapshot, &config)?; + let logical_schema = df_logical_schema( + self.snapshot, + &config.file_column_name, + Some(schema.clone()), + )?; let logical_schema = if let Some(used_columns) = self.projection { let mut fields = vec![]; @@ -733,7 +741,7 @@ impl TableProvider for DeltaTable { filter: &[&Expr], ) -> DataFusionResult> { Ok(filter - .into_iter() + .iter() .map(|_| TableProviderFilterPushDown::Inexact) .collect()) } @@ -760,7 +768,7 @@ impl DeltaTableProvider { config: DeltaScanConfig, ) -> DeltaResult { Ok(DeltaTableProvider { - schema: df_logical_schema(&snapshot, &config)?, + schema: df_logical_schema(&snapshot, &config.file_column_name, config.schema.clone())?, snapshot, log_store, config, @@ -1524,7 +1532,7 @@ pub(crate) async fn find_files_scan<'a>( } .build(snapshot)?; - let logical_schema = df_logical_schema(snapshot, &scan_config)?; + let logical_schema = df_logical_schema(snapshot, &scan_config.file_column_name, None)?; // Identify which columns we need to project let mut used_columns = expression diff --git a/crates/core/src/operations/add_column.rs b/crates/core/src/operations/add_column.rs index 028a6e5b2e..8fff1677b8 100644 --- a/crates/core/src/operations/add_column.rs +++ b/crates/core/src/operations/add_column.rs @@ -4,11 +4,10 @@ use delta_kernel::schema::StructType; use futures::future::BoxFuture; use itertools::Itertools; -use super::cast::merge_struct; use super::transaction::{CommitBuilder, CommitProperties, PROTOCOL}; - use crate::kernel::StructField; use crate::logstore::LogStoreRef; +use crate::operations::cast::merge_schema::merge_delta_struct; use crate::protocol::DeltaOperation; use crate::table::state::DeltaTableState; use crate::{DeltaResult, DeltaTable, DeltaTableError}; @@ -67,7 +66,7 @@ impl std::future::IntoFuture for AddColumnBuilder { let fields_right = &StructType::new(fields.clone()); let table_schema = this.snapshot.schema(); - let new_table_schema = merge_struct(table_schema, fields_right)?; + let new_table_schema = merge_delta_struct(table_schema, fields_right)?; // TODO(ion): Think of a way how we can simply this checking through the API or centralize some checks. let contains_timestampntz = PROTOCOL.contains_timestampntz(fields.iter()); diff --git a/crates/core/src/operations/cast/merge_schema.rs b/crates/core/src/operations/cast/merge_schema.rs new file mode 100644 index 0000000000..597700ad3f --- /dev/null +++ b/crates/core/src/operations/cast/merge_schema.rs @@ -0,0 +1,350 @@ +//! Provide schema merging for delta schemas +//! +use crate::kernel::{ArrayType, DataType as DeltaDataType, MapType, StructField, StructType}; +use arrow::datatypes::DataType::Dictionary; +use arrow_schema::{ + ArrowError, DataType, Field as ArrowField, Fields, Schema as ArrowSchema, + SchemaRef as ArrowSchemaRef, +}; +use std::collections::HashMap; + +fn try_merge_metadata( + left: &mut HashMap, + right: &HashMap, +) -> Result<(), ArrowError> { + for (k, v) in right { + if let Some(vl) = left.get(k) { + if vl != v { + return Err(ArrowError::SchemaError(format!( + "Cannot merge metadata with different values for key {}", + k + ))); + } + } else { + left.insert(k.clone(), v.clone()); + } + } + Ok(()) +} + +pub(crate) fn merge_delta_type( + left: &DeltaDataType, + right: &DeltaDataType, +) -> Result { + if left == right { + return Ok(left.clone()); + } + match (left, right) { + (DeltaDataType::Array(a), DeltaDataType::Array(b)) => { + let merged = merge_delta_type(&a.element_type, &b.element_type)?; + Ok(DeltaDataType::Array(Box::new(ArrayType::new( + merged, + a.contains_null() || b.contains_null(), + )))) + } + (DeltaDataType::Map(a), DeltaDataType::Map(b)) => { + let merged_key = merge_delta_type(&a.key_type, &b.key_type)?; + let merged_value = merge_delta_type(&a.value_type, &b.value_type)?; + Ok(DeltaDataType::Map(Box::new(MapType::new( + merged_key, + merged_value, + a.value_contains_null() || b.value_contains_null(), + )))) + } + (DeltaDataType::Struct(a), DeltaDataType::Struct(b)) => { + let merged = merge_delta_struct(a, b)?; + Ok(DeltaDataType::Struct(Box::new(merged))) + } + (a, b) => Err(ArrowError::SchemaError(format!( + "Cannot merge types {} and {}", + a, b + ))), + } +} + +pub(crate) fn merge_delta_struct( + left: &StructType, + right: &StructType, +) -> Result { + let mut errors = Vec::new(); + let merged_fields: Result, ArrowError> = left + .fields() + .map(|field| { + let right_field = right.field(field.name()); + if let Some(right_field) = right_field { + let type_or_not = merge_delta_type(field.data_type(), right_field.data_type()); + match type_or_not { + Err(e) => { + errors.push(e.to_string()); + Err(e) + } + Ok(f) => { + let mut new_field = StructField::new( + field.name(), + f, + field.is_nullable() || right_field.is_nullable(), + ); + + new_field.metadata.clone_from(&field.metadata); + try_merge_metadata(&mut new_field.metadata, &right_field.metadata)?; + Ok(new_field) + } + } + } else { + Ok(field.clone()) + } + }) + .collect(); + match merged_fields { + Ok(mut fields) => { + for field in right.fields() { + if !left.field(field.name()).is_some() { + fields.push(field.clone()); + } + } + + Ok(StructType::new(fields)) + } + Err(e) => { + errors.push(e.to_string()); + Err(ArrowError::SchemaError(errors.join("\n"))) + } + } +} + +pub(crate) fn merge_arrow_field( + left: &ArrowField, + right: &ArrowField, + preserve_new_fields: bool, +) -> Result { + if left == right { + return Ok(left.clone()); + } + + let (table_type, batch_type) = (left.data_type(), right.data_type()); + + match (table_type, batch_type) { + (Dictionary(key_type, value_type), _) + if matches!( + value_type.as_ref(), + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + ) && matches!( + batch_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) => + { + Ok(ArrowField::new( + right.name(), + Dictionary(key_type.clone(), Box::new(batch_type.clone())), + left.is_nullable() || right.is_nullable(), + )) + } + (Dictionary(key_type, value_type), _) + if matches!( + value_type.as_ref(), + DataType::Binary | DataType::BinaryView | DataType::LargeBinary + ) && matches!( + batch_type, + DataType::Binary | DataType::LargeBinary | DataType::BinaryView + ) => + { + Ok(ArrowField::new( + right.name(), + Dictionary(key_type.clone(), Box::new(batch_type.clone())), + left.is_nullable() || right.is_nullable(), + )) + } + (Dictionary(_, value_type), _) if value_type.equals_datatype(batch_type) => Ok(left + .clone() + .with_nullable(left.is_nullable() || right.is_nullable())), + + (_, Dictionary(_, value_type)) + if matches!( + table_type, + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + ) && matches!( + value_type.as_ref(), + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) => + { + Ok(right + .clone() + .with_nullable(left.is_nullable() || right.is_nullable())) + } + (_, Dictionary(_, value_type)) + if matches!( + table_type, + DataType::Binary | DataType::BinaryView | DataType::LargeBinary + ) && matches!( + value_type.as_ref(), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView + ) => + { + Ok(right + .clone() + .with_nullable(left.is_nullable() || right.is_nullable())) + } + (_, Dictionary(_, value_type)) if value_type.equals_datatype(table_type) => Ok(right + .clone() + .with_nullable(left.is_nullable() || right.is_nullable())), + // With Utf8/binary we always take the right type since that is coming from the incoming data + // by doing that we allow passthrough of any string flavor + ( + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View, + ) + | ( + DataType::Binary | DataType::BinaryView | DataType::LargeBinary, + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => Ok(ArrowField::new( + left.name(), + batch_type.clone(), + right.is_nullable() || left.is_nullable(), + )), + ( + DataType::List(left_child_fields) | DataType::LargeList(left_child_fields), + DataType::LargeList(right_child_fields), + ) => { + let merged = + merge_arrow_field(left_child_fields, right_child_fields, preserve_new_fields)?; + Ok(ArrowField::new( + left.name(), + DataType::LargeList(merged.into()), + right.is_nullable() || left.is_nullable(), + )) + } + ( + DataType::List(left_child_fields) | DataType::LargeList(left_child_fields), + DataType::List(right_child_fields), + ) => { + let merged = + merge_arrow_field(left_child_fields, right_child_fields, preserve_new_fields)?; + Ok(ArrowField::new( + left.name(), + DataType::List(merged.into()), + right.is_nullable() || left.is_nullable(), + )) + } + (DataType::Struct(left_child_fields), DataType::Struct(right_child_fields)) => { + let merged = + merge_arrow_vec_fields(left_child_fields, right_child_fields, preserve_new_fields)?; + Ok(ArrowField::new( + left.name(), + DataType::Struct(merged.into()), + right.is_nullable() || left.is_nullable(), + )) + } + (DataType::Map(left_field, left_sorted), DataType::Map(right_field, right_sorted)) + if left_sorted == right_sorted => + { + let merged = merge_arrow_field(left_field, right_field, preserve_new_fields)?; + Ok(ArrowField::new( + left.name(), + DataType::Map(merged.into(), *right_sorted), + right.is_nullable() || left.is_nullable(), + )) + } + _ => { + let mut new_field = left.clone(); + match new_field.try_merge(right) { + Ok(()) => (), + Err(err) => { + // We cannot keep the table field here, there is some weird behavior where + // Decimal(5,1) can be safely casted into Decimal(4,1) with out loss of data + // Then our stats parser fails to parse this decimal(1000.1) into Decimal(4,1) + // even though datafusion was able to write it into parquet + // We manually have to check if the decimal in the recordbatch is a subset of the table decimal + if let ( + DataType::Decimal128(left_precision, left_scale) + | DataType::Decimal256(left_precision, left_scale), + DataType::Decimal128(right_precision, right_scale), + ) = (right.data_type(), new_field.data_type()) + { + if !(left_precision <= right_precision && left_scale <= right_scale) { + return Err(ArrowError::SchemaError(format!( + "Cannot merge field {} from {} to {}", + right.name(), + right.data_type(), + new_field.data_type() + ))); + } + }; + // If it's not Decimal datatype, the new_field remains the left table field. + } + }; + Ok(new_field) + } + } +} + +/// Merges Arrow Table schema and Arrow Batch Schema, by allowing Large/View Types to passthrough. +// Sometimes fields can't be merged because they are not the same types. So table has int32, +// but batch int64. We want the preserve the table type. At later stage we will call cast_record_batch +// which will cast the batch int64->int32. This is desired behaviour so we can have flexibility +// in the batch data types. But preserve the correct table and parquet types. +// +// Preserve_new_fields can also be disabled if you just want to only use the passthrough functionality +pub(crate) fn merge_arrow_schema( + table_schema: ArrowSchemaRef, + batch_schema: ArrowSchemaRef, + preserve_new_fields: bool, +) -> Result { + let table_fields = table_schema.fields(); + let batch_fields = batch_schema.fields(); + + let merged_schema = ArrowSchema::new(merge_arrow_vec_fields( + table_fields, + batch_fields, + preserve_new_fields, + )?) + .into(); + Ok(merged_schema) +} + +fn merge_arrow_vec_fields( + table_fields: &Fields, + batch_fields: &Fields, + preserve_new_fields: bool, +) -> Result, ArrowError> { + let mut errors = Vec::with_capacity(table_fields.len()); + let merged_fields: Result, ArrowError> = table_fields + .iter() + .map(|field| { + let right_field = batch_fields.find(field.name()); + if let Some((_, right_field)) = right_field { + let field_or_not = + merge_arrow_field(field.as_ref(), right_field, preserve_new_fields); + match field_or_not { + Err(e) => { + errors.push(e.to_string()); + Err(e) + } + Ok(mut f) => { + let mut field_matadata = f.metadata().clone(); + try_merge_metadata(&mut field_matadata, right_field.metadata())?; + f.set_metadata(field_matadata); + Ok(f) + } + } + } else { + Ok(field.as_ref().clone()) + } + }) + .collect(); + match merged_fields { + Ok(mut fields) => { + if preserve_new_fields { + for field in batch_fields.into_iter() { + if table_fields.find(field.name()).is_none() { + fields.push(field.as_ref().clone()); + } + } + } + Ok(fields) + } + Err(e) => { + errors.push(e.to_string()); + Err(ArrowError::SchemaError(errors.join("\n"))) + } + } +} diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast/mod.rs similarity index 76% rename from crates/core/src/operations/cast.rs rename to crates/core/src/operations/cast/mod.rs index 68f630239d..554373e623 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast/mod.rs @@ -1,8 +1,5 @@ //! Provide common cast functionality for callers //! -use crate::kernel::{ - ArrayType, DataType as DeltaDataType, MapType, MetadataValue, StructField, StructType, -}; use arrow_array::cast::AsArray; use arrow_array::{ new_null_array, Array, ArrayRef, GenericListArray, MapArray, OffsetSizeTrait, RecordBatch, @@ -10,124 +7,11 @@ use arrow_array::{ }; use arrow_cast::{cast_with_options, CastOptions}; use arrow_schema::{ArrowError, DataType, FieldRef, Fields, SchemaRef as ArrowSchemaRef}; -use std::collections::HashMap; use std::sync::Arc; -use crate::DeltaResult; - -fn try_merge_metadata( - left: &mut HashMap, - right: &HashMap, -) -> Result<(), ArrowError> { - for (k, v) in right { - if let Some(vl) = left.get(k) { - if vl != v { - return Err(ArrowError::SchemaError(format!( - "Cannot merge metadata with different values for key {}", - k - ))); - } - } else { - left.insert(k.clone(), v.clone()); - } - } - Ok(()) -} - -pub(crate) fn merge_struct( - left: &StructType, - right: &StructType, -) -> Result { - let mut errors = Vec::new(); - let merged_fields: Result, ArrowError> = left - .fields() - .map(|field| { - let right_field = right.field(field.name()); - if let Some(right_field) = right_field { - let type_or_not = merge_type(field.data_type(), right_field.data_type()); - match type_or_not { - Err(e) => { - errors.push(e.to_string()); - Err(e) - } - Ok(f) => { - let mut new_field = StructField::new( - field.name(), - f, - field.is_nullable() || right_field.is_nullable(), - ); - - new_field.metadata.clone_from(&field.metadata); - try_merge_metadata(&mut new_field.metadata, &right_field.metadata)?; - Ok(new_field) - } - } - } else { - Ok(field.clone()) - } - }) - .collect(); - match merged_fields { - Ok(mut fields) => { - for field in right.fields() { - if !left.field(field.name()).is_some() { - fields.push(field.clone()); - } - } - - Ok(StructType::new(fields)) - } - Err(e) => { - errors.push(e.to_string()); - Err(ArrowError::SchemaError(errors.join("\n"))) - } - } -} - -pub(crate) fn merge_type( - left: &DeltaDataType, - right: &DeltaDataType, -) -> Result { - if left == right { - return Ok(left.clone()); - } - match (left, right) { - (DeltaDataType::Array(a), DeltaDataType::Array(b)) => { - let merged = merge_type(&a.element_type, &b.element_type)?; - Ok(DeltaDataType::Array(Box::new(ArrayType::new( - merged, - a.contains_null() || b.contains_null(), - )))) - } - (DeltaDataType::Map(a), DeltaDataType::Map(b)) => { - let merged_key = merge_type(&a.key_type, &b.key_type)?; - let merged_value = merge_type(&a.value_type, &b.value_type)?; - Ok(DeltaDataType::Map(Box::new(MapType::new( - merged_key, - merged_value, - a.value_contains_null() || b.value_contains_null(), - )))) - } - (DeltaDataType::Struct(a), DeltaDataType::Struct(b)) => { - let merged = merge_struct(a, b)?; - Ok(DeltaDataType::Struct(Box::new(merged))) - } - (a, b) => Err(ArrowError::SchemaError(format!( - "Cannot merge types {} and {}", - a, b - ))), - } -} +pub(crate) mod merge_schema; -pub(crate) fn merge_schema( - left: ArrowSchemaRef, - right: ArrowSchemaRef, -) -> Result { - let left_delta: StructType = left.try_into()?; - let right_delta: StructType = right.try_into()?; - let merged: StructType = merge_struct(&left_delta, &right_delta)?; - Ok(Arc::new((&merged).try_into()?)) -} +use crate::DeltaResult; fn cast_struct( struct_array: &StructArray, @@ -142,15 +26,16 @@ fn cast_struct( .map(|field| { let col_or_not = struct_array.column_by_name(field.name()); match col_or_not { - None => match add_missing { - true if field.is_nullable() => { + None => { + if add_missing && field.is_nullable() { Ok(new_null_array(field.data_type(), struct_array.len())) + } else { + Err(ArrowError::SchemaError(format!( + "Could not find column {}", + field.name() + ))) } - _ => Err(ArrowError::SchemaError(format!( - "Could not find column {0}", - field.name() - ))), - }, + } Some(col) => cast_field(col, field, cast_options, add_missing), } }) @@ -204,64 +89,64 @@ fn cast_field( cast_options: &CastOptions, add_missing: bool, ) -> Result { - if let (DataType::Struct(_), DataType::Struct(child_fields)) = - (col.data_type(), field.data_type()) - { - let child_struct = StructArray::from(col.into_data()); - Ok(Arc::new(cast_struct( - &child_struct, - child_fields, - cast_options, - add_missing, - )?) as ArrayRef) - } else if let (DataType::List(_), DataType::List(child_fields)) = - (col.data_type(), field.data_type()) - { - Ok(Arc::new(cast_list( + let (col_type, field_type) = (col.data_type(), field.data_type()); + + match (col_type, field_type) { + (DataType::Struct(_), DataType::Struct(child_fields)) => { + let child_struct = StructArray::from(col.into_data()); + Ok(Arc::new(cast_struct( + &child_struct, + child_fields, + cast_options, + add_missing, + )?) as ArrayRef) + } + (DataType::List(_), DataType::List(child_fields)) => Ok(Arc::new(cast_list( col.as_any() .downcast_ref::>() - .ok_or(ArrowError::CastError(format!( - "Expected a list for {} but got {}", - field.name(), - col.data_type() - )))?, + .ok_or_else(|| { + ArrowError::CastError(format!( + "Expected a list for {} but got {}", + field.name(), + col_type + )) + })?, child_fields, cast_options, add_missing, - )?) as ArrayRef) - } else if let (DataType::LargeList(_), DataType::LargeList(child_fields)) = - (col.data_type(), field.data_type()) - { - Ok(Arc::new(cast_list( + )?) as ArrayRef), + (DataType::LargeList(_), DataType::LargeList(child_fields)) => Ok(Arc::new(cast_list( col.as_any() .downcast_ref::>() - .ok_or(ArrowError::CastError(format!( - "Expected a list for {} but got {}", - field.name(), - col.data_type() - )))?, + .ok_or_else(|| { + ArrowError::CastError(format!( + "Expected a list for {} but got {}", + field.name(), + col_type + )) + })?, child_fields, cast_options, add_missing, - )?) as ArrayRef) - } else if let (DataType::Map(_, _), DataType::Map(child_fields, sorted)) = - (col.data_type(), field.data_type()) - { - Ok(Arc::new(cast_map( - col.as_map_opt().ok_or(ArrowError::CastError(format!( - "Expected a map for {} but got {}", - field.name(), - col.data_type() - )))?, + )?) as ArrayRef), + // TODO: add list view cast + (DataType::Map(_, _), DataType::Map(child_fields, sorted)) => Ok(Arc::new(cast_map( + col.as_map_opt().ok_or_else(|| { + ArrowError::CastError(format!( + "Expected a map for {} but got {}", + field.name(), + col_type + )) + })?, child_fields, *sorted, cast_options, add_missing, - )?) as ArrayRef) - } else if is_cast_required(col.data_type(), field.data_type()) { - cast_with_options(col, field.data_type(), cast_options) - } else { - Ok(col.clone()) + )?) as ArrayRef), + _ if is_cast_required(col_type, field_type) => { + cast_with_options(col, field_type, cast_options) + } + _ => Ok(col.clone()), } } @@ -293,6 +178,7 @@ pub fn cast_record_batch( None, ); let struct_array = cast_struct(&s, target_schema.fields(), &cast_options, add_missing)?; + Ok(RecordBatch::try_new_with_options( target_schema, struct_array.columns().to_vec(), @@ -306,6 +192,7 @@ mod tests { use std::ops::Deref; use std::sync::Arc; + use super::merge_schema::{merge_arrow_schema, merge_delta_struct}; use arrow::array::types::Int32Type; use arrow::array::{ new_empty_array, new_null_array, Array, ArrayData, ArrayRef, AsArray, Int32Array, @@ -313,17 +200,17 @@ mod tests { }; use arrow::buffer::{Buffer, NullBuffer}; use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; + use delta_kernel::schema::MetadataValue; use itertools::Itertools; use crate::kernel::{ ArrayType as DeltaArrayType, DataType as DeltaDataType, StructField as DeltaStructField, StructType as DeltaStructType, }; - use crate::operations::cast::MetadataValue; use crate::operations::cast::{cast_record_batch, is_cast_required}; #[test] - fn test_merge_schema_with_dict() { + fn test_merge_arrow_schema_with_dict() { let left_schema = Arc::new(Schema::new(vec![Field::new( "f", DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), @@ -335,7 +222,7 @@ mod tests { true, )])); - let result = super::merge_schema(left_schema, right_schema).unwrap(); + let result = merge_arrow_schema(left_schema, right_schema, true).unwrap(); assert_eq!(result.fields().len(), 1); let delta_type: DeltaDataType = result.fields()[0].data_type().try_into().unwrap(); assert_eq!(delta_type, DeltaDataType::STRING); @@ -343,7 +230,7 @@ mod tests { } #[test] - fn test_merge_schema_with_meta() { + fn test_merge_delta_schema_with_meta() { let mut left_meta = HashMap::new(); left_meta.insert("a".to_string(), "a1".to_string()); let left_schema = DeltaStructType::new(vec![DeltaStructField::new( @@ -361,7 +248,7 @@ mod tests { ) .with_metadata(right_meta)]); - let result = super::merge_struct(&left_schema, &right_schema).unwrap(); + let result = merge_delta_struct(&left_schema, &right_schema).unwrap(); let fields = result.fields().collect_vec(); assert_eq!(fields.len(), 1); let delta_type = fields[0].data_type(); @@ -373,7 +260,7 @@ mod tests { } #[test] - fn test_merge_schema_with_nested() { + fn test_merge_arrow_schema_with_nested() { let left_schema = Arc::new(Schema::new(vec![Field::new( "f", DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, false))), @@ -385,7 +272,7 @@ mod tests { true, )])); - let result = super::merge_schema(left_schema, right_schema).unwrap(); + let result = merge_arrow_schema(left_schema, right_schema, true).unwrap(); assert_eq!(result.fields().len(), 1); let delta_type: DeltaDataType = result.fields()[0].data_type().try_into().unwrap(); assert_eq!( diff --git a/crates/core/src/operations/cdc.rs b/crates/core/src/operations/cdc.rs index 42a33cbcab..f95b32ea75 100644 --- a/crates/core/src/operations/cdc.rs +++ b/crates/core/src/operations/cdc.rs @@ -5,12 +5,8 @@ use crate::table::state::DeltaTableState; use crate::DeltaResult; -use arrow::datatypes::{DataType, Field, SchemaRef}; - use datafusion::prelude::*; use datafusion_common::ScalarValue; -use std::sync::Arc; -use tracing::log::*; /// The CDCTracker is useful for hooking reads/writes in a manner nececessary to create CDC files /// associated with commits @@ -84,6 +80,8 @@ pub(crate) fn should_write_cdc(snapshot: &DeltaTableState) -> DeltaResult #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::kernel::DataType as DeltaDataType; use crate::kernel::{Action, PrimitiveType, Protocol}; @@ -91,7 +89,7 @@ mod tests { use crate::{DeltaConfigKey, DeltaTable}; use arrow::array::{ArrayRef, Int32Array, StructArray}; use arrow_array::RecordBatch; - use arrow_schema::Schema; + use arrow_schema::{DataType, Field, Schema}; use datafusion::assert_batches_sorted_eq; use datafusion::datasource::{MemTable, TableProvider}; diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index 692c1b303b..8ddd948d35 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -43,6 +43,7 @@ use serde::Serialize; use super::cdc::should_write_cdc; use super::datafusion_utils::Expression; use super::transaction::{CommitBuilder, CommitProperties, PROTOCOL}; +use super::write::WriterStatsConfig; use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::{ @@ -51,7 +52,7 @@ use crate::delta_datafusion::{ }; use crate::errors::DeltaResult; use crate::kernel::{Action, Add, Remove}; -use crate::operations::write::{write_execution_plan, write_execution_plan_cdc, WriterStatsConfig}; +use crate::operations::write::{write_execution_plan, write_execution_plan_cdc}; use crate::protocol::DeltaOperation; use crate::table::state::DeltaTableState; use crate::{DeltaTable, DeltaTableError}; @@ -236,8 +237,6 @@ async fn excute_non_empty_expr( Some(snapshot.table_config().target_file_size() as usize), None, writer_properties.clone(), - false, - None, writer_stats_config.clone(), None, ) @@ -266,7 +265,6 @@ async fn excute_non_empty_expr( .create_physical_plan() .await?; - use crate::operations::write::write_execution_plan_cdc; let cdc_actions = write_execution_plan_cdc( Some(snapshot), state.clone(), @@ -276,7 +274,6 @@ async fn excute_non_empty_expr( Some(snapshot.table_config().target_file_size() as usize), None, writer_properties, - false, writer_stats_config, None, ) diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index ea54e4e211..1699011f51 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -69,8 +69,8 @@ use crate::delta_datafusion::logical::MetricObserver; use crate::delta_datafusion::physical::{find_metric_node, get_metric, MetricObserverExec}; use crate::delta_datafusion::planner::DeltaPlanner; use crate::delta_datafusion::{ - execute_plan_to_batch, register_store, DeltaColumn, DeltaScanConfigBuilder, DeltaSessionConfig, - DeltaTableProvider, + execute_plan_to_batch, register_store, DataFusionMixins, DeltaColumn, DeltaScanConfigBuilder, + DeltaSessionConfig, DeltaTableProvider, }; use crate::kernel::Action; use crate::logstore::LogStoreRef; @@ -1060,6 +1060,7 @@ async fn execute( let scan_config = DeltaScanConfigBuilder::default() .with_file_column(true) .with_parquet_pushdown(false) + .with_schema(snapshot.input_schema()?) .build(&snapshot)?; let target_provider = Arc::new(DeltaTableProvider::try_new( @@ -1459,8 +1460,6 @@ async fn execute( Some(snapshot.table_config().target_file_size() as usize), None, writer_properties, - safe_cast, - None, writer_stats_config, None, ) diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 2a947f486f..37837ed60a 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -283,6 +283,7 @@ async fn execute( let scan_config = DeltaScanConfigBuilder::default() .with_file_column(false) + .with_schema(snapshot.input_schema()?) .build(&snapshot)?; // For each rewrite evaluate the predicate and then modify each expression @@ -345,10 +346,7 @@ async fn execute( .map(|v| v.iter().map(|v| v.to_string()).collect::>()), ); - let tracker = CDCTracker::new( - df, - updated_df.drop_columns(&vec![UPDATE_PREDICATE_COLNAME])?, - ); + let tracker = CDCTracker::new(df, updated_df.drop_columns(&[UPDATE_PREDICATE_COLNAME])?); let add_actions = write_execution_plan( Some(&snapshot), @@ -359,8 +357,6 @@ async fn execute( Some(snapshot.table_config().target_file_size() as usize), None, writer_properties.clone(), - safe_cast, - None, writer_stats_config.clone(), None, ) @@ -424,7 +420,6 @@ async fn execute( Some(snapshot.table_config().target_file_size() as usize), None, writer_properties, - safe_cast, writer_stats_config, None, ) diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 923eadeeaf..95afed637b 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -61,7 +61,7 @@ use crate::delta_datafusion::{DataFusionMixins, DeltaDataChecker}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, Add, AddCDCFile, Metadata, PartitionsExt, Remove, StructType}; use crate::logstore::LogStoreRef; -use crate::operations::cast::{cast_record_batch, merge_schema}; +use crate::operations::cast::{cast_record_batch, merge_schema::merge_arrow_schema}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; @@ -378,18 +378,12 @@ async fn write_execution_plan_with_predicate( target_file_size: Option, write_batch_size: Option, writer_properties: Option, - safe_cast: bool, - schema_mode: Option, writer_stats_config: WriterStatsConfig, sender: Option>, ) -> DeltaResult> { - let schema: ArrowSchemaRef = if schema_mode.is_some() { - plan.schema() - } else { - snapshot - .and_then(|s| s.input_schema().ok()) - .unwrap_or(plan.schema()) - }; + // We always take the plan Schema since the data may contain Large/View arrow types, + // the schema and batches were prior constructed with this in mind. + let schema: ArrowSchemaRef = plan.schema(); let checker = if let Some(snapshot) = snapshot { DeltaDataChecker::new(snapshot) } else { @@ -431,21 +425,15 @@ async fn write_execution_plan_with_predicate( let batch = maybe_batch?; checker_stream.check_batch(&batch).await?; - let arr = super::cast::cast_record_batch( - &batch, - inner_schema.clone(), - safe_cast, - schema_mode == Some(SchemaMode::Merge), - )?; if let Some(s) = sendable.as_ref() { - if let Err(e) = s.send(arr.clone()).await { + if let Err(e) = s.send(batch.clone()).await { error!("Failed to send data to observer: {e:#?}"); } } else { debug!("write_execution_plan_with_predicate did not send any batches, no sender."); } - writer.write(&arr).await?; + writer.write(&batch).await?; } let add_actions = writer.close().await; match add_actions { @@ -481,15 +469,11 @@ pub(crate) async fn write_execution_plan_cdc( target_file_size: Option, write_batch_size: Option, writer_properties: Option, - safe_cast: bool, writer_stats_config: WriterStatsConfig, sender: Option>, ) -> DeltaResult> { let cdc_store = Arc::new(PrefixStore::new(object_store, "_change_data")); - // If not overwrite, the plan schema is not taken but table schema, - // however we need the plan schema since it has the _change_type_col - let schema_mode = Some(SchemaMode::Overwrite); Ok(write_execution_plan( snapshot, state, @@ -499,8 +483,6 @@ pub(crate) async fn write_execution_plan_cdc( target_file_size, write_batch_size, writer_properties, - safe_cast, - schema_mode, writer_stats_config, sender, ) @@ -536,8 +518,6 @@ pub(crate) async fn write_execution_plan( target_file_size: Option, write_batch_size: Option, writer_properties: Option, - safe_cast: bool, - schema_mode: Option, writer_stats_config: WriterStatsConfig, sender: Option>, ) -> DeltaResult> { @@ -551,8 +531,6 @@ pub(crate) async fn write_execution_plan( target_file_size, write_batch_size, writer_properties, - safe_cast, - schema_mode, writer_stats_config, sender, ) @@ -609,8 +587,6 @@ async fn execute_non_empty_expr( Some(snapshot.table_config().target_file_size() as usize), None, writer_properties.clone(), - false, - None, writer_stats_config.clone(), None, ) @@ -692,7 +668,6 @@ pub(crate) async fn execute_non_empty_expr_cdc( Some(snapshot.table_config().target_file_size() as usize), None, writer_properties, - false, writer_stats_config, None, ) @@ -851,41 +826,50 @@ impl std::future::IntoFuture for WriteBuilder { let mut new_schema = None; if let Some(snapshot) = &this.snapshot { - let table_schema = snapshot - .physical_arrow_schema(this.log_store.object_store().clone()) - .await - .or_else(|_| snapshot.arrow_schema()) - .unwrap_or(schema.clone()); - + let table_schema = snapshot.input_schema()?; if let Err(schema_err) = try_cast_batch(schema.fields(), table_schema.fields()) { schema_drift = true; if this.mode == SaveMode::Overwrite - && this.schema_mode == Some(SchemaMode::Merge) - { - new_schema = - Some(merge_schema(table_schema.clone(), schema.clone())?); - } else if this.mode == SaveMode::Overwrite && this.schema_mode.is_some() + && this.schema_mode == Some(SchemaMode::Overwrite) { new_schema = None // we overwrite anyway, so no need to cast } else if this.schema_mode == Some(SchemaMode::Merge) { - new_schema = - Some(merge_schema(table_schema.clone(), schema.clone())?); + new_schema = Some(merge_arrow_schema( + table_schema.clone(), + schema.clone(), + schema_drift, + )?); } else { return Err(schema_err.into()); } + } else if this.mode == SaveMode::Overwrite + && this.schema_mode == Some(SchemaMode::Overwrite) + { + new_schema = None // we overwrite anyway, so no need to cast + } else { + // Schema needs to be merged so that utf8/binary/list types are preserved from the batch side if both table + // and batch contains such type. Other types are preserved from the table side. + // At this stage it will never introduce more fields since try_cast_batch passed correctly. + new_schema = Some(merge_arrow_schema( + table_schema.clone(), + schema.clone(), + schema_drift, + )?); } } - let data = if !partition_columns.is_empty() { // TODO partitioning should probably happen in its own plan ... let mut partitions: HashMap> = HashMap::new(); for batch in batches { let real_batch = match new_schema.clone() { - Some(new_schema) => { - cast_record_batch(&batch, new_schema, false, true)? - } + Some(new_schema) => cast_record_batch( + &batch, + new_schema, + this.safe_cast, + schema_drift, // Schema drifted so we have to add the missing columns/structfields. + )?, None => batch, }; @@ -915,8 +899,8 @@ impl std::future::IntoFuture for WriteBuilder { new_batches.push(cast_record_batch( &batch, new_schema.clone(), - false, - true, + this.safe_cast, + schema_drift, // Schema drifted so we have to add the missing columns/structfields. )?); } vec![new_batches] @@ -1019,8 +1003,6 @@ impl std::future::IntoFuture for WriteBuilder { this.target_file_size, this.write_batch_size, this.writer_properties.clone(), - this.safe_cast, - this.schema_mode, writer_stats_config.clone(), None, ) @@ -1031,11 +1013,7 @@ impl std::future::IntoFuture for WriteBuilder { if let Some(snapshot) = &this.snapshot { if matches!(this.mode, SaveMode::Overwrite) { // Update metadata with new schema - let table_schema = snapshot - .physical_arrow_schema(this.log_store.object_store().clone()) - .await - .or_else(|_| snapshot.arrow_schema()) - .unwrap_or(schema.clone()); + let table_schema = snapshot.input_schema()?; let configuration = snapshot.metadata().configuration.clone(); let current_protocol = snapshot.protocol(); diff --git a/crates/core/src/writer/record_batch.rs b/crates/core/src/writer/record_batch.rs index d99673c8cb..493646d479 100644 --- a/crates/core/src/writer/record_batch.rs +++ b/crates/core/src/writer/record_batch.rs @@ -30,7 +30,7 @@ use super::utils::{ use super::{DeltaWriter, DeltaWriterError, WriteMode}; use crate::errors::DeltaTableError; use crate::kernel::{scalars::ScalarExt, Action, Add, PartitionsExt, StructType}; -use crate::operations::cast::merge_schema; +use crate::operations::cast::merge_schema::merge_arrow_schema; use crate::storage::ObjectStoreRetryExt; use crate::table::builder::DeltaTableBuilder; use crate::table::config::DEFAULT_NUM_INDEX_COLS; @@ -322,8 +322,11 @@ impl PartitionWriter { WriteMode::MergeSchema => { debug!("The writer and record batch schemas do not match, merging"); - let merged = - merge_schema(self.arrow_schema.clone(), record_batch.schema().clone())?; + let merged = merge_arrow_schema( + self.arrow_schema.clone(), + record_batch.schema().clone(), + true, + )?; self.arrow_schema = merged; let mut cols = vec![]; diff --git a/python/Cargo.toml b/python/Cargo.toml index b7feb2a36e..c6fd6e26a0 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deltalake-python" -version = "0.18.3" +version = "0.19.0" authors = ["Qingping Hou ", "Will Jones "] homepage = "https://github.com/delta-io/delta-rs" license = "Apache-2.0" diff --git a/python/deltalake/schema.py b/python/deltalake/schema.py index cc7561cb39..8bc5c7e155 100644 --- a/python/deltalake/schema.py +++ b/python/deltalake/schema.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Generator, Union import pyarrow as pa @@ -15,9 +16,25 @@ DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"] +class ArrowSchemaConversionMode(Enum): + NORMAL = "NORMAL" + LARGE = "LARGE" + PASSTHROUGH = "PASSTHROUGH" + + @classmethod + def from_str(cls, value: str) -> "ArrowSchemaConversionMode": + try: + return cls(value.upper()) + except ValueError: + raise ValueError( + f"{value} is not a valid ArrowSchemaConversionMode. Valid values are: {[item.value for item in ArrowSchemaConversionMode]}" + ) + + ### Inspired from Pola-rs repo - licensed with MIT License, see license in python/licenses/polars_license.txt.### def _convert_pa_schema_to_delta( - schema: pa.schema, large_dtypes: bool = False + schema: pa.schema, + schema_conversion_mode: ArrowSchemaConversionMode = ArrowSchemaConversionMode.NORMAL, ) -> pa.schema: """Convert a PyArrow schema to a schema compatible with Delta Lake. Converts unsigned to signed equivalent, and converts all timestamps to `us` timestamps. With the boolean flag large_dtypes you can control if the schema @@ -25,7 +42,9 @@ def _convert_pa_schema_to_delta( Args schema: Source schema - large_dtypes: If True, the pyarrow schema is casted to large_dtypes + schema_conversion_mode: large mode will cast all string/binary/list to the large version arrow types, normal mode + keeps the normal version of the types. Passthrough mode keeps string/binary/list flavored types in their original + version, whether that is view/large/normal. """ dtype_map = { pa.uint8(): pa.int8(), @@ -33,20 +52,39 @@ def _convert_pa_schema_to_delta( pa.uint32(): pa.int32(), pa.uint64(): pa.int64(), } - if large_dtypes: + if schema_conversion_mode == ArrowSchemaConversionMode.LARGE: dtype_map = { **dtype_map, - **{pa.string(): pa.large_string(), pa.binary(): pa.large_binary()}, + **{ + pa.string(): pa.large_string(), + pa.string_view(): pa.large_string(), + pa.binary(): pa.large_binary(), + pa.binary_view(): pa.large_binary(), + }, } - else: + elif schema_conversion_mode == ArrowSchemaConversionMode.NORMAL: dtype_map = { **dtype_map, - **{pa.large_string(): pa.string(), pa.large_binary(): pa.binary()}, + **{ + pa.large_string(): pa.string(), + pa.string_view(): pa.string(), + pa.large_binary(): pa.binary(), + pa.binary_view(): pa.binary(), + }, } def dtype_to_delta_dtype(dtype: pa.DataType) -> pa.DataType: # Handle nested types - if isinstance(dtype, (pa.LargeListType, pa.ListType, pa.FixedSizeListType)): + if isinstance( + dtype, + ( + pa.LargeListType, + pa.ListType, + pa.FixedSizeListType, + pa.ListViewType, + pa.LargeListViewType, + ), + ): return list_to_delta_dtype(dtype) elif isinstance(dtype, pa.StructType): return struct_to_delta_dtype(dtype) @@ -63,14 +101,35 @@ def dtype_to_delta_dtype(dtype: pa.DataType) -> pa.DataType: return dtype def list_to_delta_dtype( - dtype: Union[pa.LargeListType, pa.ListType], + dtype: Union[ + pa.LargeListType, + pa.ListType, + pa.ListViewType, + pa.LargeListViewType, + pa.FixedSizeListType, + ], ) -> Union[pa.LargeListType, pa.ListType]: nested_dtype = dtype.value_type nested_dtype_cast = dtype_to_delta_dtype(nested_dtype) - if large_dtypes: + if schema_conversion_mode == ArrowSchemaConversionMode.LARGE: return pa.large_list(nested_dtype_cast) - else: + elif schema_conversion_mode == ArrowSchemaConversionMode.NORMAL: return pa.list_(nested_dtype_cast) + elif schema_conversion_mode == ArrowSchemaConversionMode.PASSTHROUGH: + if isinstance(dtype, pa.LargeListType): + return pa.large_list(nested_dtype_cast) + elif isinstance(dtype, pa.ListType): + return pa.list_(nested_dtype_cast) + elif isinstance(dtype, pa.FixedSizeListType): + return pa.list_(nested_dtype_cast) + elif isinstance(dtype, pa.LargeListViewType): + return pa.large_list_view(nested_dtype_cast) + elif isinstance(dtype, pa.ListViewType): + return pa.list_view(nested_dtype_cast) + else: + raise NotImplementedError + else: + raise NotImplementedError def struct_to_delta_dtype(dtype: pa.StructType) -> pa.StructType: fields = [dtype[i] for i in range(dtype.num_fields)] @@ -91,10 +150,12 @@ def _cast_schema_to_recordbatchreader( def convert_pyarrow_recordbatchreader( - data: pa.RecordBatchReader, large_dtypes: bool + data: pa.RecordBatchReader, schema_conversion_mode: ArrowSchemaConversionMode ) -> pa.RecordBatchReader: """Converts a PyArrow RecordBatchReader to a PyArrow RecordBatchReader with a compatible delta schema""" - schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes) + schema = _convert_pa_schema_to_delta( + data.schema, schema_conversion_mode=schema_conversion_mode + ) data = pa.RecordBatchReader.from_batches( schema, @@ -104,25 +165,33 @@ def convert_pyarrow_recordbatchreader( def convert_pyarrow_recordbatch( - data: pa.RecordBatch, large_dtypes: bool + data: pa.RecordBatch, schema_conversion_mode: ArrowSchemaConversionMode ) -> pa.RecordBatchReader: """Converts a PyArrow RecordBatch to a PyArrow RecordBatchReader with a compatible delta schema""" - schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes) + schema = _convert_pa_schema_to_delta( + data.schema, schema_conversion_mode=schema_conversion_mode + ) data = pa.Table.from_batches([data]).cast(schema).to_reader() return data -def convert_pyarrow_table(data: pa.Table, large_dtypes: bool) -> pa.RecordBatchReader: +def convert_pyarrow_table( + data: pa.Table, schema_conversion_mode: ArrowSchemaConversionMode +) -> pa.RecordBatchReader: """Converts a PyArrow table to a PyArrow RecordBatchReader with a compatible delta schema""" - schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes) + schema = _convert_pa_schema_to_delta( + data.schema, schema_conversion_mode=schema_conversion_mode + ) data = data.cast(schema).to_reader() return data def convert_pyarrow_dataset( - data: ds.Dataset, large_dtypes: bool + data: ds.Dataset, schema_conversion_mode: ArrowSchemaConversionMode ) -> pa.RecordBatchReader: """Converts a PyArrow dataset to a PyArrow RecordBatchReader with a compatible delta schema""" data = data.scanner().to_reader() - data = convert_pyarrow_recordbatchreader(data, large_dtypes) + data = convert_pyarrow_recordbatchreader( + data, schema_conversion_mode=schema_conversion_mode + ) return data diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 4c82b40cd0..5cd0d252cf 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -22,7 +22,6 @@ import pyarrow import pyarrow.dataset as ds import pyarrow.fs as pa_fs -import pyarrow_hotfix # noqa: F401; addresses CVE-2023-47248; # type: ignore from pyarrow.dataset import ( Expression, FileSystemDataset, @@ -39,13 +38,13 @@ if TYPE_CHECKING: import os -from deltalake._internal import DeltaDataChecker as _DeltaDataChecker from deltalake._internal import RawDeltaTable from deltalake._internal import create_deltalake as _create_deltalake from deltalake._util import encode_partition_value from deltalake.data_catalog import DataCatalog from deltalake.exceptions import DeltaProtocolError from deltalake.fs import DeltaStorageHandler +from deltalake.schema import ArrowSchemaConversionMode from deltalake.schema import Field as DeltaField from deltalake.schema import Schema as DeltaSchema @@ -876,7 +875,7 @@ def merge( target_alias: Optional[str] = None, error_on_type_mismatch: bool = True, writer_properties: Optional[WriterProperties] = None, - large_dtypes: bool = False, + large_dtypes: Optional[bool] = None, custom_metadata: Optional[Dict[str, str]] = None, post_commithook_properties: Optional[PostCommitHookProperties] = None, ) -> "TableMerger": @@ -891,15 +890,21 @@ def merge( target_alias: Alias for the target table error_on_type_mismatch: specify if merge will return error if data types are mismatching :default = True writer_properties: Pass writer properties to the Rust parquet writer - large_dtypes: If True, the data schema is kept in large_dtypes. + large_dtypes: Deprecated, will be removed in 1.0 + arrow_schema_conversion_mode: Large converts all types of data schema into Large Arrow types, passthrough keeps string/binary/list types untouched custom_metadata: custom metadata that will be added to the transaction commit. post_commithook_properties: properties for the post commit hook. If None, default values are used. Returns: TableMerger: TableMerger Object """ - invariants = self.schema().invariants - checker = _DeltaDataChecker(invariants) + if large_dtypes: + warnings.warn( + "large_dtypes is deprecated", + category=DeprecationWarning, + stacklevel=2, + ) + conversion_mode = ArrowSchemaConversionMode.PASSTHROUGH from .schema import ( convert_pyarrow_dataset, @@ -909,28 +914,24 @@ def merge( ) if isinstance(source, pyarrow.RecordBatchReader): - source = convert_pyarrow_recordbatchreader(source, large_dtypes) + source = convert_pyarrow_recordbatchreader(source, conversion_mode) elif isinstance(source, pyarrow.RecordBatch): - source = convert_pyarrow_recordbatch(source, large_dtypes) + source = convert_pyarrow_recordbatch(source, conversion_mode) elif isinstance(source, pyarrow.Table): - source = convert_pyarrow_table(source, large_dtypes) + source = convert_pyarrow_table(source, conversion_mode) elif isinstance(source, ds.Dataset): - source = convert_pyarrow_dataset(source, large_dtypes) + source = convert_pyarrow_dataset(source, conversion_mode) elif _has_pandas and isinstance(source, pd.DataFrame): source = convert_pyarrow_table( - pyarrow.Table.from_pandas(source), large_dtypes + pyarrow.Table.from_pandas(source), conversion_mode ) else: raise TypeError( f"{type(source).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Table or Pandas DataFrame are valid inputs for source." ) - def validate_batch(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: - checker.check_batch(batch) - return batch - source = pyarrow.RecordBatchReader.from_batches( - source.schema, (validate_batch(batch) for batch in source) + source.schema, (batch for batch in source) ) return TableMerger( diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index f323b90e35..50aa5841a0 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -1,5 +1,6 @@ import json import uuid +import warnings from dataclasses import dataclass from datetime import date, datetime from decimal import Decimal @@ -41,6 +42,7 @@ from ._util import encode_partition_value from .exceptions import DeltaProtocolError, TableNotFoundError from .schema import ( + ArrowSchemaConversionMode, convert_pyarrow_dataset, convert_pyarrow_recordbatch, convert_pyarrow_recordbatchreader, @@ -62,7 +64,6 @@ else: _has_pandas = True -PYARROW_MAJOR_VERSION = int(pa.__version__.split(".", maxsplit=1)[0]) DEFAULT_DATA_SKIPPING_NUM_INDEX_COLS = 32 DTYPE_MAP = { @@ -148,7 +149,7 @@ def write_deltalake( schema_mode: Optional[Literal["merge", "overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., large_dtypes: bool = ..., - engine: Literal["rust"], + engine: Literal["rust"] = ..., writer_properties: WriterProperties = ..., custom_metadata: Optional[Dict[str, str]] = ..., post_commithook_properties: Optional[PostCommitHookProperties] = ..., @@ -178,7 +179,7 @@ def write_deltalake( storage_options: Optional[Dict[str, str]] = ..., predicate: Optional[str] = ..., large_dtypes: bool = ..., - engine: Literal["rust"], + engine: Literal["rust"] = ..., writer_properties: WriterProperties = ..., custom_metadata: Optional[Dict[str, str]] = ..., post_commithook_properties: Optional[PostCommitHookProperties] = ..., @@ -214,7 +215,7 @@ def write_deltalake( partition_filters: Optional[List[Tuple[str, str, Any]]] = None, predicate: Optional[str] = None, large_dtypes: bool = False, - engine: Literal["pyarrow", "rust"] = "pyarrow", + engine: Literal["pyarrow", "rust"] = "rust", writer_properties: Optional[WriterProperties] = None, custom_metadata: Optional[Dict[str, str]] = None, post_commithook_properties: Optional[PostCommitHookProperties] = None, @@ -268,9 +269,8 @@ def write_deltalake( storage_options: options passed to the native delta filesystem. predicate: When using `Overwrite` mode, replace data that matches a predicate. Only used in rust engine. partition_filters: the partition filters that will be used for partition overwrite. Only used in pyarrow engine. - large_dtypes: If True, the data schema is kept in large_dtypes, has no effect on pandas dataframe input. - engine: writer engine to write the delta table. `Rust` engine is still experimental but you may - see up to 4x performance improvements over pyarrow. + large_dtypes: Only used for pyarrow engine + engine: writer engine to write the delta table. PyArrow engine is deprecated, and will be removed in v1.0. writer_properties: Pass writer properties to the Rust parquet writer. custom_metadata: Custom metadata to add to the commitInfo. post_commithook_properties: properties for the post commit hook. If None, default values are used. @@ -285,14 +285,20 @@ def write_deltalake( if isinstance(partition_by, str): partition_by = [partition_by] - data, schema = _convert_data_and_schema( - data=data, schema=schema, large_dtypes=large_dtypes - ) - if engine == "rust": + if partition_filters is not None: + raise ValueError( + "Partition filters can only be used with PyArrow engine, use predicate instead. PyArrow engine will be deprecated in 1.0" + ) + if table is not None and mode == "ignore": return + data, schema = _convert_data_and_schema( + data=data, + schema=schema, + conversion_mode=ArrowSchemaConversionMode.PASSTHROUGH, + ) data = RecordBatchReader.from_batches(schema, (batch for batch in data)) write_deltalake_rust( table_uri=table_uri, @@ -316,13 +322,34 @@ def write_deltalake( ) if table: table.update_incremental() - elif engine == "pyarrow": + warnings.warn( + "pyarrow engine is deprecated and will be removed in v1.0", + category=DeprecationWarning, + stacklevel=2, + ) + + if predicate is not None: + raise ValueError( + "Predicate can only be used with Rust engine, use partition_filters instead. PyArrow engine will be removed in 1.0" + ) + + if large_dtypes: + arrow_schema_conversion_mode = "large" + else: + arrow_schema_conversion_mode = "normal" + + conversion_mode = ArrowSchemaConversionMode.from_str( + arrow_schema_conversion_mode + ) + data, schema = _convert_data_and_schema( + data=data, schema=schema, conversion_mode=conversion_mode + ) + if schema_mode == "merge": raise ValueError( "schema_mode 'merge' is not supported in pyarrow engine. Use engine=rust" ) - # We need to write against the latest table version num_indexed_cols, stats_cols = get_num_idx_cols_and_stats_columns( table._table if table is not None else None, configuration @@ -366,19 +393,9 @@ def write_deltalake( if partition_by: table_schema: pa.Schema = schema - if PYARROW_MAJOR_VERSION < 12: - partition_schema = pa.schema( - [ - pa.field( - name, _large_to_normal_dtype(table_schema.field(name).type) - ) - for name in partition_by - ] - ) - else: - partition_schema = pa.schema( - [table_schema.field(name) for name in partition_by] - ) + partition_schema = pa.schema( + [table_schema.field(name) for name in partition_by] + ) partitioning = ds.partitioning(partition_schema, flavor="hive") else: partitioning = None @@ -393,18 +410,10 @@ def visitor(written_file: Any) -> None: columns_to_collect_stats=stats_cols, ) - # PyArrow added support for written_file.size in 9.0.0 - if PYARROW_MAJOR_VERSION >= 9: - size = written_file.size - elif filesystem is not None: - size = filesystem.get_file_info([path])[0].size - else: - size = 0 - add_actions.append( AddAction( path, - size, + written_file.size, partition_values, int(datetime.now().timestamp() * 1000), True, @@ -634,29 +643,27 @@ def _convert_data_and_schema( ArrowStreamExportable, ], schema: Optional[Union[pa.Schema, DeltaSchema]], - large_dtypes: bool, + conversion_mode: ArrowSchemaConversionMode, ) -> Tuple[pa.RecordBatchReader, pa.Schema]: if isinstance(data, RecordBatchReader): - data = convert_pyarrow_recordbatchreader(data, large_dtypes) + data = convert_pyarrow_recordbatchreader(data, conversion_mode) elif isinstance(data, pa.RecordBatch): - data = convert_pyarrow_recordbatch(data, large_dtypes) + data = convert_pyarrow_recordbatch(data, conversion_mode) elif isinstance(data, pa.Table): - data = convert_pyarrow_table(data, large_dtypes) + data = convert_pyarrow_table(data, conversion_mode) elif isinstance(data, ds.Dataset): - data = convert_pyarrow_dataset(data, large_dtypes) + data = convert_pyarrow_dataset(data, conversion_mode) elif _has_pandas and isinstance(data, pd.DataFrame): if schema is not None: data = convert_pyarrow_table( - pa.Table.from_pandas(data, schema=schema), large_dtypes=large_dtypes + pa.Table.from_pandas(data, schema=schema), conversion_mode ) else: - data = convert_pyarrow_table( - pa.Table.from_pandas(data), large_dtypes=large_dtypes - ) + data = convert_pyarrow_table(pa.Table.from_pandas(data), conversion_mode) elif hasattr(data, "__arrow_c_array__"): data = convert_pyarrow_recordbatch( pa.record_batch(data), # type:ignore[attr-defined] - large_dtypes, + conversion_mode, ) elif hasattr(data, "__arrow_c_stream__"): if not hasattr(RecordBatchReader, "from_stream"): @@ -665,7 +672,7 @@ def _convert_data_and_schema( ) data = convert_pyarrow_recordbatchreader( - RecordBatchReader.from_stream(data), large_dtypes + RecordBatchReader.from_stream(data), conversion_mode ) elif isinstance(data, Iterable): if schema is None: @@ -675,8 +682,18 @@ def _convert_data_and_schema( f"{type(data).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Iterable[RecordBatch], Table, Dataset or Pandas DataFrame or objects implementing the Arrow PyCapsule Interface are valid inputs for source." ) - if isinstance(schema, DeltaSchema): - schema = schema.to_pyarrow(as_large_types=large_dtypes) + if ( + isinstance(schema, DeltaSchema) + and conversion_mode == ArrowSchemaConversionMode.PASSTHROUGH + ): + raise NotImplementedError( + "ArrowSchemaConversionMode.passthrough is not implemented to work with DeltaSchema, skip passing a schema or pass an arrow schema." + ) + elif isinstance(schema, DeltaSchema): + if conversion_mode == ArrowSchemaConversionMode.LARGE: + schema = schema.to_pyarrow(as_large_types=True) + else: + schema = schema.to_pyarrow(as_large_types=False) elif schema is None: schema = data.schema @@ -811,19 +828,6 @@ def iter_groups(metadata: Any) -> Iterator[Any]: # Min and Max are recorded in physical type, not logical type # https://stackoverflow.com/questions/66753485/decoding-parquet-min-max-statistics-for-decimal-type # TODO: Add logic to decode physical type for DATE, DECIMAL - logical_type = ( - metadata.row_group(0) - .column(column_idx) - .statistics.logical_type.type - ) - - if PYARROW_MAJOR_VERSION < 8 and logical_type not in [ - "STRING", - "INT", - "TIMESTAMP", - "NONE", - ]: - continue minimums = ( group.column(column_idx).statistics.min diff --git a/python/pyproject.toml b/python/pyproject.toml index 013ec09aca..a13886209b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -18,8 +18,7 @@ classifiers = [ "Programming Language :: Python :: 3.12" ] dependencies = [ - "pyarrow>=8", - "pyarrow-hotfix", + "pyarrow>=16", ] [project.optional-dependencies] diff --git a/python/stubs/pyarrow/__init__.pyi b/python/stubs/pyarrow/__init__.pyi index aaf92ea962..e500d11191 100644 --- a/python/stubs/pyarrow/__init__.pyi +++ b/python/stubs/pyarrow/__init__.pyi @@ -11,6 +11,8 @@ ListType: Any StructType: Any MapType: Any FixedSizeListType: Any +LargeListViewType: Any +ListViewType: Any FixedSizeBinaryType: Any schema: Any map_: Any @@ -36,7 +38,11 @@ large_string: Any string: Any large_binary: Any binary: Any +binary_view: Any +string_view: Any +list_view: Any large_list: Any +large_list_view: Any LargeListType: Any dictionary: Any timestamp: Any diff --git a/python/tests/pyspark_integration/test_write_to_pyspark.py b/python/tests/pyspark_integration/test_write_to_pyspark.py index 81cda71883..d826140fbc 100644 --- a/python/tests/pyspark_integration/test_write_to_pyspark.py +++ b/python/tests/pyspark_integration/test_write_to_pyspark.py @@ -112,7 +112,7 @@ def test_checks_min_writer_version(tmp_path: pathlib.Path): DeltaProtocolError, match="This table's min_writer_version is 3, but" ): valid_data = pa.table({"c1": pa.array([5, 6])}) - write_deltalake(str(tmp_path), valid_data, mode="append") + write_deltalake(str(tmp_path), valid_data, mode="append", engine="pyarrow") @pytest.mark.pyspark diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index 23198d9ef3..a3ad6b62e1 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -6,6 +6,7 @@ from deltalake import DeltaTable, Field from deltalake.schema import ( ArrayType, + ArrowSchemaConversionMode, MapType, PrimitiveType, Schema, @@ -222,13 +223,74 @@ def test_delta_schema(): assert schema_without_metadata == Schema.from_pyarrow(pa_schema) -@pytest.mark.parametrize( - "schema,expected_schema,large_dtypes", - [ +def _generate_test_tuples(): + test_tuples = [ + ( + pa.schema([("some_int", pa.uint32()), ("some_string", pa.string_view())]), + pa.schema([("some_int", pa.int32()), ("some_string", pa.string_view())]), + ArrowSchemaConversionMode.PASSTHROUGH, + ), + ( + pa.schema( + [ + ("some_int", pa.uint32()), + ("some_string", pa.list_view(pa.large_string())), + ] + ), + pa.schema( + [ + ("some_int", pa.int32()), + ("some_string", pa.list_view(pa.large_string())), + ] + ), + ArrowSchemaConversionMode.PASSTHROUGH, + ), ( pa.schema([("some_int", pa.uint32()), ("some_string", pa.string())]), pa.schema([("some_int", pa.int32()), ("some_string", pa.string())]), - False, + ArrowSchemaConversionMode.NORMAL, + ), + ( + pa.schema([("some_int", pa.uint32()), ("some_string", pa.string())]), + pa.schema([("some_int", pa.int32()), ("some_string", pa.string())]), + ArrowSchemaConversionMode.PASSTHROUGH, + ), + ( + pa.schema([("some_int", pa.uint32()), ("some_string", pa.large_string())]), + pa.schema([("some_int", pa.int32()), ("some_string", pa.large_string())]), + ArrowSchemaConversionMode.PASSTHROUGH, + ), + ( + pa.schema([("some_int", pa.uint32()), ("some_binary", pa.large_binary())]), + pa.schema([("some_int", pa.int32()), ("some_binary", pa.large_binary())]), + ArrowSchemaConversionMode.PASSTHROUGH, + ), + ( + pa.schema( + [ + ("some_int", pa.uint32()), + ("some_string", pa.large_list(pa.large_string())), + ] + ), + pa.schema( + [ + ("some_int", pa.int32()), + ("some_string", pa.large_list(pa.large_string())), + ] + ), + ArrowSchemaConversionMode.PASSTHROUGH, + ), + ( + pa.schema( + [ + ("some_int", pa.uint32()), + ("some_string", pa.list_(pa.large_string())), + ] + ), + pa.schema( + [("some_int", pa.int32()), ("some_string", pa.list_(pa.large_string()))] + ), + ArrowSchemaConversionMode.PASSTHROUGH, ), ( pa.schema( @@ -247,7 +309,7 @@ def test_delta_schema(): pa.field("some_decimal", pa.decimal128(10, 2), nullable=False), ] ), - False, + ArrowSchemaConversionMode.NORMAL, ), ( pa.schema( @@ -262,17 +324,17 @@ def test_delta_schema(): pa.field("some_string", pa.large_string(), nullable=False), ] ), - True, + ArrowSchemaConversionMode.LARGE, ), ( pa.schema([("some_int", pa.uint32()), ("some_string", pa.string())]), pa.schema([("some_int", pa.int32()), ("some_string", pa.large_string())]), - True, + ArrowSchemaConversionMode.LARGE, ), ( pa.schema([("some_int", pa.uint32()), ("some_string", pa.large_string())]), pa.schema([("some_int", pa.int32()), ("some_string", pa.string())]), - False, + ArrowSchemaConversionMode.NORMAL, ), ( pa.schema( @@ -291,7 +353,7 @@ def test_delta_schema(): ("some_int3", pa.int64()), ] ), - True, + ArrowSchemaConversionMode.LARGE, ), ( pa.schema( @@ -310,7 +372,7 @@ def test_delta_schema(): ("some_string", pa.large_string()), ] ), - True, + ArrowSchemaConversionMode.LARGE, ), ( pa.schema( @@ -327,7 +389,7 @@ def test_delta_schema(): ("some_binary", pa.binary()), ] ), - False, + ArrowSchemaConversionMode.NORMAL, ), ( pa.schema( @@ -355,7 +417,7 @@ def test_delta_schema(): ("some_binary", pa.large_binary()), ] ), - True, + ArrowSchemaConversionMode.LARGE, ), ( pa.schema( @@ -383,7 +445,7 @@ def test_delta_schema(): ("some_binary", pa.binary()), ] ), - False, + ArrowSchemaConversionMode.NORMAL, ), ( pa.schema( @@ -410,7 +472,7 @@ def test_delta_schema(): ("timestamp7", pa.timestamp("us", tz="UTC")), ] ), - False, + ArrowSchemaConversionMode.NORMAL, ), ( pa.schema( @@ -451,11 +513,18 @@ def test_delta_schema(): ) ] ), - False, + ArrowSchemaConversionMode.NORMAL, ), - ], + ] + + return test_tuples + + +@pytest.mark.parametrize( + "schema,expected_schema,conversion_mode", + _generate_test_tuples(), ) -def test_schema_conversions(schema, expected_schema, large_dtypes): - result_schema = _convert_pa_schema_to_delta(schema, large_dtypes) +def test_schema_conversions(schema, expected_schema, conversion_mode): + result_schema = _convert_pa_schema_to_delta(schema, conversion_mode) assert result_schema == expected_schema diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 9fb644e285..8d03ff0863 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -6,8 +6,6 @@ from typing import Any, List, Tuple from unittest.mock import Mock -from packaging import version - from deltalake._util import encode_partition_value from deltalake.exceptions import DeltaProtocolError from deltalake.table import ProtocolVersions @@ -280,13 +278,11 @@ def test_read_table_with_stats(): data = dataset.to_table(filter=filter_expr) assert data.num_rows == 0 - # PyArrow added support for is_null and is_valid simplification in 8.0.0 - if version.parse(pa.__version__).major >= 8: - filter_expr = ds.field("cases").is_null() - assert len(list(dataset.get_fragments(filter=filter_expr))) == 0 + filter_expr = ds.field("cases").is_null() + assert len(list(dataset.get_fragments(filter=filter_expr))) == 0 - data = dataset.to_table(filter=filter_expr) - assert data.num_rows == 0 + data = dataset.to_table(filter=filter_expr) + assert data.num_rows == 0 def test_read_special_partition(): diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 3c9a977b56..1534d42789 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -13,7 +13,6 @@ import pyarrow as pa import pyarrow.compute as pc import pytest -from packaging import version from pyarrow.dataset import ParquetFileFormat, ParquetReadOptions, dataset from pyarrow.lib import RecordBatchReader @@ -147,7 +146,7 @@ def test_enforce_schema_rust_writer(existing_table: DeltaTable, mode: str): def test_update_schema(existing_table: DeltaTable): new_data = pa.table({"x": pa.array([1, 2, 3])}) - with pytest.raises(ValueError): + with pytest.raises(DeltaError): write_deltalake( existing_table, new_data, mode="append", schema_mode="overwrite" ) @@ -239,9 +238,7 @@ def test_overwrite_schema(existing_table: DeltaTable): def test_update_schema_rust_writer_append(existing_table: DeltaTable): - with pytest.raises( - SchemaMismatchError, match="Cannot cast schema, number of fields does not match" - ): + with pytest.raises(SchemaMismatchError): # It's illegal to do schema drift without correct schema_mode write_deltalake( existing_table, @@ -258,24 +255,44 @@ def test_update_schema_rust_writer_append(existing_table: DeltaTable): schema_mode="overwrite", engine="rust", ) + write_deltalake( + existing_table, + pa.table({"x2": pa.array([1, 2, 3])}), + mode="append", + schema_mode="merge", + engine="rust", + ) + + +def test_write_type_castable_types(existing_table: DeltaTable): + write_deltalake( + existing_table, + pa.table({"utf8": pa.array([1, 2, 3])}), + mode="append", + schema_mode="merge", + engine="rust", + ) + with pytest.raises( + Exception, match="Cast error: Cannot cast string 'hello' to value of Int8 type" + ): + write_deltalake( + existing_table, + pa.table({"int8": pa.array(["hello", "2", "3"])}), + mode="append", + schema_mode="merge", + engine="rust", + ) + with pytest.raises( - SchemaMismatchError, - match="Schema error: Cannot merge types string and long", + Exception, match="Cast error: Can't cast value 1000 to type Int8" ): write_deltalake( existing_table, - pa.table({"utf8": pa.array([1, 2, 3])}), + pa.table({"int8": pa.array([1000, 100, 10])}), mode="append", schema_mode="merge", engine="rust", ) - write_deltalake( - existing_table, - pa.table({"x2": pa.array([1, 2, 3])}), - mode="append", - schema_mode="merge", - engine="rust", - ) def test_update_schema_rust_writer_invalid(existing_table: DeltaTable): @@ -475,10 +492,10 @@ def test_write_modes(tmp_path: pathlib.Path, sample_data: pa.Table, engine): if engine == "pyarrow": with pytest.raises(FileExistsError): - write_deltalake(tmp_path, sample_data, mode="error") + write_deltalake(tmp_path, sample_data, mode="error", engine=engine) elif engine == "rust": with pytest.raises(DeltaError): - write_deltalake(tmp_path, sample_data, mode="error", engine="rust") + write_deltalake(tmp_path, sample_data, mode="error", engine=engine) write_deltalake(tmp_path, sample_data, mode="ignore", engine="rust") assert ("0" * 19 + "1.json") not in os.listdir(tmp_path / "_delta_log") @@ -629,6 +646,7 @@ def test_write_recordbatchreader( existing_table.to_pyarrow_dataset().schema, batches ) + print("writing second time") write_deltalake( tmp_path, reader, mode="overwrite", large_dtypes=large_dtypes, engine=engine ) @@ -686,7 +704,13 @@ def test_writer_stats(existing_table: DeltaTable, sample_data: pa.Table): assert stats["numRecords"] == sample_data.num_rows - assert all(null_count == 0 for null_count in stats["nullCount"].values()) + null_values = [] + for null_count in stats["nullCount"].values(): + if isinstance(null_count, dict): + null_values.extend(list(null_count.values())) + else: + null_values.append(null_count) + assert all(i == 0 for i in null_values) expected_mins = { "utf8": "0", @@ -694,19 +718,19 @@ def test_writer_stats(existing_table: DeltaTable, sample_data: pa.Table): "int32": 0, "int16": 0, "int8": 0, - "float32": 0.0, - "float64": 0.0, + "float32": -0.0, + "float64": -0.0, "bool": False, "binary": "0", - "timestamp": "2022-01-01T00:00:00", - "struct.x": 0, - "struct.y": "0", - "list.list.item": 0, + "timestamp": "2022-01-01T00:00:00Z", + "struct": { + "x": 0, + "y": "0", + }, } # PyArrow added support for decimal and date32 in 8.0.0 - if version.parse(pa.__version__).major >= 8: - expected_mins["decimal"] = "10.000" - expected_mins["date32"] = "2022-01-01" + expected_mins["decimal"] = 10.0 + expected_mins["date32"] = "2022-01-01" assert stats["minValues"] == expected_mins @@ -720,15 +744,12 @@ def test_writer_stats(existing_table: DeltaTable, sample_data: pa.Table): "float64": 4.0, "bool": True, "binary": "4", - "timestamp": "2022-01-01T04:00:00", - "struct.x": 4, - "struct.y": "4", - "list.list.item": 4, + "timestamp": "2022-01-01T04:00:00Z", + "struct": {"x": 4, "y": "4"}, } # PyArrow added support for decimal and date32 in 8.0.0 - if version.parse(pa.__version__).major >= 8: - expected_maxs["decimal"] = "14.000" - expected_maxs["date32"] = "2022-01-05" + expected_maxs["decimal"] = 14.0 + expected_maxs["date32"] = "2022-01-05" assert stats["maxValues"] == expected_maxs @@ -794,6 +815,7 @@ def get_multifile_stats(table: DeltaTable) -> Iterable[Dict]: write_deltalake( tmp_path, data, + engine="pyarrow", file_options=ParquetFileFormat().make_write_options(), max_rows_per_file=rows_per_file, min_rows_per_group=rows_per_file, @@ -895,7 +917,7 @@ def _normalize_path(t): # who does not love Windows? ;) [ (1, 2, pa.int64(), "1"), (False, True, pa.bool_(), "false"), - (date(2022, 1, 1), date(2022, 1, 2), pa.date32(), "2022-01-01"), + (date(2022, 1, 1), date(2022, 1, 2), pa.date32(), "'2022-01-01'"), ], ) def test_partition_overwrite( @@ -940,7 +962,7 @@ def test_partition_overwrite( tmp_path, sample_data, mode="overwrite", - partition_filters=[("p1", "=", "1")], + predicate="p1 = 1", ) delta_table.update_incremental() @@ -970,7 +992,7 @@ def test_partition_overwrite( tmp_path, sample_data, mode="overwrite", - partition_filters=[("p2", ">", filter_string)], + predicate=f"p2 > {filter_string}", ) delta_table.update_incremental() assert ( @@ -999,7 +1021,7 @@ def test_partition_overwrite( tmp_path, sample_data, mode="overwrite", - partition_filters=[("p1", "=", "1"), ("p2", "=", filter_string)], + predicate=f"p1 = 1 AND p2 = {filter_string}", ) delta_table.update_incremental() assert ( @@ -1008,13 +1030,9 @@ def test_partition_overwrite( ) == expected_data ) - - with pytest.raises(ValueError, match="Data should be aligned with partitioning"): + with pytest.raises(DeltaProtocolError, match="Invariant violations"): write_deltalake( - tmp_path, - sample_data, - mode="overwrite", - partition_filters=[("p2", "<", filter_string)], + tmp_path, sample_data, mode="overwrite", predicate=f"p2 < {filter_string}" ) @@ -1199,24 +1217,19 @@ def test_partition_overwrite_with_new_partition( new_sample_data = pa.table( { - "p1": pa.array(["2", "1"], pa.string()), - "p2": pa.array([3, 2], pa.int64()), - "val": pa.array([2, 2], pa.int64()), + "p1": pa.array(["1", "2"], pa.string()), + "p2": pa.array([2, 2], pa.int64()), + "val": pa.array([2, 3], pa.int64()), } ) expected_data = pa.table( { "p1": pa.array(["1", "1", "2", "2"], pa.string()), - "p2": pa.array([1, 2, 1, 3], pa.int64()), - "val": pa.array([1, 2, 1, 2], pa.int64()), + "p2": pa.array([1, 2, 1, 2], pa.int64()), + "val": pa.array([1, 2, 1, 3], pa.int64()), } ) - write_deltalake( - tmp_path, - new_sample_data, - mode="overwrite", - partition_filters=[("p2", "=", "2")], - ) + write_deltalake(tmp_path, new_sample_data, mode="overwrite", predicate="p2 = 2") delta_table = DeltaTable(tmp_path) assert ( delta_table.to_pyarrow_table().sort_by( @@ -1230,14 +1243,12 @@ def test_partition_overwrite_with_non_partitioned_data( tmp_path: pathlib.Path, sample_data_for_partitioning: pa.Table ): write_deltalake(tmp_path, sample_data_for_partitioning, mode="overwrite") - - with pytest.raises(ValueError, match=r'not partition columns: \["p1"\]'): - write_deltalake( - tmp_path, - sample_data_for_partitioning, - mode="overwrite", - partition_filters=[("p1", "=", "1")], - ) + write_deltalake( + tmp_path, + sample_data_for_partitioning.filter(pc.field("p1") == "1"), + mode="overwrite", + predicate="p1 = 1", + ) def test_partition_overwrite_with_wrong_partition( @@ -1249,21 +1260,15 @@ def test_partition_overwrite_with_wrong_partition( mode="overwrite", partition_by=["p1", "p2"], ) + from deltalake.exceptions import DeltaError - with pytest.raises(ValueError, match=r'not in table schema: \["p999"\]'): + with pytest.raises(DeltaError, match="No field named p999."): write_deltalake( tmp_path, sample_data_for_partitioning, mode="overwrite", - partition_filters=[("p999", "=", "1")], - ) - - with pytest.raises(ValueError, match=r'not partition columns: \["val"\]'): - write_deltalake( - tmp_path, - sample_data_for_partitioning, - mode="overwrite", - partition_filters=[("val", "=", "1")], + predicate="p999 = 1", + # partition_filters=[("p999", "=", "1")], ) new_data = pa.table( @@ -1275,15 +1280,14 @@ def test_partition_overwrite_with_wrong_partition( ) with pytest.raises( - ValueError, - match="Data should be aligned with partitioning. " - "Data contained values for partition p1=1 p2=2", + DeltaProtocolError, + match="Invariant violations", ): write_deltalake( tmp_path, new_data, mode="overwrite", - partition_filters=[("p1", "=", "1"), ("p2", "=", "1")], + predicate="p1 = 1 AND p2 = 1", ) @@ -1308,12 +1312,14 @@ def test_max_partitions_exceeding_fragment_should_fail( tmp_path, sample_data_for_partitioning, mode="overwrite", + engine="pyarrow", max_partitions=1, partition_by=["p1", "p2"], ) -def test_large_arrow_types(tmp_path: pathlib.Path): +@pytest.mark.parametrize("engine", ["rust", "pyarrow"]) +def test_large_arrow_types(tmp_path: pathlib.Path, engine): pylist = [ {"name": "Joey", "gender": b"M", "arr_type": ["x", "y"], "dict": {"a": b"M"}}, {"name": "Ivan", "gender": b"F", "arr_type": ["x", "z"]}, @@ -1329,15 +1335,14 @@ def test_large_arrow_types(tmp_path: pathlib.Path): ) table = pa.Table.from_pylist(pylist, schema=schema) - write_deltalake(tmp_path, table) + write_deltalake(tmp_path, table, mode="append", engine=engine, large_dtypes=True) + write_deltalake(tmp_path, table, mode="append", engine=engine, large_dtypes=True) + write_deltalake(tmp_path, table, mode="append", engine=engine, large_dtypes=True) dt = DeltaTable(tmp_path) assert table.schema == dt.schema().to_pyarrow(as_large_types=True) -@pytest.mark.skipif( - int(pa.__version__.split(".")[0]) < 10, reason="map casts require pyarrow >= 10" -) def test_large_arrow_types_dataset_as_large_types(tmp_path: pathlib.Path): pylist = [ {"name": "Joey", "gender": b"M", "arr_type": ["x", "y"], "dict": {"a": b"M"}}, @@ -1363,9 +1368,6 @@ def test_large_arrow_types_dataset_as_large_types(tmp_path: pathlib.Path): assert union_ds.to_table().shape[0] == 4 -@pytest.mark.skipif( - int(pa.__version__.split(".")[0]) < 10, reason="map casts require pyarrow >= 10" -) def test_large_arrow_types_explicit_scan_schema(tmp_path: pathlib.Path): pylist = [ {"name": "Joey", "gender": b"M", "arr_type": ["x", "y"], "dict": {"a": b"M"}}, @@ -1529,7 +1531,10 @@ def test_float_values(tmp_path: pathlib.Path): def test_with_deltalake_schema(tmp_path: pathlib.Path, sample_data: pa.Table): write_deltalake( - tmp_path, sample_data, schema=Schema.from_pyarrow(sample_data.schema) + tmp_path, + sample_data, + engine="pyarrow", + schema=Schema.from_pyarrow(sample_data.schema), ) delta_table = DeltaTable(tmp_path) assert delta_table.schema().to_pyarrow() == sample_data.schema @@ -1544,7 +1549,7 @@ def test_with_deltalake_json_schema(tmp_path: pathlib.Path): "account": pa.array([]), } ) - write_deltalake(tmp_path, table, schema=table_schema) + write_deltalake(tmp_path, table, engine="pyarrow", schema=table_schema) table = pa.table( { "campaign": pa.array(["deltaLake"]), @@ -1552,7 +1557,9 @@ def test_with_deltalake_json_schema(tmp_path: pathlib.Path): } ) - write_deltalake(tmp_path, data=table, schema=table_schema, mode="append") + write_deltalake( + tmp_path, data=table, engine="pyarrow", schema=table_schema, mode="append" + ) delta_table = DeltaTable(tmp_path) assert delta_table.schema() == table_schema @@ -1644,7 +1651,7 @@ def test_rust_decimal_cast(tmp_path: pathlib.Path): ): write_deltalake(tmp_path, data, mode="append", engine="rust") - with pytest.raises(SchemaMismatchError, match="Cannot merge types decimal"): + with pytest.raises(SchemaMismatchError): write_deltalake( tmp_path, data, mode="append", schema_mode="merge", engine="rust" ) @@ -1843,4 +1850,5 @@ def test_roundtrip_cdc_evolution(tmp_path: pathlib.Path): def test_empty_dataset_write(tmp_path: pathlib.Path, sample_data: pa.Table): empty_arrow_table = sample_data.schema.empty_table() empty_dataset = dataset(empty_arrow_table) - write_deltalake(tmp_path, empty_dataset, mode="append") + with pytest.raises(DeltaError, match="No data source supplied to write command"): + write_deltalake(tmp_path, empty_dataset, mode="append")