diff --git a/Cargo.toml b/Cargo.toml index b8e3a80e2..a2c26f047 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "acceptance", + "derive-macros", "ffi", "kernel", "kernel/examples/dump-table", # todo: put back to `examples/*` when inspect-table is fixed diff --git a/acceptance/src/meta.rs b/acceptance/src/meta.rs index 43fec7fa1..b8d2404e9 100644 --- a/acceptance/src/meta.rs +++ b/acceptance/src/meta.rs @@ -87,7 +87,7 @@ impl TestCaseInfo { properties: metadata .configuration .iter() - .map(|(k, v)| (k.clone(), v.clone().unwrap())) + .map(|(k, v)| (k.clone(), v.clone())) .collect(), min_reader_version: protocol.min_reader_version as u32, min_writer_version: protocol.min_writer_version as u32, diff --git a/derive-macros/Cargo.toml b/derive-macros/Cargo.toml new file mode 100644 index 000000000..e103333b6 --- /dev/null +++ b/derive-macros/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "derive-macros" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +keywords.workspace = true +license.workspace = true +repository.workspace = true +readme.workspace = true +version.workspace = true + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1" +syn = { version = "2.0", features = ["extra-traits"] } +quote = "1.0" + + diff --git a/derive-macros/src/lib.rs b/derive-macros/src/lib.rs new file mode 100644 index 000000000..7c0fdebe7 --- /dev/null +++ b/derive-macros/src/lib.rs @@ -0,0 +1,85 @@ +use proc_macro2::{Ident, TokenStream}; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; +use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Fields, PathArguments, Type}; + +/// Derive a `deltakernel::schemas::ToDataType` implementation for the annotated struct. The actual +/// field names in the schema (and therefore of the struct members) are all mandated by the Delta +/// spec, and so the user of this macro is responsible for ensuring that +/// e.g. `Metadata::schema_string` is the snake_case-ified version of `schemaString` from [Delta's +/// Change Metadata](https://github.com/delta-io/delta/blob/master/PROTOCOL.md#change-metadata) +/// action (this macro allows the use of standard rust snake_case, and will convert to the correct +/// delta schema camelCase version). +#[proc_macro_derive(Schema)] +pub fn derive_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let struct_ident = input.ident; + + let schema_fields = gen_schema_fields(&input.data); + let output = quote! { + #[automatically_derived] + impl crate::actions::schemas::ToDataType for #struct_ident { + fn to_data_type() -> crate::schema::DataType { + use crate::actions::schemas::{ToDataType, GetStructField}; + crate::schema::StructType::new(vec![ + #schema_fields + ]).into() + } + } + }; + proc_macro::TokenStream::from(output) +} + +// turn our struct name into the schema name, goes from snake_case to camelCase +fn get_schema_name(name: &Ident) -> Ident { + let snake_name = name.to_string(); + let mut next_caps = false; + let ret: String = snake_name + .chars() + .filter_map(|c| { + if c == '_' { + next_caps = true; + None + } else if next_caps { + next_caps = false; + // This assumes we're using ascii, should be okay + Some(c.to_ascii_uppercase()) + } else { + Some(c) + } + }) + .collect(); + Ident::new(&ret, name.span()) +} + +fn gen_schema_fields(data: &Data) -> TokenStream { + let fields = match data { + Data::Struct(DataStruct { + fields: Fields::Named(fields), + .. + }) => &fields.named, + _ => panic!("this derive macro only works on structs with named fields"), + }; + + let schema_fields = fields.iter().map(|field| { + let name = field.ident.as_ref().unwrap(); // we know these are named fields + let name = get_schema_name(name); + match field.ty { + Type::Path(ref type_path) => { + let type_path_quoted = type_path.path.segments.iter().map(|segment| { + let segment_ident = &segment.ident; + match &segment.arguments { + PathArguments::None => quote! { #segment_ident :: }, + PathArguments::AngleBracketed(angle_args) => quote! { #segment_ident::#angle_args :: }, + _ => panic!("Can only handle <> type path args"), + } + }); + quote_spanned! { field.span() => #(#type_path_quoted),* get_struct_field(stringify!(#name))} + } + _ => { + panic!("Can't handle type: {:?}", field.ty); + } + } + }); + quote! { #(#schema_fields),* } +} diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index 6a2bdb000..60ec550eb 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -27,6 +27,9 @@ url = "2" uuid = "1.3.0" z85 = "3.0.5" +# bring in our derive macros +derive-macros = { path = "../derive-macros" } + # used for developer-visibility visibility = "0.1.0" diff --git a/kernel/src/actions/deletion_vector.rs b/kernel/src/actions/deletion_vector.rs index 44450a12b..c7c79361d 100644 --- a/kernel/src/actions/deletion_vector.rs +++ b/kernel/src/actions/deletion_vector.rs @@ -4,12 +4,13 @@ use std::io::{Cursor, Read}; use std::sync::Arc; use bytes::Bytes; +use derive_macros::Schema; use roaring::RoaringTreemap; use url::Url; use crate::{DeltaResult, Error, FileSystemClient}; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Schema)] pub struct DeletionVectorDescriptor { /// A single character to indicate how to access the DV. Legal options are: ['u', 'i', 'p']. pub storage_type: String, diff --git a/kernel/src/actions/mod.rs b/kernel/src/actions/mod.rs index a2c73d772..6d9a8349c 100644 --- a/kernel/src/actions/mod.rs +++ b/kernel/src/actions/mod.rs @@ -3,14 +3,42 @@ pub(crate) mod deletion_vector; pub(crate) mod schemas; pub(crate) mod visitors; -use std::{collections::HashMap, sync::Arc}; +use derive_macros::Schema; +use lazy_static::lazy_static; use visitors::{AddVisitor, MetadataVisitor, ProtocolVisitor}; +use self::deletion_vector::DeletionVectorDescriptor; +use crate::actions::schemas::GetStructField; use crate::{schema::StructType, DeltaResult, EngineData}; -use self::deletion_vector::DeletionVectorDescriptor; +use std::collections::HashMap; + +pub(crate) const ADD_NAME: &str = "add"; +pub(crate) const REMOVE_NAME: &str = "remove"; +pub(crate) const METADATA_NAME: &str = "metaData"; +pub(crate) const PROTOCOL_NAME: &str = "protocol"; + +lazy_static! { + static ref LOG_SCHEMA: StructType = StructType::new( + vec![ + Option::::get_struct_field(ADD_NAME), + Option::::get_struct_field(REMOVE_NAME), + Option::::get_struct_field(METADATA_NAME), + Option::::get_struct_field(PROTOCOL_NAME), + // We don't support the following actions yet + //Option::get_field(CDC_NAME), + //Option::get_field(COMMIT_INFO_NAME), + //Option::get_field(DOMAIN_METADATA_NAME), + //Option::get_field(TRANSACTION_NAME), + ] + ); +} -#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) fn get_log_schema() -> &'static StructType { + &LOG_SCHEMA +} + +#[derive(Debug, Clone, PartialEq, Eq, Schema)] pub struct Format { /// Name of the encoding for files in this table pub provider: String, @@ -27,7 +55,7 @@ impl Default for Format { } } -#[derive(Debug, Default, Clone, PartialEq, Eq)] +#[derive(Debug, Default, Clone, PartialEq, Eq, Schema)] pub struct Metadata { /// Unique identifier for this table pub id: String, @@ -44,14 +72,13 @@ pub struct Metadata { /// The time when this metadata action is created, in milliseconds since the Unix epoch pub created_time: Option, /// Configuration options for the metadata action - pub configuration: HashMap>, + pub configuration: HashMap, } impl Metadata { pub fn try_new_from_data(data: &dyn EngineData) -> DeltaResult> { - let schema = StructType::new(vec![crate::actions::schemas::METADATA_FIELD.clone()]); let mut visitor = MetadataVisitor::default(); - data.extract(Arc::new(schema), &mut visitor)?; + data.extract(get_log_schema().project(&[METADATA_NAME])?, &mut visitor)?; Ok(visitor.metadata) } @@ -60,7 +87,7 @@ impl Metadata { } } -#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Schema)] pub struct Protocol { /// The minimum version of the Delta read protocol that a client must implement /// in order to correctly read this table @@ -79,13 +106,12 @@ pub struct Protocol { impl Protocol { pub fn try_new_from_data(data: &dyn EngineData) -> DeltaResult> { let mut visitor = ProtocolVisitor::default(); - let schema = StructType::new(vec![crate::actions::schemas::PROTOCOL_FIELD.clone()]); - data.extract(Arc::new(schema), &mut visitor)?; + data.extract(get_log_schema().project(&[PROTOCOL_NAME])?, &mut visitor)?; Ok(visitor.protocol) } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Schema)] pub struct Add { /// A relative path to a data file from the root of the table or an absolute path to a file /// that should be added to the table. The path is a URI as specified by @@ -95,7 +121,7 @@ pub struct Add { pub path: String, /// A map from partition column to value for this logical file. - pub partition_values: HashMap>, + pub partition_values: HashMap, /// The size of this data file in bytes pub size: i64, @@ -113,7 +139,7 @@ pub struct Add { pub stats: Option, /// Map containing metadata about this logical file. - pub tags: HashMap>, + pub tags: Option>, /// Information about deletion vector (DV) associated with this add action pub deletion_vector: Option, @@ -134,8 +160,7 @@ impl Add { /// Since we always want to parse multiple adds from data, we return a `Vec` pub fn parse_from_data(data: &dyn EngineData) -> DeltaResult> { let mut visitor = AddVisitor::default(); - let schema = StructType::new(vec![crate::actions::schemas::ADD_FIELD.clone()]); - data.extract(Arc::new(schema), &mut visitor)?; + data.extract(get_log_schema().project(&[ADD_NAME])?, &mut visitor)?; Ok(visitor.adds) } @@ -144,7 +169,7 @@ impl Add { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Schema)] pub(crate) struct Remove { /// A relative path to a data file from the root of the table or an absolute path to a file /// that should be added to the table. The path is a URI as specified by @@ -153,24 +178,24 @@ pub(crate) struct Remove { /// [RFC 2396 URI Generic Syntax]: https://www.ietf.org/rfc/rfc2396.txt pub(crate) path: String, + /// The time this logical file was created, as milliseconds since the epoch. + pub(crate) deletion_timestamp: Option, + /// When `false` the logical file must already be present in the table or the records /// in the added file must be contained in one or more remove actions in the same version. pub(crate) data_change: bool, - /// The time this logical file was created, as milliseconds since the epoch. - pub(crate) deletion_timestamp: Option, - /// When true the fields `partition_values`, `size`, and `tags` are present pub(crate) extended_file_metadata: Option, /// A map from partition column to value for this logical file. - pub(crate) partition_values: Option>>, + pub(crate) partition_values: Option>, /// The size of this data file in bytes pub(crate) size: Option, /// Map containing metadata about this logical file. - pub(crate) tags: Option>>, + pub(crate) tags: Option>, /// Information about deletion vector (DV) associated with this add action pub(crate) deletion_vector: Option, @@ -185,19 +210,111 @@ pub(crate) struct Remove { } impl Remove { - // _try_new_from_data for now, to avoid warning, probably will need at some point - // pub(crate) fn _try_new_from_data( - // data: &dyn EngineData, - // ) -> DeltaResult { - // let mut visitor = Visitor::new(visit_remove); - // let schema = StructType::new(vec![crate::actions::schemas::REMOVE_FIELD.clone()]); - // data.extract(Arc::new(schema), &mut visitor)?; - // visitor - // .extracted - // .unwrap_or_else(|| Err(Error::generic("Didn't get expected remove"))) - // } - pub(crate) fn dv_unique_id(&self) -> Option { self.deletion_vector.as_ref().map(|dv| dv.unique_id()) } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::schema::{ArrayType, DataType, MapType, StructField}; + + #[test] + fn test_metadata_schema() { + let schema = get_log_schema() + .project(&["metaData"]) + .expect("Couldn't get metaData field"); + + let expected = Arc::new(StructType::new(vec![StructField::new( + "metaData", + StructType::new(vec![ + StructField::new("id", DataType::STRING, false), + StructField::new("name", DataType::STRING, true), + StructField::new("description", DataType::STRING, true), + StructField::new( + "format", + StructType::new(vec![ + StructField::new("provider", DataType::STRING, false), + StructField::new( + "options", + MapType::new(DataType::STRING, DataType::STRING, false), + false, + ), + ]), + false, + ), + StructField::new("schemaString", DataType::STRING, false), + StructField::new( + "partitionColumns", + ArrayType::new(DataType::STRING, false), + false, + ), + StructField::new("createdTime", DataType::LONG, true), + StructField::new( + "configuration", + MapType::new(DataType::STRING, DataType::STRING, false), + false, + ), + ]), + true, + )])); + assert_eq!(schema, expected); + } + + fn tags_field() -> StructField { + StructField::new( + "tags", + MapType::new(DataType::STRING, DataType::STRING, false), + true, + ) + } + + fn partition_values_field() -> StructField { + StructField::new( + "partitionValues", + MapType::new(DataType::STRING, DataType::STRING, false), + true, + ) + } + + fn deletion_vector_field() -> StructField { + StructField::new( + "deletionVector", + DataType::Struct(Box::new(StructType::new(vec![ + StructField::new("storageType", DataType::STRING, false), + StructField::new("pathOrInlineDv", DataType::STRING, false), + StructField::new("offset", DataType::INTEGER, true), + StructField::new("sizeInBytes", DataType::INTEGER, false), + StructField::new("cardinality", DataType::LONG, false), + ]))), + true, + ) + } + + #[test] + fn test_remove_schema() { + let schema = get_log_schema() + .project(&["remove"]) + .expect("Couldn't get remove field"); + let expected = Arc::new(StructType::new(vec![StructField::new( + "remove", + StructType::new(vec![ + StructField::new("path", DataType::STRING, false), + StructField::new("deletionTimestamp", DataType::LONG, true), + StructField::new("dataChange", DataType::BOOLEAN, false), + StructField::new("extendedFileMetadata", DataType::BOOLEAN, true), + partition_values_field(), + StructField::new("size", DataType::LONG, true), + tags_field(), + deletion_vector_field(), + StructField::new("baseRowId", DataType::LONG, true), + StructField::new("defaultRowCommitVersion", DataType::LONG, true), + ]), + true, + )])); + assert_eq!(schema, expected); + } +} diff --git a/kernel/src/actions/schemas.rs b/kernel/src/actions/schemas.rs index a1cbf890c..18dd45caa 100644 --- a/kernel/src/actions/schemas.rs +++ b/kernel/src/actions/schemas.rs @@ -1,254 +1,64 @@ //! Schema definitions for action types -use lazy_static::lazy_static; +use std::collections::HashMap; -use crate::schema::{ArrayType, DataType, MapType, StructField, StructType}; +use crate::schema::{ArrayType, DataType, MapType, StructField}; -lazy_static! { - // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#change-metadata - pub(crate) static ref METADATA_FIELD: StructField = StructField::new( - "metaData", - StructType::new(vec![ - StructField::new("id", DataType::STRING, false), - StructField::new("name", DataType::STRING, true), - StructField::new("description", DataType::STRING, true), - StructField::new( - "format", - StructType::new(vec![ - StructField::new("provider", DataType::STRING, false), - StructField::new( - "options", - MapType::new( - DataType::STRING, - DataType::STRING, - true, - ), - true, - ), - ]), - false, - ), - StructField::new("schemaString", DataType::STRING, false), - StructField::new( - "partitionColumns", - ArrayType::new(DataType::STRING, false), - false, - ), - StructField::new("createdTime", DataType::LONG, true), - StructField::new( - "configuration", - MapType::new( - DataType::STRING, - DataType::STRING, - true, - ), - false, - ), - ]), - true, - ); - // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#protocol-evolution - pub(crate) static ref PROTOCOL_FIELD: StructField = StructField::new( - "protocol", - StructType::new(vec![ - StructField::new("minReaderVersion", DataType::INTEGER, false), - StructField::new("minWriterVersion", DataType::INTEGER, false), - StructField::new( - "readerFeatures", - ArrayType::new(DataType::STRING, false), - true, - ), - StructField::new( - "writerFeatures", - ArrayType::new(DataType::STRING, false), - true, - ), - ]), - true, - ); - // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#commit-provenance-information - static ref COMMIT_INFO_FIELD: StructField = StructField::new( - "commitInfo", - StructType::new(vec![ - StructField::new("timestamp", DataType::LONG, false), - StructField::new("operation", DataType::STRING, false), - StructField::new("isolationLevel", DataType::STRING, true), - StructField::new("isBlindAppend", DataType::BOOLEAN, true), - StructField::new("txnId", DataType::STRING, true), - StructField::new("readVersion", DataType::LONG, true), - StructField::new( - "operationParameters", - MapType::new( - DataType::STRING, - DataType::STRING, - true, - ), - true, - ), - StructField::new( - "operationMetrics", - MapType::new( - DataType::STRING, - DataType::STRING, - true, - ), - true, - ), - ]), - true, - ); - // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#add-file-and-remove-file - pub(crate) static ref ADD_FIELD: StructField = StructField::new( - "add", - StructType::new(vec![ - StructField::new("path", DataType::STRING, false), - partition_values_field(), - StructField::new("size", DataType::LONG, false), - StructField::new("modificationTime", DataType::LONG, false), - StructField::new("dataChange", DataType::BOOLEAN, false), - StructField::new("stats", DataType::STRING, true), - tags_field(), - deletion_vector_field(), - StructField::new("baseRowId", DataType::LONG, true), - StructField::new("defaultRowCommitVersion", DataType::LONG, true), - StructField::new("clusteringProvider", DataType::STRING, true), - ]), - true, - ); - // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#add-file-and-remove-file - pub(crate) static ref REMOVE_FIELD: StructField = StructField::new( - "remove", - StructType::new(vec![ - StructField::new("path", DataType::STRING, false), - StructField::new("deletionTimestamp", DataType::LONG, true), - StructField::new("dataChange", DataType::BOOLEAN, false), - StructField::new("extendedFileMetadata", DataType::BOOLEAN, true), - partition_values_field(), - StructField::new("size", DataType::LONG, true), - StructField::new("stats", DataType::STRING, true), - tags_field(), - deletion_vector_field(), - StructField::new("baseRowId", DataType::LONG, true), - StructField::new("defaultRowCommitVersion", DataType::LONG, true), - ]), - true, - ); - static ref REMOVE_FIELD_CHECKPOINT: StructField = StructField::new( - "remove", - StructType::new(vec![ - StructField::new("path", DataType::STRING, false), - StructField::new("deletionTimestamp", DataType::LONG, true), - StructField::new("dataChange", DataType::BOOLEAN, false), - ]), - true, - ); - // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#add-cdc-file - static ref CDC_FIELD: StructField = StructField::new( - "cdc", - StructType::new(vec![ - StructField::new("path", DataType::STRING, false), - partition_values_field(), - StructField::new("size", DataType::LONG, false), - StructField::new("dataChange", DataType::BOOLEAN, false), - tags_field(), - ]), - true, - ); - // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#transaction-identifiers - static ref TXN_FIELD: StructField = StructField::new( - "txn", - StructType::new(vec![ - StructField::new("appId", DataType::STRING, false), - StructField::new("version", DataType::LONG, false), - StructField::new("lastUpdated", DataType::LONG, true), - ]), - true, - ); - // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#domain-metadata - static ref DOMAIN_METADATA_FIELD: StructField = StructField::new( - "domainMetadata", - StructType::new(vec![ - StructField::new("domain", DataType::STRING, false), - StructField::new( - "configuration", - MapType::new( - DataType::STRING, - DataType::STRING, - true, - ), - false, - ), - StructField::new("removed", DataType::BOOLEAN, false), - ]), - true, - ); - // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#checkpoint-metadata - static ref CHECKPOINT_METADATA_FIELD: StructField = StructField::new( - "checkpointMetadata", - StructType::new(vec![ - StructField::new("flavor", DataType::STRING, false), - tags_field(), - ]), - true, - ); - // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#sidecar-file-information - static ref SIDECAR_FIELD: StructField = StructField::new( - "sidecar", - StructType::new(vec![ - StructField::new("path", DataType::STRING, false), - StructField::new("sizeInBytes", DataType::LONG, false), - StructField::new("modificationTime", DataType::LONG, false), - StructField::new("type", DataType::STRING, false), - tags_field(), - ]), - true, - ); +pub(crate) trait ToDataType { + fn to_data_type() -> DataType; +} + +macro_rules! impl_to_data_type { + ( $(($rust_type: ty, $kernel_type: expr)), * ) => { + $( + impl ToDataType for $rust_type { + fn to_data_type() -> DataType { + $kernel_type + } + } + )* + }; +} + +impl_to_data_type!( + (String, DataType::STRING), + (i64, DataType::LONG), + (i32, DataType::INTEGER), + (i16, DataType::SHORT), + (char, DataType::BYTE), + (f32, DataType::FLOAT), + (f64, DataType::DOUBLE), + (bool, DataType::BOOLEAN) +); - static ref LOG_SCHEMA: StructType = StructType::new( - vec![ - ADD_FIELD.clone(), - CDC_FIELD.clone(), - COMMIT_INFO_FIELD.clone(), - DOMAIN_METADATA_FIELD.clone(), - METADATA_FIELD.clone(), - PROTOCOL_FIELD.clone(), - REMOVE_FIELD.clone(), - TXN_FIELD.clone(), - ] - ); +// ToDataType impl for non-nullable array types +impl ToDataType for Vec { + fn to_data_type() -> DataType { + ArrayType::new(T::to_data_type(), false).into() + } } -fn tags_field() -> StructField { - StructField::new( - "tags", - MapType::new(DataType::STRING, DataType::STRING, true), - true, - ) +// ToDataType impl for non-nullable map types +impl ToDataType for HashMap { + fn to_data_type() -> DataType { + MapType::new(K::to_data_type(), V::to_data_type(), false).into() + } } -fn partition_values_field() -> StructField { - StructField::new( - "partitionValues", - MapType::new(DataType::STRING, DataType::STRING, true), - false, - ) +pub(crate) trait GetStructField { + fn get_struct_field(name: impl Into) -> StructField; } -fn deletion_vector_field() -> StructField { - StructField::new( - "deletionVector", - DataType::Struct(Box::new(StructType::new(vec![ - StructField::new("storageType", DataType::STRING, false), - StructField::new("pathOrInlineDv", DataType::STRING, false), - StructField::new("offset", DataType::INTEGER, true), - StructField::new("sizeInBytes", DataType::INTEGER, false), - StructField::new("cardinality", DataType::LONG, false), - ]))), - true, - ) +// Normal types produce non-nullable fields +impl GetStructField for T { + fn get_struct_field(name: impl Into) -> StructField { + StructField::new(name, T::to_data_type(), false) + } } -#[cfg(test)] -pub(crate) fn log_schema() -> &'static StructType { - &LOG_SCHEMA +// Option types produce nullable fields +impl GetStructField for Option { + fn get_struct_field(name: impl Into) -> StructField { + StructField::new(name, T::to_data_type(), true) + } } diff --git a/kernel/src/actions/visitors.rs b/kernel/src/actions/visitors.rs index 0e23d39cb..d845646eb 100644 --- a/kernel/src/actions/visitors.rs +++ b/kernel/src/actions/visitors.rs @@ -169,7 +169,7 @@ impl AddVisitor { modification_time, data_change, stats, - tags: HashMap::new(), + tags: None, deletion_vector, base_row_id, default_row_commit_version, @@ -211,20 +211,20 @@ impl RemoveVisitor { let size: Option = getters[5].get_opt(row_index, "remove.size")?; - // TODO(nick) stats are skipped in getters[6] and tags are skipped in getters[7] + // TODO(nick) tags are skipped in getters[6] let deletion_vector = if let Some(storage_type) = - getters[8].get_opt(row_index, "remove.deletionVector.storageType")? + getters[7].get_opt(row_index, "remove.deletionVector.storageType")? { // there is a storageType, so the whole DV must be there let path_or_inline_dv: String = - getters[9].get(row_index, "remove.deletionVector.pathOrInlineDv")?; + getters[8].get(row_index, "remove.deletionVector.pathOrInlineDv")?; let offset: Option = - getters[10].get_opt(row_index, "remove.deletionVector.offset")?; + getters[9].get_opt(row_index, "remove.deletionVector.offset")?; let size_in_bytes: i32 = - getters[11].get(row_index, "remove.deletionVector.sizeInBytes")?; + getters[10].get(row_index, "remove.deletionVector.sizeInBytes")?; let cardinality: i64 = - getters[12].get(row_index, "remove.deletionVector.cardinality")?; + getters[11].get(row_index, "remove.deletionVector.cardinality")?; Some(DeletionVectorDescriptor { storage_type, path_or_inline_dv, @@ -236,9 +236,9 @@ impl RemoveVisitor { None }; - let base_row_id: Option = getters[13].get_opt(row_index, "remove.baseRowId")?; + let base_row_id: Option = getters[12].get_opt(row_index, "remove.baseRowId")?; let default_row_commit_version: Option = - getters[14].get_opt(row_index, "remove.defaultRowCommitVersion")?; + getters[13].get_opt(row_index, "remove.defaultRowCommitVersion")?; Ok(Remove { path, @@ -277,10 +277,9 @@ mod tests { use super::*; use crate::{ - actions::schemas::log_schema, + actions::{get_log_schema, ADD_NAME}, client::arrow_data::ArrowEngineData, client::sync::{json::SyncJsonHandler, SyncEngineInterface}, - schema::StructType, EngineData, EngineInterface, JsonHandler, }; @@ -302,7 +301,7 @@ mod tests { r#"{"metaData":{"id":"testId","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"value\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{"delta.enableDeletionVectors":"true","delta.columnMapping.mode":"none"},"createdTime":1677811175819}}"#, ] .into(); - let output_schema = Arc::new(log_schema().clone()); + let output_schema = Arc::new(get_log_schema().clone()); let parsed = handler .parse_json(string_array_to_engine_data(json_strings), output_schema) .unwrap(); @@ -331,12 +330,9 @@ mod tests { let configuration = HashMap::from_iter([ ( "delta.enableDeletionVectors".to_string(), - Some("true".to_string()), - ), - ( - "delta.columnMapping.mode".to_string(), - Some("none".to_string()), + "true".to_string(), ), + ("delta.columnMapping.mode".to_string(), "none".to_string()), ]); let expected = Metadata { id: "testId".into(), @@ -368,26 +364,26 @@ mod tests { r#"{"add":{"path":"c1=6/c2=a/part-00011-10619b10-b691-4fd0-acc4-2a9608499d7c.c000.snappy.parquet","partitionValues":{"c1":"6","c2":"a"},"size":452,"modificationTime":1670892998137,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"c3\":4},\"maxValues\":{\"c3\":4},\"nullCount\":{\"c3\":0}}"}}"#, ] .into(); - let output_schema = Arc::new(log_schema().clone()); + let output_schema = Arc::new(get_log_schema().clone()); let batch = json_handler .parse_json(string_array_to_engine_data(json_strings), output_schema) .unwrap(); - let add_schema = StructType::new(vec![crate::actions::schemas::ADD_FIELD.clone()]); + let add_schema = get_log_schema() + .project(&[ADD_NAME]) + .expect("Can't get add schema"); let mut add_visitor = AddVisitor::default(); - batch - .extract(Arc::new(add_schema), &mut add_visitor) - .unwrap(); + batch.extract(add_schema, &mut add_visitor).unwrap(); let add1 = Add { path: "c1=4/c2=c/part-00003-f525f459-34f9-46f5-82d6-d42121d883fd.c000.snappy.parquet".into(), partition_values: HashMap::from([ - ("c1".to_string(), Some("4".to_string())), - ("c2".to_string(), Some("c".to_string())), + ("c1".to_string(), "4".to_string()), + ("c2".to_string(), "c".to_string()), ]), size: 452, modification_time: 1670892998135, data_change: true, stats: Some("{\"numRecords\":1,\"minValues\":{\"c3\":5},\"maxValues\":{\"c3\":5},\"nullCount\":{\"c3\":0}}".into()), - tags: HashMap::new(), + tags: None, deletion_vector: None, base_row_id: None, default_row_commit_version: None, @@ -396,8 +392,8 @@ mod tests { let add2 = Add { path: "c1=5/c2=b/part-00007-4e73fa3b-2c88-424a-8051-f8b54328ffdb.c000.snappy.parquet".into(), partition_values: HashMap::from([ - ("c1".to_string(), Some("5".to_string())), - ("c2".to_string(), Some("b".to_string())), + ("c1".to_string(), "5".to_string()), + ("c2".to_string(), "b".to_string()), ]), modification_time: 1670892998136, stats: Some("{\"numRecords\":1,\"minValues\":{\"c3\":6},\"maxValues\":{\"c3\":6},\"nullCount\":{\"c3\":0}}".into()), @@ -406,8 +402,8 @@ mod tests { let add3 = Add { path: "c1=6/c2=a/part-00011-10619b10-b691-4fd0-acc4-2a9608499d7c.c000.snappy.parquet".into(), partition_values: HashMap::from([ - ("c1".to_string(), Some("6".to_string())), - ("c2".to_string(), Some("a".to_string())), + ("c1".to_string(), "6".to_string()), + ("c2".to_string(), "a".to_string()), ]), modification_time: 1670892998137, stats: Some("{\"numRecords\":1,\"minValues\":{\"c3\":4},\"maxValues\":{\"c3\":4},\"nullCount\":{\"c3\":0}}".into()), diff --git a/kernel/src/client/arrow_data.rs b/kernel/src/client/arrow_data.rs index 3ce883e25..1dd5c5f27 100644 --- a/kernel/src/client/arrow_data.rs +++ b/kernel/src/client/arrow_data.rs @@ -137,14 +137,14 @@ impl EngineMap for MapArray { None } - fn materialize(&self, row_index: usize) -> HashMap> { + fn materialize(&self, row_index: usize) -> HashMap { let mut ret = HashMap::new(); let map_val = self.value(row_index); let keys = map_val.column(0).as_string::(); let values = map_val.column(1).as_string::(); for (key, value) in keys.iter().zip(values.iter()) { - if let Some(key) = key { - ret.insert(key.into(), value.map(|v| v.into())); + if let (Some(key), Some(value)) = (key, value) { + ret.insert(key.into(), value.into()); } } ret @@ -338,7 +338,7 @@ mod tests { use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use crate::{ - actions::{schemas::log_schema, Metadata}, + actions::{get_log_schema, Metadata, Protocol}, client::sync::SyncEngineInterface, DeltaResult, EngineData, EngineInterface, }; @@ -361,7 +361,7 @@ mod tests { r#"{"metaData":{"id":"aff5cb91-8cd9-4195-aef9-446908507302","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"c1\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"c2\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"c3\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["c1","c2"],"configuration":{},"createdTime":1670892997849}}"#, ] .into(); - let output_schema = Arc::new(log_schema().clone()); + let output_schema = Arc::new(get_log_schema().clone()); let parsed = handler .parse_json(string_array_to_engine_data(json_strings), output_schema) .unwrap(); @@ -371,4 +371,21 @@ mod tests { assert_eq!(metadata.partition_columns, vec!("c1", "c2")); Ok(()) } + + #[test] + fn test_nullable_struct() -> DeltaResult<()> { + let client = SyncEngineInterface::new(); + let handler = client.get_json_handler(); + let json_strings: StringArray = vec![ + r#"{"metaData":{"id":"aff5cb91-8cd9-4195-aef9-446908507302","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"c1\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"c2\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"c3\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["c1","c2"],"configuration":{},"createdTime":1670892997849}}"#, + ] + .into(); + let output_schema = get_log_schema().project(&["metaData"])?; + let parsed = handler + .parse_json(string_array_to_engine_data(json_strings), output_schema) + .unwrap(); + let protocol = Protocol::try_new_from_data(parsed.as_ref())?; + assert!(protocol.is_none()); + Ok(()) + } } diff --git a/kernel/src/client/arrow_utils.rs b/kernel/src/client/arrow_utils.rs index 88ea3311f..6e7518364 100644 --- a/kernel/src/client/arrow_utils.rs +++ b/kernel/src/client/arrow_utils.rs @@ -2,39 +2,60 @@ use std::sync::Arc; -use crate::DeltaResult; +use crate::{schema::SchemaRef, DeltaResult, Error}; use arrow_array::RecordBatch; use arrow_schema::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; -use itertools::Itertools; use parquet::{arrow::ProjectionMask, schema::types::SchemaDescriptor}; -/// Get the indicies in `parquet_schema` of the specified columns in `requested_schema` +/// Get the indicies in `parquet_schema` of the specified columns in `requested_schema`. This +/// returns a tuples of (mask_indicies: Vec, reorder_indicies: +/// Vec). `mask_indicies` is used for generating the mask for reading from the +/// parquet file, and simply contains an entry for each index we wish to select from the parquet +/// file set to the index of the requested column in the parquet. `reorder_indicies` is used for +/// re-ordering and will be the same size as `requested_schema`. Each index in `reorder_indicies` +/// represents a column that will be in the read parquet data at that index. The value stored in +/// `reorder_indicies` is the position that the column should appear in the final output. For +/// example, if `reorder_indicies` is `[2,0,1]`, then the re-ordering code should take the third +/// column in the raw-read parquet data, and move it to the first column in the final output, the +/// first column to the second, and the second to the third. pub(crate) fn get_requested_indices( - requested_schema: &ArrowSchema, + requested_schema: &SchemaRef, parquet_schema: &ArrowSchemaRef, -) -> DeltaResult> { - let indicies = requested_schema - .fields +) -> DeltaResult<(Vec, Vec)> { + let requested_len = requested_schema.fields.len(); + let mut mask_indicies = vec![0; requested_len]; + let mut found_count = 0; // verify that we found all requested fields + let reorder_indicies = parquet_schema + .fields() .iter() - .map(|field| { - // todo: handle nested (then use `leaves` not `roots` below in generate_mask) - parquet_schema.index_of(field.name()) + .enumerate() + .filter_map(|(parquet_position, field)| { + requested_schema.index_of(field.name()).map(|index| { + found_count += 1; + mask_indicies[index] = parquet_position; + index + }) }) - .try_collect()?; - Ok(indicies) + .collect(); + if found_count != requested_len { + return Err(Error::generic( + "Didn't find all requested columns in parquet schema", + )); + } + Ok((mask_indicies, reorder_indicies)) } /// Create a mask that will only select the specified indicies from the parquet. Currently we only /// handle "root" level columns, and hence use `ProjectionMask::roots`, but will support leaf /// selection in the future. See issues #86 and #96 as well. pub(crate) fn generate_mask( - requested_schema: &ArrowSchema, + requested_schema: &SchemaRef, parquet_schema: &ArrowSchemaRef, parquet_physical_schema: &SchemaDescriptor, indicies: &[usize], ) -> Option { - if parquet_schema.fields.size() == requested_schema.fields.size() { + if parquet_schema.fields.size() == requested_schema.fields.len() { // we assume that in get_requested_indicies we will have caught any column name mismatches, // so here we can just say that if we request the same # of columns as the parquet file // actually has, we don't need to mask anything out @@ -47,26 +68,32 @@ pub(crate) fn generate_mask( } } -/// Reorder a RecordBatch to match `requested_schema`. This method takes `indicies` as computed by -/// [`get_requested_indicies`] as an optimization. If the indicies are in order, then we don't need -/// to do any re-ordering. +/// Reorder a RecordBatch to match `requested_ordering`. This method takes `mask_indicies` as +/// computed by [`get_requested_indicies`] as an optimization. If the indicies are in order, then we +/// don't need to do any re-ordering. Otherwise, for each non-zero value in `requested_ordering`, +/// the column at that index will be added in order to returned batch pub(crate) fn reorder_record_batch( - requested_schema: Arc, input_data: RecordBatch, - indicies: &[usize], + mask_indicies: &[usize], + requested_ordering: &[usize], ) -> DeltaResult { - if indicies.windows(2).all(|is| is[0] <= is[1]) { + if mask_indicies.windows(2).all(|is| is[0] <= is[1]) { // indicies is already sorted, meaning we requested in the order that the columns were // stored in the parquet Ok(input_data) } else { // requested an order different from the parquet, reorder - let reordered_columns = indicies + let input_schema = input_data.schema(); + let mut fields = Vec::with_capacity(requested_ordering.len()); + let reordered_columns = requested_ordering .iter() .map(|index| { - input_data.column(*index).clone() // cheap clones of `Arc`s + // cheap clones of `Arc`s + fields.push(input_schema.field(*index).clone()); + input_data.column(*index).clone() }) .collect(); - Ok(RecordBatch::try_new(requested_schema, reordered_columns)?) + let schema = Arc::new(ArrowSchema::new(fields)); + Ok(RecordBatch::try_new(schema, reordered_columns)?) } } diff --git a/kernel/src/client/default/json.rs b/kernel/src/client/default/json.rs index 6927ae957..4b537a0ab 100644 --- a/kernel/src/client/default/json.rs +++ b/kernel/src/client/default/json.rs @@ -249,7 +249,7 @@ mod tests { use super::*; use crate::{ - actions::schemas::log_schema, client::default::executor::tokio::TokioBackgroundExecutor, + actions::get_log_schema, client::default::executor::tokio::TokioBackgroundExecutor, }; fn string_array_to_engine_data(string_array: StringArray) -> Box { @@ -271,7 +271,7 @@ mod tests { r#"{"protocol":{"minReaderVersion":3,"minWriterVersion":7,"readerFeatures":["deletionVectors"],"writerFeatures":["deletionVectors"]}}"#, r#"{"metaData":{"id":"testId","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"value\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{"delta.enableDeletionVectors":"true","delta.columnMapping.mode":"none"},"createdTime":1677811175819}}"#, ]); - let output_schema = Arc::new(log_schema().clone()); + let output_schema = Arc::new(get_log_schema().clone()); let batch = handler .parse_json(string_array_to_engine_data(json_strings), output_schema) @@ -298,7 +298,7 @@ mod tests { }]; let handler = DefaultJsonHandler::new(store, Arc::new(TokioBackgroundExecutor::new())); - let physical_schema = Arc::new(ArrowSchema::try_from(log_schema()).unwrap()); + let physical_schema = Arc::new(ArrowSchema::try_from(get_log_schema()).unwrap()); let data: Vec = handler .read_json_files(files, Arc::new(physical_schema.try_into().unwrap()), None) .unwrap() diff --git a/kernel/src/client/default/parquet.rs b/kernel/src/client/default/parquet.rs index 671967f91..51d25bed0 100644 --- a/kernel/src/client/default/parquet.rs +++ b/kernel/src/client/default/parquet.rs @@ -55,9 +55,9 @@ impl ParquetHandler for DefaultParquetHandler { return Ok(Box::new(std::iter::empty())); } - let schema: ArrowSchemaRef = Arc::new(physical_schema.as_ref().try_into()?); - let file_reader = ParquetOpener::new(1024, schema.clone(), self.store.clone()); - let mut stream = FileStream::new(files.to_vec(), schema, file_reader)?; + let arrow_schema: ArrowSchemaRef = Arc::new(physical_schema.as_ref().try_into()?); + let file_reader = ParquetOpener::new(1024, physical_schema.clone(), self.store.clone()); + let mut stream = FileStream::new(files.to_vec(), arrow_schema, file_reader)?; // This channel will become the output iterator. // The stream will execute in the background and send results to this channel. @@ -95,19 +95,19 @@ struct ParquetOpener { // projection: Arc<[usize]>, batch_size: usize, limit: Option, - table_schema: ArrowSchemaRef, + table_schema: SchemaRef, store: Arc, } impl ParquetOpener { pub(crate) fn new( batch_size: usize, - schema: ArrowSchemaRef, + table_schema: SchemaRef, store: Arc, ) -> Self { Self { batch_size, - table_schema: schema, + table_schema, limit: None, store, } @@ -130,7 +130,8 @@ impl FileOpener for ParquetOpener { let mut reader = ParquetObjectReader::new(store, meta); let metadata = ArrowReaderMetadata::load_async(&mut reader, Default::default()).await?; let parquet_schema = metadata.schema(); - let indicies = get_requested_indices(&table_schema, parquet_schema)?; + let (indicies, requested_ordering) = + get_requested_indices(&table_schema, parquet_schema)?; let options = ArrowReaderOptions::new(); //.with_page_index(enable_page_index); let mut builder = ParquetRecordBatchStreamBuilder::new_with_options(reader, options).await?; @@ -152,7 +153,7 @@ impl FileOpener for ParquetOpener { let stream = stream.map(move |rbr| { // re-order each batch if needed rbr.map_err(Error::Parquet) - .and_then(|rb| reorder_record_batch(table_schema.clone(), rb, &indicies)) + .and_then(|rb| reorder_record_batch(rb, &indicies, &requested_ordering)) }); Ok(stream.boxed()) })) diff --git a/kernel/src/client/sync/parquet.rs b/kernel/src/client/sync/parquet.rs index 6d8b09aba..c279883a2 100644 --- a/kernel/src/client/sync/parquet.rs +++ b/kernel/src/client/sync/parquet.rs @@ -1,6 +1,5 @@ use std::fs::File; -use arrow_schema::Schema as ArrowSchema; use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ParquetRecordBatchReaderBuilder}; use tracing::debug; use url::Url; @@ -20,15 +19,10 @@ fn try_create_from_parquet(schema: SchemaRef, location: Url) -> DeltaResult DeltaResult` fn get<'a>(&'a self, row_index: usize, key: &str) -> Option<&'a str>; /// Materialize the entire map at `row_index` in the raw data into a `HashMap` - fn materialize(&self, row_index: usize) -> HashMap>; + fn materialize(&self, row_index: usize) -> HashMap; } /// A map item is useful if the Engine needs to know what row of raw data it needs to access to @@ -70,7 +70,7 @@ impl<'a> MapItem<'a> { self.map.get(self.row, key) } - pub fn materialize(&self) -> HashMap> { + pub fn materialize(&self) -> HashMap { self.map.materialize(self.row) } } @@ -149,14 +149,14 @@ impl<'a> TypedGetData<'a, Vec> for dyn GetData<'a> + '_ { } } -/// Provide an impl to get a map field as a `HashMap>`. Note that this will +/// Provide an impl to get a map field as a `HashMap`. Note that this will /// allocate the map and allocate for each entry -impl<'a> TypedGetData<'a, HashMap>> for dyn GetData<'a> + '_ { +impl<'a> TypedGetData<'a, HashMap> for dyn GetData<'a> + '_ { fn get_opt( &'a self, row_index: usize, field_name: &str, - ) -> DeltaResult>>> { + ) -> DeltaResult>> { let map_opt: Option> = self.get_opt(row_index, field_name)?; Ok(map_opt.map(|map| map.materialize())) } diff --git a/kernel/src/scan/file_stream.rs b/kernel/src/scan/file_stream.rs index 7f46017b2..3bdbef2a2 100644 --- a/kernel/src/scan/file_stream.rs +++ b/kernel/src/scan/file_stream.rs @@ -1,14 +1,14 @@ use std::collections::HashSet; -use std::sync::Arc; use either::Either; use tracing::debug; use super::data_skipping::DataSkippingFilter; +use crate::actions::{get_log_schema, ADD_NAME, REMOVE_NAME}; use crate::actions::{visitors::AddVisitor, visitors::RemoveVisitor, Add, Remove}; use crate::engine_data::{GetData, TypedGetData}; use crate::expressions::Expression; -use crate::schema::{SchemaRef, StructType}; +use crate::schema::SchemaRef; use crate::{DataVisitor, DeltaResult, EngineData, EngineInterface}; struct LogReplayScanner { @@ -101,18 +101,18 @@ impl LogReplayScanner { .map(|filter| filter.apply(actions)) .transpose()?; - let schema_to_use = StructType::new(if is_log_batch { - vec![ - crate::actions::schemas::ADD_FIELD.clone(), - crate::actions::schemas::REMOVE_FIELD.clone(), - ] + let schema_to_use = if is_log_batch { + // NB: We _must_ pass these in the order `ADD_NAME, REMOVE_NAME` as the visitor assumes + // the Add action comes first. The [`project`] method honors this order, so this works + // as long as we keep this order here. + get_log_schema().project(&[ADD_NAME, REMOVE_NAME])? } else { // All checkpoint actions are already reconciled and Remove actions in checkpoint files // only serve as tombstones for vacuum jobs. So no need to load them here. - vec![crate::actions::schemas::ADD_FIELD.clone()] - }); + get_log_schema().project(&[ADD_NAME])? + }; let mut visitor = AddRemoveVisitor::new(selection_vector, is_log_batch); - actions.extract(Arc::new(schema_to_use), &mut visitor)?; + actions.extract(schema_to_use, &mut visitor)?; for remove in visitor.removes.into_iter() { let dv_id = remove.dv_unique_id(); diff --git a/kernel/src/scan/mod.rs b/kernel/src/scan/mod.rs index 8417c7702..f69aa3dba 100644 --- a/kernel/src/scan/mod.rs +++ b/kernel/src/scan/mod.rs @@ -4,7 +4,7 @@ use itertools::Itertools; use tracing::debug; use self::file_stream::log_replay_iter; -use crate::actions::Add; +use crate::actions::{get_log_schema, Add, ADD_NAME, REMOVE_NAME}; use crate::expressions::{Expression, Scalar}; use crate::schema::{DataType, SchemaRef, StructField, StructType}; use crate::snapshot::Snapshot; @@ -51,8 +51,8 @@ impl ScanBuilder { self } - /// Optionally provide a [`Schema`] for columns to select from the [`Snapshot`]. See - /// [`with_schema`] for details. If schema_opt is `None` this is a no-op. + /// Optionally provide a [`SchemaRef`] for columns to select from the [`Snapshot`]. See + /// [`ScanBuilder::with_schema`] for details. If schema_opt is `None` this is a no-op. pub fn with_schema_opt(self, schema_opt: Option) -> Self { match schema_opt { Some(schema) => self.with_schema(schema), @@ -144,13 +144,8 @@ impl Scan { &self, engine_interface: &dyn EngineInterface, ) -> DeltaResult>> { - let commit_read_schema = Arc::new(StructType::new(vec![ - crate::actions::schemas::ADD_FIELD.clone(), - crate::actions::schemas::REMOVE_FIELD.clone(), - ])); - let checkpoint_read_schema = Arc::new(StructType::new(vec![ - crate::actions::schemas::ADD_FIELD.clone(), - ])); + let commit_read_schema = get_log_schema().project(&[ADD_NAME, REMOVE_NAME])?; + let checkpoint_read_schema = get_log_schema().project(&[ADD_NAME])?; let log_iter = self.snapshot.log_segment.replay( engine_interface, @@ -285,12 +280,9 @@ impl Scan { } } -fn parse_partition_value( - raw: Option<&Option>, - data_type: &DataType, -) -> DeltaResult { +fn parse_partition_value(raw: Option<&String>, data_type: &DataType) -> DeltaResult { match raw { - Some(Some(v)) => match data_type { + Some(v) => match data_type { DataType::Primitive(primitive) => primitive.parse_scalar(v), _ => Err(Error::generic(format!( "Unexpected partition column type: {data_type:?}" @@ -386,7 +378,7 @@ mod tests { for (raw, data_type, expected) in &cases { let value = parse_partition_value( - Some(&Some(raw.to_string())), + Some(&raw.to_string()), &DataType::Primitive(data_type.clone()), ) .unwrap(); diff --git a/kernel/src/schema.rs b/kernel/src/schema.rs index 7d3050af2..396e82242 100644 --- a/kernel/src/schema.rs +++ b/kernel/src/schema.rs @@ -3,8 +3,11 @@ use std::sync::Arc; use std::{collections::HashMap, fmt::Display}; use indexmap::IndexMap; +use itertools::Itertools; use serde::{Deserialize, Serialize}; +use crate::{DeltaResult, Error}; + pub type Schema = StructType; pub type SchemaRef = Arc; @@ -140,10 +143,38 @@ impl StructType { } } + /// Get a [`StructType`] containing [`StructField`]s of the given names. The order of fields in + /// the returned schema will match the order passed to this function, which can be different + /// from this order in this schema. Returns an Err if a specified field doesn't exist. + pub fn project_as_struct(&self, names: &[impl AsRef]) -> DeltaResult { + let fields = names + .iter() + .map(|name| { + self.fields + .get(name.as_ref()) + .cloned() + .ok_or_else(|| Error::missing_column(name.as_ref())) + }) + .try_collect()?; + Ok(Self::new(fields)) + } + + /// Get a [`SchemaRef`] containing [`StructField`]s of the given names. The order of fields in + /// the returned schema will match the order passed to this function, which can be different + /// from this order in this schema. Returns an Err if a specified field doesn't exist. + pub fn project(&self, names: &[impl AsRef]) -> DeltaResult { + let struct_type = self.project_as_struct(names)?; + Ok(Arc::new(struct_type)) + } + pub fn field(&self, name: impl AsRef) -> Option<&StructField> { self.fields.get(name.as_ref()) } + pub fn index_of(&self, name: impl AsRef) -> Option { + self.fields.get_index_of(name.as_ref()) + } + pub fn fields(&self) -> impl Iterator { self.fields.values() } diff --git a/kernel/src/snapshot.rs b/kernel/src/snapshot.rs index 850097e62..74bf61174 100644 --- a/kernel/src/snapshot.rs +++ b/kernel/src/snapshot.rs @@ -9,9 +9,9 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; use url::Url; -use crate::actions::{Metadata, Protocol}; +use crate::actions::{get_log_schema, Metadata, Protocol, METADATA_NAME, PROTOCOL_NAME}; use crate::path::LogPath; -use crate::schema::{Schema, SchemaRef, StructType}; +use crate::schema::{Schema, SchemaRef}; use crate::{DeltaResult, EngineInterface, Error, FileMeta, FileSystemClient, Version}; use crate::{EngineData, Expression}; @@ -69,10 +69,7 @@ impl LogSegment { &self, engine_interface: &dyn EngineInterface, ) -> DeltaResult> { - let schema = Arc::new(StructType::new(vec![ - crate::actions::schemas::METADATA_FIELD.clone(), - crate::actions::schemas::PROTOCOL_FIELD.clone(), - ])); + let schema = get_log_schema().project(&[PROTOCOL_NAME, METADATA_NAME])?; // read the same protocol and metadata schema for both commits and checkpoints // TODO add metadata.table_id is not null and protocol.something_required is not null let data_batches = self.replay(engine_interface, schema.clone(), schema, None)?;