diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 3455cce132b62..96b7c05d67251 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -63,7 +63,7 @@ use datafusion_expr::{ expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, planner::ExprPlanner, - Expr, UserDefinedLogicalNode, WindowUDF, + Expr, WindowUDF, }; // backwards compatibility @@ -1679,27 +1679,7 @@ pub enum RegisterFunction { #[derive(Debug)] pub struct EmptySerializerRegistry; -impl SerializerRegistry for EmptySerializerRegistry { - fn serialize_logical_plan( - &self, - node: &dyn UserDefinedLogicalNode, - ) -> Result> { - not_impl_err!( - "Serializing user defined logical plan node `{}` is not supported", - node.name() - ) - } - - fn deserialize_logical_plan( - &self, - name: &str, - _bytes: &[u8], - ) -> Result> { - not_impl_err!( - "Deserializing user defined logical plan node `{name}` is not supported" - ) - } -} +impl SerializerRegistry for EmptySerializerRegistry {} /// Describes which SQL statements can be run. /// diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 4eb49710bcf85..588181b144218 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -19,7 +19,7 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; -use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; +use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result}; use std::collections::HashSet; use std::fmt::Debug; @@ -123,22 +123,52 @@ pub trait FunctionRegistry { } } -/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. +/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode] +/// and custom table providers for which the name alone is meaningless in the target +/// execution context, e.g. UDTFs, manually registered tables etc. pub trait SerializerRegistry: Debug + Send + Sync { /// Serialize this node to a byte array. This serialization should not include /// input plans. fn serialize_logical_plan( &self, node: &dyn UserDefinedLogicalNode, - ) -> Result>; + ) -> Result> { + not_impl_err!( + "Serializing user defined logical plan node `{}` is not supported", + node.name() + ) + } /// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from /// bytes. fn deserialize_logical_plan( &self, name: &str, - bytes: &[u8], - ) -> Result>; + _bytes: &[u8], + ) -> Result> { + not_impl_err!( + "Deserializing user defined logical plan node `{name}` is not supported" + ) + } + + /// Serialized table definition for UDTFs or manually registered table providers that can't be + /// marshaled by reference. Should return some benign error for regular tables that can be + /// found/restored by name in the destination execution context. + fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result> { + not_impl_err!("No custom table support") + } + + /// Deserialize the custom table with the given name. + /// Note: more often than not, the name can't be used as a discriminator if multiple different + /// `TableSource` and/or `TableProvider` implementations are expected (this is particularly true + /// for UDTFs in DataFusion, which are always registered under the same name: `tmp_table`). + fn deserialize_custom_table( + &self, + name: &str, + _bytes: &[u8], + ) -> Result> { + not_impl_err!("Deserializing custom table `{name}` is not supported") + } } /// A [`FunctionRegistry`] that uses in memory [`HashMap`]s diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 515553152659a..3c417ba423c93 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension, - LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values, + LogicalPlan, Operator, Projection, SortExpr, Subquery, TableScan, TryCast, Values, }; use substrait::proto::aggregate_rel::Grouping; use substrait::proto::expression as substrait_expression; @@ -86,6 +86,7 @@ use substrait::proto::expression::{ SingularOrList, SwitchExpression, WindowFunction, }; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; +use substrait::proto::read_rel::ExtensionTable; use substrait::proto::rel_common::{Emit, EmitKind}; use substrait::proto::set_rel::SetOp; use substrait::proto::{ @@ -438,6 +439,22 @@ pub trait SubstraitConsumer: Send + Sync + Sized { user_defined_literal.type_reference ) } + + fn consume_extension_table( + &self, + extension_table: &ExtensionTable, + _schema: &DFSchema, + _projection: &Option, + ) -> Result { + if let Some(ext_detail) = extension_table.detail.as_ref() { + substrait_err!( + "Missing handler for extension table: {}", + &ext_detail.type_url + ) + } else { + substrait_err!("Unexpected empty detail in ExtensionTable") + } + } } /// Convert Substrait Rel to DataFusion DataFrame @@ -559,6 +576,32 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; Ok(LogicalPlan::Extension(Extension { node: plan })) } + + fn consume_extension_table( + &self, + extension_table: &ExtensionTable, + schema: &DFSchema, + projection: &Option, + ) -> Result { + if let Some(ext_detail) = &extension_table.detail { + let source = self + .state + .serializer_registry() + .deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?; + let table_name = ext_detail + .type_url + .rsplit_once('/') + .map(|(_, name)| name) + .unwrap_or(&ext_detail.type_url); + let table_scan = TableScan::try_new(table_name, source, None, vec![], None)?; + let plan = LogicalPlan::TableScan(table_scan); + ensure_schema_compatibility(plan.schema(), schema.clone())?; + let schema = apply_masking(schema.clone(), projection)?; + apply_projection(plan, schema) + } else { + substrait_err!("Unexpected empty detail in ExtensionTable") + } + } } // Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which @@ -1449,8 +1492,11 @@ pub async fn from_read_rel( ) .await } - _ => { - not_impl_err!("Unsupported ReadType: {:?}", read.read_type) + Some(ReadType::ExtensionTable(ext)) => { + consumer.consume_extension_table(ext, &substrait_schema, &read.projection) + } + None => { + substrait_err!("Unexpected empty read_type") } } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index b73d246e19899..8d5c5fe6f0aa2 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -63,7 +63,7 @@ use substrait::proto::expression::literal::{ }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::read_rel::VirtualTable; +use substrait::proto::read_rel::{ExtensionTable, VirtualTable}; use substrait::proto::rel_common::EmitKind; use substrait::proto::rel_common::EmitKind::Emit; use substrait::proto::{ @@ -211,6 +211,23 @@ pub fn to_substrait_rel( let table_schema = scan.source.schema().to_dfschema_ref()?; let base_schema = to_substrait_named_struct(&table_schema)?; + let table = if let Ok(bytes) = state + .serializer_registry() + .serialize_custom_table(scan.source.as_ref()) + { + ReadType::ExtensionTable(ExtensionTable { + detail: Some(ProtoAny { + type_url: scan.table_name.to_string(), + value: bytes.into(), + }), + }) + } else { + ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + }) + }; + Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, @@ -219,10 +236,7 @@ pub fn to_substrait_rel( best_effort_filter: None, projection, advanced_extension: None, - read_type: Some(ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - })), + read_type: Some(table), }))), })) } @@ -2238,8 +2252,8 @@ mod test { use super::*; use crate::logical_plan::consumer::{ from_substrait_extended_expr, from_substrait_literal_without_names, - from_substrait_named_struct, from_substrait_type_without_names, - DefaultSubstraitConsumer, + from_substrait_named_struct, from_substrait_plan, + from_substrait_type_without_names, DefaultSubstraitConsumer, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow::array::{ @@ -2247,8 +2261,12 @@ mod test { }; use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; - use datafusion::common::DFSchema; + use datafusion::common::{assert_contains, DFSchema}; + use datafusion::datasource::empty::EmptyTable; + use datafusion::datasource::{DefaultTableSource, TableProvider}; use datafusion::execution::{SessionState, SessionStateBuilder}; + use datafusion::logical_expr::registry::SerializerRegistry; + use datafusion::logical_expr::TableSource; use datafusion::prelude::SessionContext; use std::sync::OnceLock; @@ -2585,4 +2603,110 @@ mod test { assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); } + + #[tokio::test] + async fn round_trip_extension_table() { + const TABLE_NAME: &str = "custom_table"; + const SERIALIZED: &[u8] = "table definition".as_bytes(); + + fn custom_table() -> Arc { + Arc::new(EmptyTable::new(Arc::new(Schema::new([ + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(Field::new("name", DataType::Utf8, false)), + ])))) + } + + #[derive(Debug)] + struct Registry; + impl SerializerRegistry for Registry { + fn serialize_custom_table(&self, table: &dyn TableSource) -> Result> { + if table.schema() == custom_table().schema() { + Ok(SERIALIZED.to_vec()) + } else { + Err(DataFusionError::Internal("Not our table".into())) + } + } + fn deserialize_custom_table( + &self, + name: &str, + bytes: &[u8], + ) -> Result> { + if name == TABLE_NAME && bytes == SERIALIZED { + Ok(Arc::new(DefaultTableSource::new(custom_table()))) + } else { + panic!("Unexpected extension table: {name}"); + } + } + } + + async fn round_trip_logical_plans( + local: &SessionContext, + remote: &SessionContext, + ) -> Result<()> { + local.register_table(TABLE_NAME, custom_table())?; + remote.table_provider(TABLE_NAME).await.expect_err( + "The remote context is not supposed to know about custom_table", + ); + let initial_plan = local + .sql(&format!("select id from {TABLE_NAME}")) + .await? + .logical_plan() + .clone(); + + // write substrait locally + let substrait = to_substrait_plan(&initial_plan, &local.state())?; + + // read substrait remotely + // since we know there's no `custom_table` registered in the remote context, this will only succeed + // if our table got encoded as an ExtensionTable and is now decoded back to a table source. + let restored = from_substrait_plan(&remote.state(), &substrait).await?; + assert_contains!( + // confirm that the Substrait plan contains our custom_table as an ExtensionTable + serde_json::to_string(substrait.as_ref()).unwrap(), + format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TABLE_NAME}","#) + ); + remote // make sure the restored plan is fully working in the remote context + .execute_logical_plan(restored.clone()) + .await? + .collect() + .await + .expect("Restored plan cannot be executed remotely"); + assert_eq!( + // check that the restored plan is functionally equivalent (and almost identical) to the initial one + initial_plan.to_string(), + restored.to_string().replace( + // substrait will add an explicit full-schema projection if the original table had none + &format!("TableScan: {TABLE_NAME} projection=[id, name]"), + &format!("TableScan: {TABLE_NAME}"), + ) + ); + Ok(()) + } + + // take 1 + let failed_attempt = + round_trip_logical_plans(&SessionContext::new(), &SessionContext::new()) + .await + .expect_err( + "The round trip should fail in the absence of a SerializerRegistry", + ); + assert_contains!( + failed_attempt.message(), + format!("No table named '{TABLE_NAME}'") + ); + + // take 2 + fn proper_context() -> SessionContext { + SessionContext::new_with_state( + SessionStateBuilder::new() + // This will transport our custom_table as a Substrait ExtensionTable + .with_serializer_registry(Arc::new(Registry)) + .build(), + ) + } + + round_trip_logical_plans(&proper_context(), &proper_context()) + .await + .expect("Local plan could not be restored remotely"); + } }