Skip to content

Commit

Permalink
fixup! [substrait] Add support for ExtensionTable
Browse files Browse the repository at this point in the history
  • Loading branch information
ccciudatu committed Jan 8, 2025
1 parent 20ba857 commit 64fb0e5
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 84 deletions.
24 changes: 14 additions & 10 deletions datafusion/expr/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub trait SerializerRegistry: Debug + Send + Sync {
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>> {
) -> Result<NamedBytes> {
not_impl_err!(
"Serializing user defined logical plan node `{}` is not supported",
node.name()
Expand All @@ -151,17 +151,16 @@ pub trait SerializerRegistry: Debug + Send + Sync {
)
}

/// Serialized table definition for UDTFs or manually registered table providers that can't be
/// marshaled by reference. Should return some benign error for regular tables that can be
/// found/restored by name in the destination execution context.
fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result<Vec<u8>> {
not_impl_err!("No custom table support")
/// Serialized table definition for UDTFs or some other table provider implementation that
/// can't be marshaled by reference.
fn serialize_custom_table(
&self,
_table: &dyn TableSource,
) -> Result<Option<NamedBytes>> {
Ok(None)
}

/// Deserialize the custom table with the given name.
/// Note: more often than not, the name can't be used as a discriminator if multiple different
/// `TableSource` and/or `TableProvider` implementations are expected (this is particularly true
/// for UDTFs in DataFusion, which are always registered under the same name: `tmp_table`).
/// Deserialize a custom table.
fn deserialize_custom_table(
&self,
name: &str,
Expand All @@ -171,6 +170,11 @@ pub trait SerializerRegistry: Debug + Send + Sync {
}
}

/// A sequence of bytes with a string qualifier. Meant to encapsulate serialized extensions
/// that need to carry their type, e.g. the `type_url` for protobuf messages.
#[derive(Debug, Clone)]
pub struct NamedBytes(pub String, pub Vec<u8>);

/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s
#[derive(Default, Debug)]
pub struct MemoryFunctionRegistry {
Expand Down
80 changes: 41 additions & 39 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};

use datafusion::logical_expr::{
Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension,
LogicalPlan, Operator, Projection, SortExpr, Subquery, TableScan, TryCast, Values,
LogicalPlan, Operator, Projection, SortExpr, Subquery, TableSource, TryCast, Values,
};
use substrait::proto::aggregate_rel::Grouping;
use substrait::proto::expression as substrait_expression;
Expand Down Expand Up @@ -462,9 +462,7 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
fn consume_extension_table(
&self,
extension_table: &ExtensionTable,
_schema: &DFSchema,
_projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
) -> Result<Arc<dyn TableSource>> {
if let Some(ext_detail) = extension_table.detail.as_ref() {
substrait_err!(
"Missing handler for extension table: {}",
Expand Down Expand Up @@ -599,24 +597,11 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
fn consume_extension_table(
&self,
extension_table: &ExtensionTable,
schema: &DFSchema,
projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
) -> Result<Arc<dyn TableSource>> {
if let Some(ext_detail) = &extension_table.detail {
let source = self
.state
self.state
.serializer_registry()
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?;
let table_name = ext_detail
.type_url
.rsplit_once('/')
.map(|(_, name)| name)
.unwrap_or(&ext_detail.type_url);
let table_scan = TableScan::try_new(table_name, source, None, vec![], None)?;
let plan = LogicalPlan::TableScan(table_scan);
ensure_schema_compatibility(plan.schema(), schema.clone())?;
let schema = apply_masking(schema.clone(), projection)?;
apply_projection(plan, schema)
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)
} else {
substrait_err!("Unexpected empty detail in ExtensionTable")
}
Expand Down Expand Up @@ -1366,26 +1351,14 @@ pub async fn from_read_rel(
read: &ReadRel,
) -> Result<LogicalPlan> {
async fn read_with_schema(
consumer: &impl SubstraitConsumer,
table_ref: TableReference,
table_source: Arc<dyn TableSource>,
schema: DFSchema,
projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
let schema = schema.replace_qualifier(table_ref.clone());

let plan = {
let provider = match consumer.resolve_table_ref(&table_ref).await? {
Some(ref provider) => Arc::clone(provider),
_ => return plan_err!("No table named '{table_ref}'"),
};

LogicalPlanBuilder::scan(
table_ref,
provider_as_source(Arc::clone(&provider)),
None,
)?
.build()?
};
let plan = { LogicalPlanBuilder::scan(table_ref, table_source, None)?.build()? };

ensure_schema_compatibility(plan.schema(), schema.clone())?;

Expand All @@ -1394,6 +1367,17 @@ pub async fn from_read_rel(
apply_projection(plan, schema)
}

async fn table_source(
consumer: &impl SubstraitConsumer,
table_ref: &TableReference,
) -> Result<Arc<dyn TableSource>> {
if let Some(provider) = consumer.resolve_table_ref(table_ref).await? {
Ok(provider_as_source(provider))
} else {
plan_err!("No table named '{table_ref}'")
}
}

let named_struct = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for Read Relation")
})?;
Expand All @@ -1419,10 +1403,10 @@ pub async fn from_read_rel(
table: nt.names[2].clone().into(),
},
};

let table_source = table_source(consumer, &table_reference).await?;
read_with_schema(
consumer,
table_reference,
table_source,
substrait_schema,
&read.projection,
)
Expand Down Expand Up @@ -1501,17 +1485,35 @@ pub async fn from_read_rel(
let name = filename.unwrap();
// directly use unwrap here since we could determine it is a valid one
let table_reference = TableReference::Bare { table: name.into() };
let table_source = table_source(consumer, &table_reference).await?;

read_with_schema(
consumer,
table_reference,
table_source,
substrait_schema,
&read.projection,
)
.await
}
Some(ReadType::ExtensionTable(ext)) => {
consumer.consume_extension_table(ext, &substrait_schema, &read.projection)
// look for the original table name under `rel.common.hint.alias`
// in case the producer was kind enough to put it there.
let name_hint = read
.common
.as_ref()
.and_then(|rel_common| rel_common.hint.as_ref())
.map(|hint| hint.alias.as_str().trim())
.filter(|alias| !alias.is_empty());
// if no name hint was provided, use the name that datafusion
// sets for UDTFs
let table_name = name_hint.unwrap_or("tmp_table");
read_with_schema(
TableReference::from(table_name),
consumer.consume_extension_table(ext)?,
substrait_schema,
&read.projection,
)
.await
}
None => {
substrait_err!("Unexpected empty read_type")
Expand Down Expand Up @@ -1917,7 +1919,7 @@ pub async fn from_substrait_sorts(
},
None => not_impl_err!("Sort without sort kind is invalid"),
};
let (asc, nulls_first) = asc_nullfirst.unwrap();
let (asc, nulls_first) = asc_nullfirst?;
sorts.push(Sort {
expr,
asc,
Expand Down
94 changes: 61 additions & 33 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ use std::sync::Arc;
use substrait::proto::expression_reference::ExprType;

use datafusion::arrow::datatypes::{Field, IntervalUnit};
use datafusion::logical_expr::{Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, TableSource, TryCast, Union, Values, Window, WindowFrameUnits};
use datafusion::logical_expr::{
Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit,
Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan,
TableSource, TryCast, Union, Values, Window, WindowFrameUnits,
};
use datafusion::{
arrow::datatypes::{DataType, TimeUnit},
error::{DataFusionError, Result},
Expand Down Expand Up @@ -50,9 +54,10 @@ use datafusion::execution::SessionState;
use datafusion::logical_expr::expr::{
Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction,
};
use datafusion::logical_expr::registry::NamedBytes;
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
use datafusion::prelude::Expr;
use pbjson_types::{Any as ProtoAny, Any};
use pbjson_types::Any as ProtoAny;
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
use substrait::proto::expression::cast::FailureBehavior;
use substrait::proto::expression::field_reference::{RootReference, RootType};
Expand All @@ -66,8 +71,8 @@ use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::expression::ScalarFunction;
use substrait::proto::read_rel::{ExtensionTable, VirtualTable};
use substrait::proto::rel_common::EmitKind;
use substrait::proto::rel_common::EmitKind::Emit;
use substrait::proto::rel_common::{EmitKind, Hint};
use substrait::proto::{
fetch_rel, rel_common, ExchangeRel, ExpressionReference, ExtendedExpression,
RelCommon,
Expand Down Expand Up @@ -363,10 +368,10 @@ pub trait SubstraitProducer: Send + Sync + Sized {
from_in_subquery(self, in_subquery, schema)
}

fn handle_extension_table(
fn handle_custom_table(
&mut self,
_table: &dyn TableSource,
) -> Result<ExtensionTable> {
) -> Result<Option<ExtensionTable>> {
not_impl_err!("Not implemented")
}
}
Expand Down Expand Up @@ -395,12 +400,12 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> {
}

fn handle_extension(&mut self, plan: &Extension) -> Result<Box<Rel>> {
let extension_bytes = self
let NamedBytes(type_url, bytes) = self
.serializer_registry
.serialize_logical_plan(plan.node.as_ref())?;
let detail = ProtoAny {
type_url: plan.node.name().to_string(),
value: extension_bytes.into(),
type_url,
value: bytes.into(),
};
let mut inputs_rel = plan
.node
Expand Down Expand Up @@ -429,14 +434,22 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> {
}))
}

fn handle_extension_table(&mut self, table: &dyn TableSource) -> Result<ExtensionTable> {
let bytes = self.serializer_registry.serialize_custom_table(table)?;
Ok(ExtensionTable {
detail: Some(Any {
type_url: "/substrait.ExtensionTable".into(),
value: bytes.into(),
})
})
fn handle_custom_table(
&mut self,
table: &dyn TableSource,
) -> Result<Option<ExtensionTable>> {
if let Some(NamedBytes(type_url, bytes)) =
self.serializer_registry.serialize_custom_table(table)?
{
Ok(Some(ExtensionTable {
detail: Some(ProtoAny {
type_url,
value: bytes.into(),
}),
}))
} else {
Ok(None)
}
}
}

Expand Down Expand Up @@ -572,21 +585,32 @@ pub fn from_table_scan(
let table_schema = scan.source.schema().to_dfschema_ref()?;
let base_schema = to_substrait_named_struct(&table_schema)?;

let table = if let Ok(ext_table) = producer
.handle_extension_table(scan.source.as_ref())
{
ReadType::ExtensionTable(ext_table)
} else {
ReadType::NamedTable(NamedTable {
names: scan.table_name.to_vec(),
advanced_extension: None,
})
};

let (table, common) =
if let Ok(Some(ext_table)) = producer.handle_custom_table(scan.source.as_ref()) {
(
ReadType::ExtensionTable(ext_table),
Some(RelCommon {
hint: Some(Hint {
// store the original table name as rel.common.hint.alias
alias: scan.table_name.to_string(),
..Default::default()
}),
..Default::default()
}),
)
} else {
(
ReadType::NamedTable(NamedTable {
names: scan.table_name.to_vec(),
advanced_extension: None,
}),
None,
)
};

Ok(Box::new(Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
common,
base_schema: Some(base_schema),
filter: None,
best_effort_filter: None,
Expand Down Expand Up @@ -1715,7 +1739,7 @@ pub fn from_in_subquery(
subquery_type: Some(
substrait::proto::expression::subquery::SubqueryType::InPredicate(
Box::new(InPredicate {
needles: (vec![substrait_expr]),
needles: vec![substrait_expr],
haystack: Some(subquery_plan),
}),
),
Expand Down Expand Up @@ -2909,6 +2933,7 @@ mod test {
#[tokio::test]
async fn round_trip_extension_table() {
const TABLE_NAME: &str = "custom_table";
const TYPE_URL: &str = "/substrait.test.CustomTable";
const SERIALIZED: &[u8] = "table definition".as_bytes();

fn custom_table() -> Arc<dyn TableProvider> {
Expand All @@ -2921,9 +2946,12 @@ mod test {
#[derive(Debug)]
struct Registry;
impl SerializerRegistry for Registry {
fn serialize_custom_table(&self, table: &dyn TableSource) -> Result<Vec<u8>> {
fn serialize_custom_table(
&self,
table: &dyn TableSource,
) -> Result<Option<NamedBytes>> {
if table.schema() == custom_table().schema() {
Ok(SERIALIZED.to_vec())
Ok(Some(NamedBytes(TYPE_URL.to_string(), SERIALIZED.to_vec())))
} else {
Err(DataFusionError::Internal("Not our table".into()))
}
Expand All @@ -2933,7 +2961,7 @@ mod test {
name: &str,
bytes: &[u8],
) -> Result<Arc<dyn TableSource>> {
if name == TABLE_NAME && bytes == SERIALIZED {
if name == TYPE_URL && bytes == SERIALIZED {
Ok(Arc::new(DefaultTableSource::new(custom_table())))
} else {
panic!("Unexpected extension table: {name}");
Expand Down Expand Up @@ -2965,7 +2993,7 @@ mod test {
assert_contains!(
// confirm that the Substrait plan contains our custom_table as an ExtensionTable
serde_json::to_string(substrait.as_ref()).unwrap(),
format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TABLE_NAME}","#)
format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TYPE_URL}","#)
);
remote // make sure the restored plan is fully working in the remote context
.execute_logical_plan(restored.clone())
Expand Down
Loading

0 comments on commit 64fb0e5

Please sign in to comment.