Skip to content

Commit b51b7b1

Browse files
committed
switch to using "project" to get schemas
1 parent c7785c2 commit b51b7b1

File tree

9 files changed

+125
-120
lines changed

9 files changed

+125
-120
lines changed

derive-macros/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ proc-macro = true
1414

1515
[dependencies]
1616
proc-macro2 = "1"
17-
syn = "2.0"
17+
syn = { version = "2.0", features = ["extra-traits"] }
1818
quote = "1.0"
1919

2020

derive-macros/src/lib.rs

Lines changed: 2 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,15 @@
1-
use proc_macro2::{Ident, Spacing, TokenStream, TokenTree};
1+
use proc_macro2::{Ident, TokenStream};
22
use quote::{quote, quote_spanned};
33
use syn::spanned::Spanned;
4-
use syn::{
5-
parse_macro_input, Attribute, Data, DataStruct, DeriveInput, Fields, Meta, PathArguments, Type,
6-
};
7-
8-
static SCHEMA_ERR_STR: &str = "schema(...) only supports schema(name = name)";
9-
10-
// Return the ident to use as the schema name if it's been specified in the attributes of the struct
11-
fn get_schema_name_from_attr<'a>(attrs: impl Iterator<Item = &'a Attribute>) -> Option<Ident> {
12-
for attr in attrs {
13-
if let Meta::List(list) = &attr.meta {
14-
if let Some(attr_name) = list.path.segments.iter().last() {
15-
if attr_name.ident == "schema" {
16-
// We have some schema(...) attribute, see if we've specified a different name
17-
let tokens: Vec<TokenTree> = list.tokens.clone().into_iter().collect();
18-
match tokens[..] {
19-
// we only support `name = name` style
20-
[TokenTree::Ident(ref name_ident), TokenTree::Punct(ref punct), TokenTree::Ident(ref schema_ident)] =>
21-
{
22-
assert!(name_ident == "name", "{}", SCHEMA_ERR_STR);
23-
assert!(punct.as_char() == '=', "{}", SCHEMA_ERR_STR);
24-
assert!(punct.spacing() == Spacing::Alone, "{}", SCHEMA_ERR_STR);
25-
return Some(schema_ident.clone());
26-
}
27-
_ => panic!("{}", SCHEMA_ERR_STR),
28-
}
29-
} else {
30-
panic!("Schema only accepts `schema` as an extra attribute")
31-
}
32-
}
33-
}
34-
}
35-
None
36-
}
4+
use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Fields, PathArguments, Type};
375

386
#[proc_macro_derive(Schema, attributes(schema))]
397
pub fn derive_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
408
let input = parse_macro_input!(input as DeriveInput);
419
let struct_ident = input.ident;
42-
let schema_name = get_schema_name_from_attr(input.attrs.iter()).unwrap_or_else(|| {
43-
// default to the struct name, but lowercased
44-
Ident::new(
45-
&struct_ident.to_string().to_lowercase(),
46-
struct_ident.span(),
47-
)
48-
});
4910

5011
let schema_fields = gen_schema_fields(&input.data);
5112
let output = quote! {
52-
impl crate::actions::schemas::GetSchema for #struct_ident {
53-
fn get_schema() -> crate::schema::SchemaRef {
54-
use crate::actions::schemas::GetField;
55-
static SCHEMA_LOCK: std::sync::OnceLock<crate::schema::SchemaRef> = std::sync::OnceLock::new();
56-
SCHEMA_LOCK.get_or_init(|| {
57-
std::sync::Arc::new(crate::schema::StructType::new(vec![
58-
Self::get_field(stringify!(#schema_name))
59-
]))
60-
}).clone() // cheap clone, it's an Arc
61-
}
62-
}
63-
6413
impl crate::actions::schemas::GetField for #struct_ident {
6514
fn get_field(name: impl Into<String>) -> crate::schema::StructField {
6615
use crate::actions::schemas::GetField;

kernel/src/actions/mod.rs

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,39 @@ pub(crate) mod schemas;
44
pub(crate) mod visitors;
55

66
use derive_macros::Schema;
7-
use std::collections::HashMap;
7+
use lazy_static::lazy_static;
88
use visitors::{AddVisitor, MetadataVisitor, ProtocolVisitor};
99

10+
use self::deletion_vector::DeletionVectorDescriptor;
11+
use crate::actions::schemas::GetField;
1012
use crate::{schema::StructType, DeltaResult, EngineData};
1113

12-
use self::{deletion_vector::DeletionVectorDescriptor, schemas::GetSchema};
14+
use std::collections::HashMap;
15+
16+
lazy_static! {
17+
static ref LOG_SCHEMA: StructType = StructType::new(
18+
vec![
19+
Option::<Add>::get_field("add"),
20+
Option::<Remove>::get_field("remove"),
21+
Option::<Metadata>::get_field("metaData"),
22+
Option::<Protocol>::get_field("protocol"),
23+
// We don't support the following actions yet
24+
//Option<Cdc>::get_field("cdc"),
25+
//Option<CommitInfo>::get_field("commitInfo"),
26+
//Option<DomainMetadata>::get_field("domainMetadata"),
27+
//Option<Transaction>::get_field("txn"),
28+
]
29+
);
30+
}
31+
32+
pub(crate) static ADD_NAME: &str = "add";
33+
pub(crate) static REMOVE_NAME: &str = "remove";
34+
pub(crate) static METADATA_NAME: &str = "metaData";
35+
pub(crate) static PROTOCOL_NAME: &str = "protocol";
36+
37+
pub(crate) fn get_log_schema() -> &'static StructType {
38+
&LOG_SCHEMA
39+
}
1340

1441
#[derive(Debug, Clone, PartialEq, Eq, Schema)]
1542
pub struct Format {
@@ -52,7 +79,10 @@ pub struct Metadata {
5279
impl Metadata {
5380
pub fn try_new_from_data(data: &dyn EngineData) -> DeltaResult<Option<Metadata>> {
5481
let mut visitor = MetadataVisitor::default();
55-
data.extract(Metadata::get_schema(), &mut visitor)?;
82+
data.extract(
83+
get_log_schema().project_as_schema(&[METADATA_NAME])?,
84+
&mut visitor,
85+
)?;
5686
Ok(visitor.metadata)
5787
}
5888

@@ -80,7 +110,10 @@ pub struct Protocol {
80110
impl Protocol {
81111
pub fn try_new_from_data(data: &dyn EngineData) -> DeltaResult<Option<Protocol>> {
82112
let mut visitor = ProtocolVisitor::default();
83-
data.extract(Protocol::get_schema(), &mut visitor)?;
113+
data.extract(
114+
get_log_schema().project_as_schema(&[PROTOCOL_NAME])?,
115+
&mut visitor,
116+
)?;
84117
Ok(visitor.protocol)
85118
}
86119
}
@@ -134,7 +167,10 @@ impl Add {
134167
/// Since we always want to parse multiple adds from data, we return a `Vec<Add>`
135168
pub fn parse_from_data(data: &dyn EngineData) -> DeltaResult<Vec<Add>> {
136169
let mut visitor = AddVisitor::default();
137-
data.extract(Add::get_schema(), &mut visitor)?;
170+
data.extract(
171+
get_log_schema().project_as_schema(&[ADD_NAME])?,
172+
&mut visitor,
173+
)?;
138174
Ok(visitor.adds)
139175
}
140176

@@ -189,43 +225,18 @@ impl Remove {
189225
}
190226
}
191227

192-
use crate::actions::schemas::GetField;
193-
use lazy_static::lazy_static;
194-
195-
lazy_static! {
196-
static ref LOG_SCHEMA: StructType = StructType::new(
197-
vec![
198-
Option::<Add>::get_field("add"),
199-
Option::<Remove>::get_field("remove"),
200-
Option::<Metadata>::get_field("metaData"),
201-
Option::<Protocol>::get_field("protocol"),
202-
// We don't support the following actions yet
203-
//Option<Cdc>::get_field("cdc"),
204-
//Option<CommitInfo>::get_field("commitInfo"),
205-
//Option<DomainMetadata>::get_field("domainMetadata"),
206-
//Option<Transaction>::get_field("txn"),
207-
]
208-
);
209-
}
210-
211-
#[cfg(test)]
212-
pub(crate) fn get_log_schema() -> &'static StructType {
213-
&LOG_SCHEMA
214-
}
215-
216228
#[cfg(test)]
217229
mod tests {
218230
use std::sync::Arc;
219231

220232
use super::*;
221-
use crate::{
222-
actions::schemas::GetSchema,
223-
schema::{ArrayType, DataType, MapType, StructField},
224-
};
233+
use crate::schema::{ArrayType, DataType, MapType, StructField};
225234

226235
#[test]
227236
fn test_metadata_schema() {
228-
let schema = Metadata::get_schema();
237+
let schema = get_log_schema()
238+
.project_as_schema(&["metaData"])
239+
.expect("Couldn't get metaData field");
229240

230241
let expected = Arc::new(StructType::new(vec![StructField::new(
231242
"metaData",
@@ -258,7 +269,7 @@ mod tests {
258269
false,
259270
),
260271
]),
261-
false,
272+
true,
262273
)]));
263274
assert_eq!(schema, expected);
264275
}
@@ -295,7 +306,9 @@ mod tests {
295306

296307
#[test]
297308
fn test_remove_schema() {
298-
let schema = Remove::get_schema();
309+
let schema = get_log_schema()
310+
.project_as_schema(&["remove"])
311+
.expect("Couldn't get remove field");
299312
let expected = Arc::new(StructType::new(vec![StructField::new(
300313
"remove",
301314
StructType::new(vec![
@@ -310,7 +323,7 @@ mod tests {
310323
StructField::new("baseRowId", DataType::LONG, true),
311324
StructField::new("defaultRowCommitVersion", DataType::LONG, true),
312325
]),
313-
false,
326+
true,
314327
)]));
315328
assert_eq!(schema, expected);
316329
}

kernel/src/actions/schemas.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,7 @@
22
33
use std::collections::HashMap;
44

5-
use crate::schema::{ArrayType, DataType, MapType, SchemaRef, StructField};
6-
7-
/// A trait that says you can ask for the [`Schema`] of the implementor
8-
pub(crate) trait GetSchema {
9-
fn get_schema() -> SchemaRef;
10-
}
5+
use crate::schema::{ArrayType, DataType, MapType, StructField};
116

127
/// A trait that allows getting a `StructField` based on the provided name and nullability
138
pub(crate) trait GetField {

kernel/src/actions/visitors.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ mod tests {
262262

263263
use super::*;
264264
use crate::{
265-
actions::get_log_schema,
266-
actions::schemas::GetSchema,
265+
actions::{get_log_schema, ADD_NAME},
267266
simple_client::{data::SimpleData, json::SimpleJsonHandler, SimpleClient},
268267
EngineData, EngineInterface, JsonHandler,
269268
};
@@ -356,7 +355,9 @@ mod tests {
356355
let batch = json_handler
357356
.parse_json(string_array_to_engine_data(json_strings), output_schema)
358357
.unwrap();
359-
let add_schema = Add::get_schema();
358+
let add_schema = get_log_schema()
359+
.project_as_schema(&[ADD_NAME])
360+
.expect("Can't get add schema");
360361
let mut add_visitor = AddVisitor::default();
361362
batch.extract(add_schema, &mut add_visitor).unwrap();
362363
let add1 = Add {

kernel/src/scan/file_stream.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
use std::collections::HashSet;
2-
use std::sync::Arc;
32

43
use either::Either;
54
use tracing::debug;
65

76
use super::data_skipping::DataSkippingFilter;
8-
use crate::actions::schemas::{GetField, GetSchema};
7+
use crate::actions::{get_log_schema, ADD_NAME, REMOVE_NAME};
98
use crate::actions::{visitors::AddVisitor, visitors::RemoveVisitor, Add, Remove};
109
use crate::engine_data::{GetData, TypedGetData};
1110
use crate::expressions::Expression;
12-
use crate::schema::{SchemaRef, StructType};
11+
use crate::schema::SchemaRef;
1312
use crate::{DataVisitor, DeltaResult, EngineData, EngineInterface};
1413

1514
struct LogReplayScanner {
@@ -82,14 +81,11 @@ impl LogReplayScanner {
8281
};
8382

8483
let schema_to_use = if is_log_batch {
85-
Arc::new(StructType::new(vec![
86-
Option::<Add>::get_field("add"),
87-
Option::<Remove>::get_field("remove"),
88-
]))
84+
get_log_schema().project_as_schema(&[ADD_NAME, REMOVE_NAME])?
8985
} else {
9086
// All checkpoint actions are already reconciled and Remove actions in checkpoint files
9187
// only serve as tombstones for vacuum jobs. So no need to load them here.
92-
Add::get_schema()
88+
get_log_schema().project_as_schema(&[ADD_NAME])?
9389
};
9490
let mut visitor = AddRemoveVisitor::default();
9591
actions.extract(schema_to_use, &mut visitor)?;

kernel/src/scan/mod.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ use std::sync::Arc;
33
use itertools::Itertools;
44

55
use self::file_stream::log_replay_iter;
6-
use crate::actions::schemas::GetField;
7-
use crate::actions::{Add, Remove};
6+
use crate::actions::{get_log_schema, Add, ADD_NAME, REMOVE_NAME};
87
use crate::expressions::{Expression, Scalar};
98
use crate::schema::{DataType, SchemaRef, StructType};
109
use crate::snapshot::Snapshot;
@@ -129,10 +128,7 @@ impl Scan {
129128
&self,
130129
engine_interface: &dyn EngineInterface,
131130
) -> DeltaResult<impl Iterator<Item = DeltaResult<Add>>> {
132-
let action_schema = Arc::new(StructType::new(vec![
133-
Option::<Add>::get_field("add"),
134-
Option::<Remove>::get_field("remove"),
135-
]));
131+
let action_schema = get_log_schema().project_as_schema(&[ADD_NAME, REMOVE_NAME])?;
136132

137133
let log_iter = self.snapshot.log_segment.replay(
138134
engine_interface,

kernel/src/schema.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ use std::sync::Arc;
33
use std::{collections::HashMap, fmt::Display};
44

55
use indexmap::IndexMap;
6+
use itertools::Itertools;
67
use serde::{Deserialize, Serialize};
78

9+
use crate::{DeltaResult, Error};
10+
811
pub type Schema = StructType;
912
pub type SchemaRef = Arc<StructType>;
1013

@@ -140,6 +143,38 @@ impl StructType {
140143
}
141144
}
142145

146+
/// Get a [`StructType`] containing [`StructField`]s of the given names, preserving the original
147+
/// order of fields. Returns an Err if a specified field doesn't exist
148+
pub fn project(&self, names: &[impl AsRef<str>]) -> DeltaResult<StructType> {
149+
let mut indexes: Vec<usize> = names
150+
.iter()
151+
.map(|name| {
152+
self.fields
153+
.get_index_of(name.as_ref())
154+
.ok_or_else(|| Error::missing_column(name.as_ref()))
155+
})
156+
.try_collect()?;
157+
indexes.sort(); // keep schema order
158+
let fields: Vec<StructField> = indexes
159+
.iter()
160+
.map(|index| {
161+
self.fields
162+
.get_index(*index)
163+
.expect("get_index_of returned non-existant index")
164+
.1
165+
.clone()
166+
})
167+
.collect();
168+
Ok(Self::new(fields))
169+
}
170+
171+
/// Get a [`SchemaRef`] containing [`StructField`]s of the given names, preserving the original
172+
/// order of fields. Returns an Err if a specified field doesn't exist
173+
pub fn project_as_schema(&self, names: &[impl AsRef<str>]) -> DeltaResult<SchemaRef> {
174+
let struct_type = self.project(names)?;
175+
Ok(Arc::new(struct_type))
176+
}
177+
143178
pub fn field(&self, name: impl AsRef<str>) -> Option<&StructField> {
144179
self.fields.get(name.as_ref())
145180
}

0 commit comments

Comments
 (0)