diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 588181b14421..a2f5a45e7b9b 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -132,7 +132,7 @@ pub trait SerializerRegistry: Debug + Send + Sync { fn serialize_logical_plan( &self, node: &dyn UserDefinedLogicalNode, - ) -> Result> { + ) -> Result { not_impl_err!( "Serializing user defined logical plan node `{}` is not supported", node.name() @@ -151,17 +151,16 @@ pub trait SerializerRegistry: Debug + Send + Sync { ) } - /// Serialized table definition for UDTFs or manually registered table providers that can't be - /// marshaled by reference. Should return some benign error for regular tables that can be - /// found/restored by name in the destination execution context. - fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result> { - not_impl_err!("No custom table support") + /// Serialized table definition for UDTFs or some other table provider implementation that + /// can't be marshaled by reference. + fn serialize_custom_table( + &self, + _table: &dyn TableSource, + ) -> Result> { + Ok(None) } - /// Deserialize the custom table with the given name. - /// Note: more often than not, the name can't be used as a discriminator if multiple different - /// `TableSource` and/or `TableProvider` implementations are expected (this is particularly true - /// for UDTFs in DataFusion, which are always registered under the same name: `tmp_table`). + /// Deserialize a custom table. fn deserialize_custom_table( &self, name: &str, @@ -171,6 +170,11 @@ pub trait SerializerRegistry: Debug + Send + Sync { } } +/// A sequence of bytes with a string qualifier. Meant to encapsulate serialized extensions +/// that need to carry their type, e.g. the `type_url` for protobuf messages. +#[derive(Debug, Clone)] +pub struct NamedBytes(pub String, pub Vec); + /// A [`FunctionRegistry`] that uses in memory [`HashMap`]s #[derive(Default, Debug)] pub struct MemoryFunctionRegistry { diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9618bc4a59fc..6a7857bc0a1e 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension, - LogicalPlan, Operator, Projection, SortExpr, Subquery, TableScan, TryCast, Values, + LogicalPlan, Operator, Projection, SortExpr, Subquery, TableSource, TryCast, Values, }; use substrait::proto::aggregate_rel::Grouping; use substrait::proto::expression as substrait_expression; @@ -462,9 +462,7 @@ pub trait SubstraitConsumer: Send + Sync + Sized { fn consume_extension_table( &self, extension_table: &ExtensionTable, - _schema: &DFSchema, - _projection: &Option, - ) -> Result { + ) -> Result> { if let Some(ext_detail) = extension_table.detail.as_ref() { substrait_err!( "Missing handler for extension table: {}", @@ -599,24 +597,11 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { fn consume_extension_table( &self, extension_table: &ExtensionTable, - schema: &DFSchema, - projection: &Option, - ) -> Result { + ) -> Result> { if let Some(ext_detail) = &extension_table.detail { - let source = self - .state + self.state .serializer_registry() - .deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?; - let table_name = ext_detail - .type_url - .rsplit_once('/') - .map(|(_, name)| name) - .unwrap_or(&ext_detail.type_url); - let table_scan = TableScan::try_new(table_name, source, None, vec![], None)?; - let plan = LogicalPlan::TableScan(table_scan); - ensure_schema_compatibility(plan.schema(), schema.clone())?; - let schema = apply_masking(schema.clone(), projection)?; - apply_projection(plan, schema) + .deserialize_custom_table(&ext_detail.type_url, &ext_detail.value) } else { substrait_err!("Unexpected empty detail in ExtensionTable") } @@ -1366,26 +1351,14 @@ pub async fn from_read_rel( read: &ReadRel, ) -> Result { async fn read_with_schema( - consumer: &impl SubstraitConsumer, table_ref: TableReference, + table_source: Arc, schema: DFSchema, projection: &Option, ) -> Result { let schema = schema.replace_qualifier(table_ref.clone()); - let plan = { - let provider = match consumer.resolve_table_ref(&table_ref).await? { - Some(ref provider) => Arc::clone(provider), - _ => return plan_err!("No table named '{table_ref}'"), - }; - - LogicalPlanBuilder::scan( - table_ref, - provider_as_source(Arc::clone(&provider)), - None, - )? - .build()? - }; + let plan = { LogicalPlanBuilder::scan(table_ref, table_source, None)?.build()? }; ensure_schema_compatibility(plan.schema(), schema.clone())?; @@ -1394,6 +1367,17 @@ pub async fn from_read_rel( apply_projection(plan, schema) } + async fn table_source( + consumer: &impl SubstraitConsumer, + table_ref: &TableReference, + ) -> Result> { + if let Some(provider) = consumer.resolve_table_ref(table_ref).await? { + Ok(provider_as_source(provider)) + } else { + plan_err!("No table named '{table_ref}'") + } + } + let named_struct = read.base_schema.as_ref().ok_or_else(|| { substrait_datafusion_err!("No base schema provided for Read Relation") })?; @@ -1419,10 +1403,10 @@ pub async fn from_read_rel( table: nt.names[2].clone().into(), }, }; - + let table_source = table_source(consumer, &table_reference).await?; read_with_schema( - consumer, table_reference, + table_source, substrait_schema, &read.projection, ) @@ -1501,17 +1485,35 @@ pub async fn from_read_rel( let name = filename.unwrap(); // directly use unwrap here since we could determine it is a valid one let table_reference = TableReference::Bare { table: name.into() }; + let table_source = table_source(consumer, &table_reference).await?; read_with_schema( - consumer, table_reference, + table_source, substrait_schema, &read.projection, ) .await } Some(ReadType::ExtensionTable(ext)) => { - consumer.consume_extension_table(ext, &substrait_schema, &read.projection) + // look for the original table name under `rel.common.hint.alias` + // in case the producer was kind enough to put it there. + let name_hint = read + .common + .as_ref() + .and_then(|rel_common| rel_common.hint.as_ref()) + .map(|hint| hint.alias.as_str().trim()) + .filter(|alias| !alias.is_empty()); + // if no name hint was provided, use the name that datafusion + // sets for UDTFs + let table_name = name_hint.unwrap_or("tmp_table"); + read_with_schema( + TableReference::from(table_name), + consumer.consume_extension_table(ext)?, + substrait_schema, + &read.projection, + ) + .await } None => { substrait_err!("Unexpected empty read_type") @@ -1917,7 +1919,7 @@ pub async fn from_substrait_sorts( }, None => not_impl_err!("Sort without sort kind is invalid"), }; - let (asc, nulls_first) = asc_nullfirst.unwrap(); + let (asc, nulls_first) = asc_nullfirst?; sorts.push(Sort { expr, asc, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e5cbcd4dbe66..ac0b101f4225 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -22,7 +22,11 @@ use std::sync::Arc; use substrait::proto::expression_reference::ExprType; use datafusion::arrow::datatypes::{Field, IntervalUnit}; -use datafusion::logical_expr::{Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, TableSource, TryCast, Union, Values, Window, WindowFrameUnits}; +use datafusion::logical_expr::{ + Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, + Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, + TableSource, TryCast, Union, Values, Window, WindowFrameUnits, +}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -50,9 +54,10 @@ use datafusion::execution::SessionState; use datafusion::logical_expr::expr::{ Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction, }; +use datafusion::logical_expr::registry::NamedBytes; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; -use pbjson_types::{Any as ProtoAny, Any}; +use pbjson_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; use substrait::proto::expression::cast::FailureBehavior; use substrait::proto::expression::field_reference::{RootReference, RootType}; @@ -66,8 +71,8 @@ use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::expression::ScalarFunction; use substrait::proto::read_rel::{ExtensionTable, VirtualTable}; -use substrait::proto::rel_common::EmitKind; use substrait::proto::rel_common::EmitKind::Emit; +use substrait::proto::rel_common::{EmitKind, Hint}; use substrait::proto::{ fetch_rel, rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, RelCommon, @@ -363,10 +368,10 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_in_subquery(self, in_subquery, schema) } - fn handle_extension_table( + fn handle_custom_table( &mut self, _table: &dyn TableSource, - ) -> Result { + ) -> Result> { not_impl_err!("Not implemented") } } @@ -395,12 +400,12 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { } fn handle_extension(&mut self, plan: &Extension) -> Result> { - let extension_bytes = self + let NamedBytes(type_url, bytes) = self .serializer_registry .serialize_logical_plan(plan.node.as_ref())?; let detail = ProtoAny { - type_url: plan.node.name().to_string(), - value: extension_bytes.into(), + type_url, + value: bytes.into(), }; let mut inputs_rel = plan .node @@ -429,14 +434,22 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { })) } - fn handle_extension_table(&mut self, table: &dyn TableSource) -> Result { - let bytes = self.serializer_registry.serialize_custom_table(table)?; - Ok(ExtensionTable { - detail: Some(Any { - type_url: "/substrait.ExtensionTable".into(), - value: bytes.into(), - }) - }) + fn handle_custom_table( + &mut self, + table: &dyn TableSource, + ) -> Result> { + if let Some(NamedBytes(type_url, bytes)) = + self.serializer_registry.serialize_custom_table(table)? + { + Ok(Some(ExtensionTable { + detail: Some(ProtoAny { + type_url, + value: bytes.into(), + }), + })) + } else { + Ok(None) + } } } @@ -572,21 +585,32 @@ pub fn from_table_scan( let table_schema = scan.source.schema().to_dfschema_ref()?; let base_schema = to_substrait_named_struct(&table_schema)?; - let table = if let Ok(ext_table) = producer - .handle_extension_table(scan.source.as_ref()) - { - ReadType::ExtensionTable(ext_table) - } else { - ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - }) - }; - + let (table, common) = + if let Ok(Some(ext_table)) = producer.handle_custom_table(scan.source.as_ref()) { + ( + ReadType::ExtensionTable(ext_table), + Some(RelCommon { + hint: Some(Hint { + // store the original table name as rel.common.hint.alias + alias: scan.table_name.to_string(), + ..Default::default() + }), + ..Default::default() + }), + ) + } else { + ( + ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + }), + None, + ) + }; Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, + common, base_schema: Some(base_schema), filter: None, best_effort_filter: None, @@ -1715,7 +1739,7 @@ pub fn from_in_subquery( subquery_type: Some( substrait::proto::expression::subquery::SubqueryType::InPredicate( Box::new(InPredicate { - needles: (vec![substrait_expr]), + needles: vec![substrait_expr], haystack: Some(subquery_plan), }), ), @@ -2909,6 +2933,7 @@ mod test { #[tokio::test] async fn round_trip_extension_table() { const TABLE_NAME: &str = "custom_table"; + const TYPE_URL: &str = "/substrait.test.CustomTable"; const SERIALIZED: &[u8] = "table definition".as_bytes(); fn custom_table() -> Arc { @@ -2921,9 +2946,12 @@ mod test { #[derive(Debug)] struct Registry; impl SerializerRegistry for Registry { - fn serialize_custom_table(&self, table: &dyn TableSource) -> Result> { + fn serialize_custom_table( + &self, + table: &dyn TableSource, + ) -> Result> { if table.schema() == custom_table().schema() { - Ok(SERIALIZED.to_vec()) + Ok(Some(NamedBytes(TYPE_URL.to_string(), SERIALIZED.to_vec()))) } else { Err(DataFusionError::Internal("Not our table".into())) } @@ -2933,7 +2961,7 @@ mod test { name: &str, bytes: &[u8], ) -> Result> { - if name == TABLE_NAME && bytes == SERIALIZED { + if name == TYPE_URL && bytes == SERIALIZED { Ok(Arc::new(DefaultTableSource::new(custom_table()))) } else { panic!("Unexpected extension table: {name}"); @@ -2965,7 +2993,7 @@ mod test { assert_contains!( // confirm that the Substrait plan contains our custom_table as an ExtensionTable serde_json::to_string(substrait.as_ref()).unwrap(), - format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TABLE_NAME}","#) + format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TYPE_URL}","#) ); remote // make sure the restored plan is fully working in the remote context .execute_logical_plan(restored.clone()) diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 7045729493b1..0a9c8e525745 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -31,6 +31,7 @@ use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::logical_expr::registry::NamedBytes; use datafusion::logical_expr::{ Extension, LogicalPlan, PartitionEvaluator, Repartition, UserDefinedLogicalNode, Values, Volatility, @@ -50,13 +51,13 @@ impl SerializerRegistry for MockSerializerRegistry { fn serialize_logical_plan( &self, node: &dyn UserDefinedLogicalNode, - ) -> Result> { + ) -> Result { if node.name() == "MockUserDefinedLogicalPlan" { let node = node .as_any() .downcast_ref::() .unwrap(); - node.serialize() + Ok(NamedBytes(node.name().to_string(), node.serialize()?)) } else { unreachable!() }