diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 90995c1d116a..27fc08592575 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -88,6 +88,12 @@ version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +[[package]] +name = "anyhow" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" + [[package]] name = "apache-avro" version = "0.16.0" @@ -244,6 +250,34 @@ dependencies = [ "num", ] +[[package]] +name = "arrow-flight" +version = "52.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e7ffbc96072e466ae5188974725bb46757587eafe427f77a25b828c375ae882" +dependencies = [ + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", + "base64 0.22.1", + "bytes", + "futures", + "once_cell", + "paste", + "prost", + "prost-types", + "tokio", + "tonic", +] + [[package]] name = "arrow-ipc" version = "52.2.0" @@ -379,6 +413,28 @@ dependencies = [ "zstd-safe 7.2.1", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.74", +] + [[package]] name = "async-trait" version = "0.1.81" @@ -712,6 +768,51 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +dependencies = [ + "async-trait", + "axum-core", + "bitflags 1.3.2", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.30", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper 0.1.2", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.73" @@ -1135,10 +1236,12 @@ dependencies = [ "apache-avro", "arrow", "arrow-array", + "arrow-flight", "arrow-ipc", "arrow-schema", "async-compression", "async-trait", + "base64 0.22.1", "bytes", "bzip2", "chrono", @@ -1173,11 +1276,13 @@ dependencies = [ "parquet", "paste", "pin-project-lite", + "prost", "rand", "sqlparser", "tempfile", "tokio", "tokio-util", + "tonic", "url", "uuid", "xz2", @@ -2094,6 +2199,18 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper 0.14.30", + "pin-project-lite", + "tokio", + "tokio-io-timeout", +] + [[package]] name = "hyper-util" version = "0.1.7" @@ -2401,6 +2518,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "md-5" version = "0.10.6" @@ -2886,6 +3009,38 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" +dependencies = [ + "anyhow", + "itertools 0.12.1", + "proc-macro2", + "quote", + "syn 2.0.74", +] + +[[package]] +name = "prost-types" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0" +dependencies = [ + "prost", +] + [[package]] name = "quad-rand" version = "0.2.1" @@ -3086,7 +3241,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 1.0.1", "tokio", "tokio-rustls 0.26.0", "tokio-util", @@ -3629,6 +3784,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "sync_wrapper" version = "1.0.1" @@ -3772,6 +3933,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "tokio-io-timeout" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" +dependencies = [ + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-macros" version = "2.4.0" @@ -3829,6 +4000,33 @@ dependencies = [ "tokio", ] +[[package]] +name = "tonic" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76c4eb7a4e9ef9d4763600161f12f5070b92a578e1b634db88a6887844c91a13" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64 0.21.7", + "bytes", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.30", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower" version = "0.4.13" @@ -3837,9 +4035,13 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", + "indexmap 1.9.3", "pin-project", "pin-project-lite", + "rand", + "slab", "tokio", + "tokio-util", "tower-layer", "tower-service", "tracing", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index cbd9ffd0feba..d078868506d1 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -44,6 +44,7 @@ datafusion = { path = "../datafusion/core", version = "41.0.0", features = [ "regex_expressions", "unicode_expressions", "compression", + "flight", ] } dirs = "4.0.0" env_logger = "0.9" diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index db4242d97175..223d318aefbf 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -42,6 +42,7 @@ use datafusion::physical_plan::{collect, execute_stream, ExecutionPlanProperties use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; +use datafusion::datasource::flight::config::FlightOptions; use datafusion::sql::sqlparser; use rustyline::error::ReadlineError; use rustyline::Editor; @@ -386,6 +387,8 @@ pub(crate) async fn register_object_store_and_config_extensions( let mut table_options = ctx.session_state().default_table_options().clone(); if let Some(format) = format { table_options.set_config_format(format); + } else { + table_options.extensions.insert(FlightOptions::default()) } table_options.alter_with_string_hash_map(options)?; diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index e678c93ede8b..ccb66a697e1e 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -59,10 +59,12 @@ default = [ "unicode_expressions", "compression", "parquet", + "flight", ] encoding_expressions = ["datafusion-functions/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = ["datafusion-physical-plan/force_hash_collisions", "datafusion-common/force_hash_collisions"] +flight = ["dep:arrow-flight", "dep:base64", "dep:prost", "dep:tonic"] math_expressions = ["datafusion-functions/math_expressions"] parquet = ["datafusion-common/parquet", "dep:parquet"] pyarrow = ["datafusion-common/pyarrow", "parquet"] @@ -83,6 +85,7 @@ ahash = { workspace = true } apache-avro = { version = "0.16", optional = true } arrow = { workspace = true } arrow-array = { workspace = true } +arrow-flight = { workspace = true, optional = true } arrow-ipc = { workspace = true } arrow-schema = { workspace = true } async-compression = { version = "0.4.0", features = [ @@ -94,6 +97,7 @@ async-compression = { version = "0.4.0", features = [ "tokio", ], optional = true } async-trait = { workspace = true } +base64 = { version = "0.22", optional = true } bytes = { workspace = true } bzip2 = { version = "0.4.3", optional = true } chrono = { workspace = true } @@ -128,11 +132,13 @@ parking_lot = { workspace = true } parquet = { workspace = true, optional = true, default-features = true } paste = "1.0.15" pin-project-lite = "^0.2.7" +prost = { version = "0.12", optional = true } rand = { workspace = true } sqlparser = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true } tokio-util = { version = "0.7.4", features = ["io"], optional = true } +tonic = { version = "0.11", optional = true } url = { workspace = true } uuid = { version = "1.7", features = ["v4"] } xz2 = { version = "0.1", optional = true, features = ["static"] } @@ -161,6 +167,7 @@ test-utils = { path = "../../test-utils" } thiserror = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } tokio-postgres = "0.7.7" +tokio-stream = { version = "0.1.15", features = ["net"] } [target.'cfg(not(target_os = "windows"))'.dev-dependencies] nix = { version = "0.29.0", features = ["fs"] } diff --git a/datafusion/core/src/datasource/flight/config.rs b/datafusion/core/src/datasource/flight/config.rs new file mode 100644 index 000000000000..5737a03a2566 --- /dev/null +++ b/datafusion/core/src/datasource/flight/config.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Only meant for registering the `flight` namespace with `datafusion-cli` + +use datafusion_common::config::{ConfigEntry, ConfigExtension, ExtensionOptions}; +use std::any::Any; +use std::collections::HashMap; + +/// Collects and reports back config entries. Only used to persuade `datafusion-cli` +/// to accept the `flight.` prefix for `CREATE EXTERNAL TABLE` options. +#[derive(Default, Debug, Clone)] +pub struct FlightOptions { + inner: HashMap, +} + +impl ConfigExtension for FlightOptions { + const PREFIX: &'static str = "flight"; +} + +impl ExtensionOptions for FlightOptions { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn cloned(&self) -> Box { + Box::new(self.clone()) + } + + fn set(&mut self, key: &str, value: &str) -> datafusion_common::Result<()> { + self.inner.insert(key.into(), value.into()); + Ok(()) + } + + fn entries(&self) -> Vec { + self.inner + .iter() + .map(|(key, value)| ConfigEntry { + key: key.to_owned(), + value: Some(value.to_owned()).filter(|s| !s.is_empty()), + description: "", + }) + .collect() + } +} diff --git a/datafusion/core/src/datasource/flight/mod.rs b/datafusion/core/src/datasource/flight/mod.rs new file mode 100644 index 000000000000..25dfcf320713 --- /dev/null +++ b/datafusion/core/src/datasource/flight/mod.rs @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Generic [FlightTableFactory] that can connect to Arrow Flight services, +//! with a [sql::FlightSqlDriver] provided out-of-the-box. + +use std::any::Any; +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::Arc; + +use arrow_flight::error::FlightError; +use arrow_flight::FlightInfo; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use tonic::metadata::MetadataMap; +use tonic::transport::Channel; + +use datafusion_catalog::{Session, TableProvider, TableProviderFactory}; +use datafusion_common::{project_schema, DataFusionError}; +use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr::Partitioning::UnknownPartitioning; +use datafusion_physical_plan::{ExecutionMode, ExecutionPlan, PlanProperties}; + +use crate::datasource::physical_plan::FlightExec; + +pub mod config; +pub mod sql; + +/// Generic Arrow Flight data source. Requires a [FlightDriver] that allows implementors +/// to integrate any custom Flight RPC service by producing a [FlightMetadata] for some DDL. +/// +/// # Sample usage: +/// ``` +/// use std::collections::HashMap; +/// use arrow_flight::{FlightClient, FlightDescriptor}; +/// use tonic::transport::Channel; +/// use datafusion::datasource::flight::{FlightMetadata, FlightDriver}; +/// use datafusion::prelude::SessionContext; +/// use std::sync::Arc; +/// use datafusion::datasource::flight::FlightTableFactory; +/// +/// #[derive(Debug, Clone, Default)] +/// struct CustomFlightDriver {} +/// #[async_trait::async_trait] +/// impl FlightDriver for CustomFlightDriver { +/// async fn metadata(&self, channel: Channel, opts: &HashMap) +/// -> arrow_flight::error::Result { +/// let mut client = FlightClient::new(channel); +/// // the `flight.` prefix is an already registered namespace in datafusion-cli +/// let descriptor = FlightDescriptor::new_cmd(opts["flight.command"].clone()); +/// let flight_info = client.get_flight_info(descriptor).await?; +/// FlightMetadata::try_from(flight_info) +/// } +/// } +/// +/// #[tokio::main] +/// async fn main() -> datafusion_common::Result<()> { +/// let ctx = SessionContext::new(); +/// ctx.state_ref().write().table_factories_mut() +/// .insert("CUSTOM_FLIGHT".into(), Arc::new(FlightTableFactory::new( +/// Arc::new(CustomFlightDriver::default()) +/// ))); +/// let _ = ctx.sql(r#" +/// CREATE EXTERNAL TABLE custom_flight_table STORED AS CUSTOM_FLIGHT +/// LOCATION 'https://custom.flight.rpc' +/// OPTIONS ('flight.command' 'select * from everywhere') +/// "#).await; // will fail as it can't connect to the bogus URL, but we ignore the error +/// Ok(()) +/// } +/// +/// ``` +#[derive(Clone, Debug)] +pub struct FlightTableFactory { + driver: Arc, +} + +impl FlightTableFactory { + /// Create a data source using the provided driver + pub fn new(driver: Arc) -> Self { + Self { driver } + } + + /// Convenient way to create a [FlightTable] programatically, as an alternative to DDL. + pub async fn open_table( + &self, + entry_point: impl Into, + options: HashMap, + ) -> datafusion_common::Result { + let origin = entry_point.into(); + let channel = Channel::from_shared(origin.clone()) + .unwrap() + .connect() + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let metadata = self + .driver + .metadata(channel.clone(), &options) + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let logical_schema = metadata.plan_properties.eq_properties.schema().clone(); + Ok(FlightTable { + driver: self.driver.clone(), + channel, + options, + origin, + logical_schema, + }) + } +} + +#[async_trait] +impl TableProviderFactory for FlightTableFactory { + async fn create( + &self, + _state: &dyn Session, + cmd: &CreateExternalTable, + ) -> datafusion_common::Result> { + let table = self.open_table(&cmd.location, cmd.options.clone()).await?; + Ok(Arc::new(table)) + } +} + +/// Extension point for integrating any Flight RPC service as a [FlightTableFactory]. +/// Handles the initial `GetFlightInfo` call and all its prerequisites (such as `Handshake`), +/// to produce a [FlightMetadata]. +#[async_trait] +pub trait FlightDriver: Sync + Send + Debug { + /// Returns a [FlightMetadata] from the specified channel, + /// according to the provided table options. + /// The driver must provide at least a [FlightInfo] in order to construct a flight metadata. + async fn metadata( + &self, + channel: Channel, + options: &HashMap, + ) -> arrow_flight::error::Result; +} + +/// The information that a [FlightDriver] must produce +/// in order to register flights as DataFusion tables. +#[derive(Clone, Debug)] +pub struct FlightMetadata { + /// FlightInfo object produced by the driver + pub(super) flight_info: Arc, + /// Physical plan properties. Sensible defaults will be used if the + /// driver doesn't need (or care) to customize the execution plan. + pub(super) plan_properties: Arc, + /// The gRPC headers to use on the `DoGet` calls + pub(super) grpc_metadata: Arc, +} + +impl FlightMetadata { + /// Provide custom [PlanProperties] to account for service specifics, + /// such as known partitioning scheme, unbounded execution mode etc. + pub fn new(info: FlightInfo, props: PlanProperties, grpc: MetadataMap) -> Self { + Self { + flight_info: Arc::new(info), + plan_properties: Arc::new(props), + grpc_metadata: Arc::new(grpc), + } + } + + /// Uses the default [PlanProperties] and infers the schema from the FlightInfo response. + pub fn try_new( + info: FlightInfo, + grpc: MetadataMap, + ) -> arrow_flight::error::Result { + let schema = Arc::new(info.clone().try_decode_schema()?); + let partitions = info.endpoint.len(); + let props = PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + UnknownPartitioning(partitions), + ExecutionMode::Bounded, + ); + Ok(Self::new(info, props, grpc)) + } + + fn with_physical_schema(self, schema: SchemaRef) -> Self { + let eq_props = EquivalenceProperties::new_with_orderings( + schema, + &self.plan_properties.eq_properties.oeq_class().orderings[..], + ); + let pp = PlanProperties::new( + eq_props, + self.plan_properties.partitioning.clone(), + self.plan_properties.execution_mode, + ); + Self { + flight_info: self.flight_info, + plan_properties: Arc::new(pp), + grpc_metadata: self.grpc_metadata, + } + } +} + +/// Uses the default [PlanProperties] and no custom gRPC metadata entries +impl TryFrom for FlightMetadata { + type Error = FlightError; + + fn try_from(info: FlightInfo) -> Result { + Self::try_new(info, MetadataMap::default()) + } +} + +/// Table provider that wraps a specific flight from an Arrow Flight service +pub struct FlightTable { + driver: Arc, + channel: Channel, + options: HashMap, + origin: String, + logical_schema: SchemaRef, +} + +#[async_trait] +impl TableProvider for FlightTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.logical_schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::View + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> datafusion_common::Result> { + let schema = project_schema(&self.logical_schema, projection)?; + let metadata = self + .driver + .metadata(self.channel.clone(), &self.options) + .await + .map_err(|e| DataFusionError::External(Box::new(e)))? + .with_physical_schema(schema); + Ok(Arc::new(FlightExec::new(&metadata, &self.origin))) + } +} diff --git a/datafusion/core/src/datasource/flight/sql.rs b/datafusion/core/src/datasource/flight/sql.rs new file mode 100644 index 000000000000..6fb211d885ea --- /dev/null +++ b/datafusion/core/src/datasource/flight/sql.rs @@ -0,0 +1,475 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Default [FlightDriver] for Flight SQL + +use std::collections::HashMap; +use std::str::FromStr; + +use arrow_flight::error::Result; +use arrow_flight::flight_service_client::FlightServiceClient; +use arrow_flight::sql::{CommandStatementQuery, ProstMessageExt}; +use arrow_flight::{FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse}; +use arrow_schema::ArrowError; +use async_trait::async_trait; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use bytes::Bytes; +use futures::{stream, TryStreamExt}; +use prost::Message; +use tonic::metadata::{AsciiMetadataKey, MetadataMap}; +use tonic::transport::Channel; +use tonic::IntoRequest; + +use crate::datasource::flight::{FlightDriver, FlightMetadata}; + +/// Default Flight SQL driver. Requires a `flight.sql.query` to be passed as a table option. +/// If `flight.sql.username` (and optionally `flight.sql.password`) are passed, +/// will perform the `Handshake` using basic authentication. +/// Any additional headers can be passed as table options using the `flight.sql.header.` prefix. +/// +/// A [crate::datasource::flight::FlightTableFactory] using this driver is registered +/// with the default `SessionContext` under the name `FLIGHT_SQL`. +#[derive(Clone, Debug, Default)] +pub struct FlightSqlDriver {} + +#[async_trait] +impl FlightDriver for FlightSqlDriver { + async fn metadata( + &self, + channel: Channel, + options: &HashMap, + ) -> Result { + let mut client = FlightSqlClient::new(channel); + let headers = options.iter().filter_map(|(key, value)| { + key.strip_prefix("flight.sql.header.") + .map(|header_name| (header_name, value)) + }); + for header in headers { + client.set_header(header.0, header.1) + } + if let Some(username) = options.get("flight.sql.username") { + let default_password = "".to_string(); + let password = options + .get("flight.sql.password") + .unwrap_or(&default_password); + _ = client.handshake(username, password).await?; + } + let info = client + .execute(options["flight.sql.query"].clone(), None) + .await?; + let mut grpc_metadata = MetadataMap::new(); + if let Some(token) = client.token { + grpc_metadata.insert( + "authorization", + format!("Bearer {}", token).parse().unwrap(), + ); + } + FlightMetadata::try_new(info, grpc_metadata) + } +} + +///////////////////////////////////////////////////////////////////////// +// Shameless copy/paste from arrow-flight FlightSqlServiceClient +// This is only needed in order to access the bearer token received +// during handshake, as the standard client does not expose this information. +// The bearer token has to be passed to the clients that perform +// the DoGet operation, since Dremio, Ballista and possibly others +// expect the bearer token they produce with the handshake response +// to be set on all subsequent requests, including DoGet. +#[derive(Debug, Clone)] +struct FlightSqlClient { + token: Option, + headers: HashMap, + flight_client: FlightServiceClient, +} + +impl FlightSqlClient { + /// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel` + fn new(channel: Channel) -> Self { + Self::new_from_inner(FlightServiceClient::new(channel)) + } + + /// Creates a new higher level client with the provided lower level client + fn new_from_inner(inner: FlightServiceClient) -> Self { + Self { + token: None, + flight_client: inner, + headers: HashMap::default(), + } + } + + /// Perform a `handshake` with the server, passing credentials and establishing a session. + /// + /// If the server returns an "authorization" header, it is automatically parsed and set as + /// a token for future requests. Any other data returned by the server in the handshake + /// response is returned as a binary blob. + async fn handshake( + &mut self, + username: &str, + password: &str, + ) -> std::result::Result { + let cmd = HandshakeRequest { + protocol_version: 0, + payload: Default::default(), + }; + let mut req = tonic::Request::new(stream::iter(vec![cmd])); + let val = BASE64_STANDARD.encode(format!("{username}:{password}")); + let val = format!("Basic {val}") + .parse() + .map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?; + req.metadata_mut().insert("authorization", val); + let req = self.set_request_headers(req)?; + let resp = self + .flight_client + .handshake(req) + .await + .map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?; + if let Some(auth) = resp.metadata().get("authorization") { + let auth = auth.to_str().map_err(|_| { + ArrowError::ParseError("Can't read auth header".to_string()) + })?; + let bearer = "Bearer "; + if !auth.starts_with(bearer) { + Err(ArrowError::ParseError("Invalid auth header!".to_string()))?; + } + let auth = auth[bearer.len()..].to_string(); + self.token = Some(auth); + } + let responses: Vec = + resp.into_inner().try_collect().await.map_err(|_| { + ArrowError::ParseError("Can't collect responses".to_string()) + })?; + let resp = match responses.as_slice() { + [resp] => resp.payload.clone(), + [] => Bytes::new(), + _ => Err(ArrowError::ParseError( + "Multiple handshake responses".to_string(), + ))?, + }; + Ok(resp) + } + + async fn execute( + &mut self, + query: String, + transaction_id: Option, + ) -> std::result::Result { + let cmd = CommandStatementQuery { + query, + transaction_id, + }; + self.get_flight_info_for_command(cmd).await + } + + async fn get_flight_info_for_command( + &mut self, + cmd: M, + ) -> std::result::Result { + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let req = self.set_request_headers(descriptor.into_request())?; + let fi = self + .flight_client + .get_flight_info(req) + .await + .map_err(|status| ArrowError::IpcError(format!("{status:?}")))? + .into_inner(); + Ok(fi) + } + + fn set_header(&mut self, key: impl Into, value: impl Into) { + let key: String = key.into(); + let value: String = value.into(); + self.headers.insert(key, value); + } + + fn set_request_headers( + &self, + mut req: tonic::Request, + ) -> std::result::Result, ArrowError> { + for (k, v) in &self.headers { + let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| { + ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}")) + })?; + let v = v.parse().map_err(|e| { + ArrowError::ParseError(format!( + "Cannot convert header value \"{v}\": {e}" + )) + })?; + req.metadata_mut().insert(k, v); + } + if let Some(token) = &self.token { + let val = format!("Bearer {token}").parse().map_err(|e| { + ArrowError::ParseError(format!( + "Cannot convert token to header value: {e}" + )) + })?; + req.metadata_mut().insert("authorization", val); + } + Ok(req) + } +} +///////////////////////////////////////////////////////////////////////// + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::net::SocketAddr; + use std::pin::Pin; + use std::sync::Arc; + use std::time::Duration; + + use arrow_array::{Array, Float32Array, Int64Array, Int8Array, RecordBatch}; + use arrow_flight::encode::FlightDataEncoderBuilder; + use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; + use arrow_flight::sql::server::FlightSqlService; + use arrow_flight::sql::{ + CommandStatementQuery, ProstMessageExt, SqlInfo, TicketStatementQuery, + }; + use arrow_flight::{ + FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, Ticket, + }; + use arrow_schema::{DataType, Field, Schema}; + use async_trait::async_trait; + use futures::{stream, Stream, TryStreamExt}; + use prost::Message; + use tokio::net::TcpListener; + use tokio::sync::oneshot::{channel, Receiver, Sender}; + use tokio_stream::wrappers::TcpListenerStream; + use tonic::codegen::http::HeaderMap; + use tonic::codegen::tokio_stream; + use tonic::metadata::MetadataMap; + use tonic::transport::Server; + use tonic::{Extensions, Request, Response, Status, Streaming}; + + use crate::prelude::SessionContext; + + const AUTH_HEADER: &str = "authorization"; + const BEARER_TOKEN: &str = "Bearer flight-sql-token"; + + struct TestFlightSqlService { + flight_info: FlightInfo, + partition_data: RecordBatch, + expected_handshake_headers: HashMap, + expected_flight_info_query: String, + shutdown_sender: Option>, + } + + impl TestFlightSqlService { + async fn run_in_background(self, rx: Receiver<()>) -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let service = FlightServiceServer::new(self); + #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + tokio::spawn(async move { + Server::builder() + .timeout(Duration::from_secs(1)) + .add_service(service) + .serve_with_incoming_shutdown( + TcpListenerStream::new(listener), + async { + rx.await.ok(); + }, + ) + .await + .unwrap(); + }); + tokio::time::sleep(Duration::from_millis(25)).await; + addr + } + } + + impl Drop for TestFlightSqlService { + fn drop(&mut self) { + if let Some(tx) = self.shutdown_sender.take() { + tx.send(()).ok(); + } + } + } + + fn check_header( + request: &Request, + rpc: &str, + header_name: &str, + expected_value: &str, + ) { + let actual_value = request + .metadata() + .get(header_name) + .unwrap_or_else(|| panic!("[{}] missing header `{}`", rpc, header_name)) + .to_str() + .unwrap_or_else(|e| { + panic!( + "[{}] error parsing value for header `{}`: {:?}", + rpc, header_name, e + ) + }); + assert_eq!( + actual_value, expected_value, + "[{}] unexpected value for header `{}`", + rpc, header_name + ) + } + + #[async_trait] + impl FlightSqlService for TestFlightSqlService { + type FlightService = TestFlightSqlService; + + async fn do_handshake( + &self, + request: Request>, + ) -> Result< + Response< + Pin> + Send>>, + >, + Status, + > { + for (header_name, expected_value) in self.expected_handshake_headers.iter() { + check_header(&request, "do_handshake", header_name, expected_value); + } + Ok(Response::from_parts( + MetadataMap::from_headers(HeaderMap::from_iter([( + AUTH_HEADER.parse().unwrap(), + BEARER_TOKEN.parse().unwrap(), + )])), // the client should send this header back on the next request (i.e. GetFlightInfo) + Box::pin(tokio_stream::empty()), + Extensions::default(), + )) + } + + async fn get_flight_info_statement( + &self, + query: CommandStatementQuery, + request: Request, + ) -> Result, Status> { + let mut expected_flight_info_headers = + self.expected_handshake_headers.clone(); + expected_flight_info_headers.insert(AUTH_HEADER.into(), BEARER_TOKEN.into()); + for (header_name, expected_value) in expected_flight_info_headers.iter() { + check_header(&request, "get_flight_info", header_name, expected_value); + } + assert_eq!( + query.query.to_lowercase(), + self.expected_flight_info_query.to_lowercase() + ); + Ok(Response::new(self.flight_info.clone())) + } + + async fn do_get_statement( + &self, + _ticket: TicketStatementQuery, + request: Request, + ) -> Result::DoGetStream>, Status> { + let data = self.partition_data.clone(); + let rb = async move { Ok(data) }; + check_header(&request, "do_get", "authorization", BEARER_TOKEN); + let stream = FlightDataEncoderBuilder::default() + .with_schema(self.partition_data.schema()) + .build(stream::once(rb)) + .map_err(|e| Status::from_error(Box::new(e))); + + Ok(Response::new(Box::pin(stream))) + } + + async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} + } + + #[tokio::test] + async fn flight_sql_data_source() -> datafusion_common::Result<()> { + let partition_data = RecordBatch::try_new( + Arc::new(Schema::new([ + Arc::new(Field::new("col1", DataType::Float32, false)), + Arc::new(Field::new("col2", DataType::Int8, false)), + ])), + vec![ + Arc::new(Float32Array::from(vec![0.0, 0.1, 0.2, 0.3])), + Arc::new(Int8Array::from(vec![10, 20, 30, 40])), + ], + ) + .unwrap(); + let rows_per_partition = partition_data.num_rows(); + + let query = "SELECT * FROM some_table"; + let ticket_payload = TicketStatementQuery::default().as_any().encode_to_vec(); + let endpoint_archetype = + FlightEndpoint::default().with_ticket(Ticket::new(ticket_payload)); + let endpoints = vec![ + endpoint_archetype.clone(), + endpoint_archetype.clone(), + endpoint_archetype, + ]; + let num_partitions = endpoints.len(); + let flight_info = FlightInfo::default() + .try_with_schema(partition_data.schema().as_ref()) + .unwrap(); + let flight_info = endpoints + .into_iter() + .fold(flight_info, |fi, e| fi.with_endpoint(e)); + let (tx, rx) = channel(); + let service = TestFlightSqlService { + flight_info, + partition_data, + expected_handshake_headers: HashMap::from([ + (AUTH_HEADER.into(), "Basic YWRtaW46cGFzc3dvcmQ=".into()), + ("custom-hdr1".into(), "v1".into()), + ("custom-hdr2".into(), "v2".into()), + ]), + expected_flight_info_query: query.into(), + shutdown_sender: Some(tx), + }; + let port = service.run_in_background(rx).await.port(); + let ctx = SessionContext::new(); + let _ = ctx + .sql(&format!( + r#" + CREATE EXTERNAL TABLE fsql STORED AS FLIGHT_SQL + LOCATION 'http://localhost:{port}' + OPTIONS( + 'flight.sql.username' 'admin', + 'flight.sql.password' 'password', + 'flight.sql.query' '{query}', + 'flight.sql.header.custom-hdr1' 'v1', + 'flight.sql.header.custom-hdr2' 'v2', + )"# + )) + .await + .unwrap(); + let df = ctx.sql("select col1 from fsql").await.unwrap(); + df.clone().show().await?; + assert_eq!( + df.count().await.unwrap(), + rows_per_partition * num_partitions + ); + let df = ctx.sql("select sum(col2) from fsql").await?; + df.clone().show().await?; + let rb = df + .collect() + .await? + .first() + .cloned() + .expect("no record batch"); + assert_eq!(rb.schema().fields.len(), 1); + let arr = rb + .column(0) + .as_any() + .downcast_ref::() + .expect("wrong type of column"); + assert_eq!(arr.iter().next().unwrap().unwrap(), 300); + Ok(()) + } +} diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 1c9924735735..58ca50c6a68b 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -24,6 +24,8 @@ pub mod cte_worktable; pub mod default_table_source; pub mod empty; pub mod file_format; +#[cfg(feature = "flight")] +pub mod flight; pub mod function; pub mod listing; pub mod listing_table_factory; diff --git a/datafusion/core/src/datasource/physical_plan/flight.rs b/datafusion/core/src/datasource/physical_plan/flight.rs new file mode 100644 index 000000000000..630c91f1b0ee --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/flight.rs @@ -0,0 +1,262 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution plan for reading flights from Arrow Flight services + +use std::any::Any; +use std::collections::HashMap; +use std::error::Error; +use std::fmt::Formatter; +use std::sync::Arc; + +use arrow_array::RecordBatch; +use arrow_flight::error::FlightError; +use arrow_flight::{FlightClient, FlightEndpoint, Ticket}; +use arrow_schema::SchemaRef; +use futures::{StreamExt, TryStreamExt}; +use tonic::metadata::{MetadataKey, MetadataMap, MetadataValue}; +use tonic::transport::Channel; + +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; + +use crate::datasource::flight::FlightMetadata; + +/// Arrow Flight physical plan that maps flight endpoints to partitions +#[derive(Clone, Debug)] +pub struct FlightExec { + /// Visible for proto serialization + pub partitions: Vec>, + /// Visible for proto serialization + pub plan_properties: Arc, + /// Visible for proto serialization + pub grpc_metadata: Arc, +} + +/// The minimum information required for fetching a flight stream. +#[derive(Clone, Debug)] +pub struct FlightPartition { + /// Visible for proto serialization + pub locations: Vec, + /// Visible for proto serialization + pub ticket: Vec, +} + +impl FlightPartition { + fn new(endpoint: &FlightEndpoint, fallback_location: String) -> Self { + let locations = if endpoint.location.is_empty() { + vec![fallback_location] + } else { + endpoint + .location + .iter() + .map(|loc| { + if loc.uri.starts_with("arrow-flight-reuse-connection://") { + fallback_location.clone() + } else { + loc.uri.clone() + } + }) + .collect() + }; + Self { + locations, + ticket: endpoint + .ticket + .clone() + .expect("No flight ticket") + .ticket + .to_vec(), + } + } + + /// Primarily used for proto deserialization + pub fn restore(locations: Vec, ticket: Vec) -> Self { + Self { locations, ticket } + } +} + +impl FlightExec { + /// Creates a FlightExec with the provided [FlightMetadata] + /// and origin URL (used as fallback location as per the protocol spec). + pub fn new(metadata: &FlightMetadata, origin: &str) -> Self { + let partitions = metadata + .flight_info + .endpoint + .iter() + .map(|endpoint| FlightPartition::new(endpoint, origin.to_string())) + .map(Arc::new) + .collect(); + Self { + partitions, + plan_properties: metadata.plan_properties.clone(), + grpc_metadata: metadata.grpc_metadata.clone(), + } + } + + /// Primarily used for proto deserialization + pub fn restore( + partitions: Vec, + plan_properties: PlanProperties, + grpc_headers: &HashMap>, + ) -> Self { + let mut grpc_metadata = MetadataMap::new(); + for (key, value) in grpc_headers { + let text_value = String::from_utf8(value.clone()); + if text_value.is_ok() { + let text_key = MetadataKey::from_bytes(key.as_bytes()).unwrap(); + grpc_metadata.insert(text_key, text_value.unwrap().parse().unwrap()); + } else { + let binary_key = MetadataKey::from_bytes(key.as_bytes()).unwrap(); + grpc_metadata.insert_bin(binary_key, MetadataValue::from_bytes(value)); + } + } + Self { + partitions: partitions.into_iter().map(Arc::new).collect(), + plan_properties: Arc::new(plan_properties), + grpc_metadata: Arc::new(grpc_metadata), + } + } +} + +async fn flight_stream( + partition: Arc, + schema: SchemaRef, + grpc_metadata: Arc, +) -> Result { + let mut errors: Vec> = vec![]; + for loc in &partition.locations { + match try_fetch_stream( + loc, + partition.ticket.clone(), + schema.clone(), + grpc_metadata.clone(), + ) + .await + { + Ok(stream) => return Ok(stream), + Err(e) => errors.push(Box::new(e)), + } + } + let err = errors.into_iter().last().unwrap_or_else(|| { + Box::new(FlightError::ProtocolError(format!( + "No available location for endpoint {:?}", + partition.locations + ))) + }); + Err(DataFusionError::External(err)) +} + +async fn try_fetch_stream( + source: impl Into, + ticket: Vec, + schema: SchemaRef, + grpc: Arc, +) -> arrow_flight::error::Result { + let ticket = Ticket::new(ticket); + let dest = Channel::from_shared(source.into()) + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + let channel = dest + .connect() + .await + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + let mut client = FlightClient::new(channel); + client.metadata_mut().clone_from(grpc.as_ref()); + let stream = client.do_get(ticket).await?; + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema.clone(), + stream.map(move |rb| { + let schema = schema.clone(); + rb.map(move |rb| { + if schema.fields.is_empty() || rb.schema() == schema { + rb + } else if schema.contains(rb.schema_ref()) { + rb.with_schema(schema.clone()).unwrap() + } else { + let columns = schema + .fields + .iter() + .map(|field| { + rb.column_by_name(field.name()) + .expect("missing fields in record batch") + .clone() + }) + .collect(); + RecordBatch::try_new(schema.clone(), columns) + .expect("cannot impose desired schema on record batch") + } + }) + .map_err(|e| DataFusionError::External(Box::new(e))) + }), + ))) +} + +impl DisplayAs for FlightExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default => f.write_str("FlightExec"), + DisplayFormatType::Verbose => write!(f, "FlightExec {:?}", self.partitions), + } + } +} + +impl ExecutionPlan for FlightExec { + fn name(&self) -> &str { + "FlightExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.plan_properties.as_ref() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + let future_stream = flight_stream( + self.partitions[partition].clone(), + self.schema(), + self.grpc_metadata.clone(), + ); + let stream = futures::stream::once(future_stream).try_flatten(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream, + ))) + } +} diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index f810fb86bd89..b995917b8b55 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -23,6 +23,8 @@ mod csv; mod file_groups; mod file_scan_config; mod file_stream; +#[cfg(feature = "flight")] +mod flight; mod json; #[cfg(feature = "parquet")] pub mod parquet; @@ -43,6 +45,9 @@ pub use file_scan_config::{ pub use file_stream::{FileOpenFuture, FileOpener, FileStream, OnError}; pub use json::{JsonOpener, NdJsonExec}; +#[cfg(feature = "flight")] +pub use self::flight::{FlightExec, FlightPartition}; + use std::{ fmt::{Debug, Formatter, Result as FmtResult}, ops::Range, diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 07420afe842f..1577e793861b 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -25,6 +25,8 @@ use crate::datasource::file_format::json::JsonFormatFactory; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormatFactory; use crate::datasource::file_format::FileFormatFactory; +#[cfg(feature = "flight")] +use crate::datasource::flight::{sql::FlightSqlDriver, FlightTableFactory}; use crate::datasource::provider::DefaultTableFactory; use crate::execution::context::SessionState; #[cfg(feature = "nested_expressions")] @@ -55,6 +57,13 @@ impl SessionStateDefaults { table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); + #[cfg(feature = "flight")] + table_factories.insert( + "FLIGHT_SQL".into(), + Arc::new(FlightTableFactory::new( + Arc::new(FlightSqlDriver::default()), + )), + ); table_factories } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 95d9e6700a50..cf5660d79ccc 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -40,9 +40,10 @@ name = "datafusion_proto" path = "src/lib.rs" [features] -default = ["parquet"] +default = ["parquet", "flight"] json = ["pbjson", "serde", "serde_json"] parquet = ["datafusion/parquet", "datafusion-common/parquet"] +flight = [] [dependencies] arrow = { workspace = true } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 819130b08e86..5e0172a53a9b 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -693,6 +693,7 @@ message PhysicalPlanNode { PlaceholderRowExecNode placeholder_row = 27; CsvSinkExecNode csv_sink = 28; ParquetSinkExecNode parquet_sink = 29; + FlightScanExecNode flight_scan = 30; } } @@ -1199,3 +1200,28 @@ message PartitionStats { int64 num_bytes = 3; repeated datafusion_common.ColumnStats column_stats = 4; } + +message FlightScanExecNode { + repeated FlightPartitionNode partitions = 1; + PlanPropertiesNode plan_properties = 2; + map grpc_headers = 3; +} + +message FlightPartitionNode { + repeated string locations = 1; + bytes token = 2; +} + +message PlanPropertiesNode { + datafusion_common.Schema schema = 1; + repeated PhysicalSortExprNodeCollection output_ordering = 2; + Partitioning partitioning = 3; + ExecutionMode execution_mode = 4; +} + +enum ExecutionMode { + Bounded = 0; + Unbounded = 1; + PipelineBreaking = 2; +} + diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 521a0d90c1ed..1cd71b0eaf2b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4676,6 +4676,80 @@ impl<'de> serde::Deserialize<'de> for EmptyRelationNode { deserializer.deserialize_struct("datafusion.EmptyRelationNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ExecutionMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Bounded => "Bounded", + Self::Unbounded => "Unbounded", + Self::PipelineBreaking => "PipelineBreaking", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for ExecutionMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "Bounded", + "Unbounded", + "PipelineBreaking", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ExecutionMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "Bounded" => Ok(ExecutionMode::Bounded), + "Unbounded" => Ok(ExecutionMode::Unbounded), + "PipelineBreaking" => Ok(ExecutionMode::PipelineBreaking), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for ExplainExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -5752,6 +5826,249 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { deserializer.deserialize_struct("datafusion.FixedSizeBinary", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for FlightPartitionNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.locations.is_empty() { + len += 1; + } + if !self.token.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FlightPartitionNode", len)?; + if !self.locations.is_empty() { + struct_ser.serialize_field("locations", &self.locations)?; + } + if !self.token.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("token", pbjson::private::base64::encode(&self.token).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FlightPartitionNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "locations", + "token", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Locations, + Token, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "locations" => Ok(GeneratedField::Locations), + "token" => Ok(GeneratedField::Token), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FlightPartitionNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FlightPartitionNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut locations__ = None; + let mut token__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Locations => { + if locations__.is_some() { + return Err(serde::de::Error::duplicate_field("locations")); + } + locations__ = Some(map_.next_value()?); + } + GeneratedField::Token => { + if token__.is_some() { + return Err(serde::de::Error::duplicate_field("token")); + } + token__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + } + } + Ok(FlightPartitionNode { + locations: locations__.unwrap_or_default(), + token: token__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.FlightPartitionNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for FlightScanExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.partitions.is_empty() { + len += 1; + } + if self.plan_properties.is_some() { + len += 1; + } + if !self.grpc_headers.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FlightScanExecNode", len)?; + if !self.partitions.is_empty() { + struct_ser.serialize_field("partitions", &self.partitions)?; + } + if let Some(v) = self.plan_properties.as_ref() { + struct_ser.serialize_field("planProperties", v)?; + } + if !self.grpc_headers.is_empty() { + let v: std::collections::HashMap<_, _> = self.grpc_headers.iter() + .map(|(k, v)| (k, pbjson::private::base64::encode(v))).collect(); + struct_ser.serialize_field("grpcHeaders", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FlightScanExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "partitions", + "plan_properties", + "planProperties", + "grpc_headers", + "grpcHeaders", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Partitions, + PlanProperties, + GrpcHeaders, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "partitions" => Ok(GeneratedField::Partitions), + "planProperties" | "plan_properties" => Ok(GeneratedField::PlanProperties), + "grpcHeaders" | "grpc_headers" => Ok(GeneratedField::GrpcHeaders), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FlightScanExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FlightScanExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut partitions__ = None; + let mut plan_properties__ = None; + let mut grpc_headers__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Partitions => { + if partitions__.is_some() { + return Err(serde::de::Error::duplicate_field("partitions")); + } + partitions__ = Some(map_.next_value()?); + } + GeneratedField::PlanProperties => { + if plan_properties__.is_some() { + return Err(serde::de::Error::duplicate_field("planProperties")); + } + plan_properties__ = map_.next_value()?; + } + GeneratedField::GrpcHeaders => { + if grpc_headers__.is_some() { + return Err(serde::de::Error::duplicate_field("grpcHeaders")); + } + grpc_headers__ = Some( + map_.next_value::>>()? + .into_iter().map(|(k,v)| (k, v.0)).collect() + ); + } + } + } + Ok(FlightScanExecNode { + partitions: partitions__.unwrap_or_default(), + plan_properties: plan_properties__, + grpc_headers: grpc_headers__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.FlightScanExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for FullTableReference { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -14686,6 +15003,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::ParquetSink(v) => { struct_ser.serialize_field("parquetSink", v)?; } + physical_plan_node::PhysicalPlanType::FlightScan(v) => { + struct_ser.serialize_field("flightScan", v)?; + } } } struct_ser.end() @@ -14741,6 +15061,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "csvSink", "parquet_sink", "parquetSink", + "flight_scan", + "flightScan", ]; #[allow(clippy::enum_variant_names)] @@ -14773,6 +15095,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { PlaceholderRow, CsvSink, ParquetSink, + FlightScan, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14822,6 +15145,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "placeholderRow" | "placeholder_row" => Ok(GeneratedField::PlaceholderRow), "csvSink" | "csv_sink" => Ok(GeneratedField::CsvSink), "parquetSink" | "parquet_sink" => Ok(GeneratedField::ParquetSink), + "flightScan" | "flight_scan" => Ok(GeneratedField::FlightScan), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -15038,6 +15362,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("parquetSink")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetSink) +; + } + GeneratedField::FlightScan => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("flightScan")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::FlightScan) ; } } @@ -16054,6 +16385,152 @@ impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { deserializer.deserialize_struct("datafusion.PlaceholderRowExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PlanPropertiesNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.schema.is_some() { + len += 1; + } + if !self.output_ordering.is_empty() { + len += 1; + } + if self.partitioning.is_some() { + len += 1; + } + if self.execution_mode != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PlanPropertiesNode", len)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if !self.output_ordering.is_empty() { + struct_ser.serialize_field("outputOrdering", &self.output_ordering)?; + } + if let Some(v) = self.partitioning.as_ref() { + struct_ser.serialize_field("partitioning", v)?; + } + if self.execution_mode != 0 { + let v = ExecutionMode::try_from(self.execution_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.execution_mode)))?; + struct_ser.serialize_field("executionMode", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PlanPropertiesNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "schema", + "output_ordering", + "outputOrdering", + "partitioning", + "execution_mode", + "executionMode", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Schema, + OutputOrdering, + Partitioning, + ExecutionMode, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "schema" => Ok(GeneratedField::Schema), + "outputOrdering" | "output_ordering" => Ok(GeneratedField::OutputOrdering), + "partitioning" => Ok(GeneratedField::Partitioning), + "executionMode" | "execution_mode" => Ok(GeneratedField::ExecutionMode), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PlanPropertiesNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PlanPropertiesNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut schema__ = None; + let mut output_ordering__ = None; + let mut partitioning__ = None; + let mut execution_mode__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::OutputOrdering => { + if output_ordering__.is_some() { + return Err(serde::de::Error::duplicate_field("outputOrdering")); + } + output_ordering__ = Some(map_.next_value()?); + } + GeneratedField::Partitioning => { + if partitioning__.is_some() { + return Err(serde::de::Error::duplicate_field("partitioning")); + } + partitioning__ = map_.next_value()?; + } + GeneratedField::ExecutionMode => { + if execution_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("executionMode")); + } + execution_mode__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(PlanPropertiesNode { + schema: schema__, + output_ordering: output_ordering__.unwrap_or_default(), + partitioning: partitioning__, + execution_mode: execution_mode__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.PlanPropertiesNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PlanType { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 070c9b31d3d4..ae4f34d815d5 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1030,7 +1030,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30" )] pub physical_plan_type: ::core::option::Option, } @@ -1097,6 +1097,8 @@ pub mod physical_plan_node { CsvSink(::prost::alloc::boxed::Box), #[prost(message, tag = "29")] ParquetSink(::prost::alloc::boxed::Box), + #[prost(message, tag = "30")] + FlightScan(super::FlightScanExecNode), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1916,6 +1918,39 @@ pub struct PartitionStats { #[prost(message, repeated, tag = "4")] pub column_stats: ::prost::alloc::vec::Vec, } +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FlightScanExecNode { + #[prost(message, repeated, tag = "1")] + pub partitions: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "2")] + pub plan_properties: ::core::option::Option, + #[prost(map = "string, bytes", tag = "3")] + pub grpc_headers: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::vec::Vec, + >, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FlightPartitionNode { + #[prost(string, repeated, tag = "1")] + pub locations: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(bytes = "vec", tag = "2")] + pub token: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PlanPropertiesNode { + #[prost(message, optional, tag = "1")] + pub schema: ::core::option::Option, + #[prost(message, repeated, tag = "2")] + pub output_ordering: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "3")] + pub partitioning: ::core::option::Option, + #[prost(enumeration = "ExecutionMode", tag = "4")] + pub execution_mode: i32, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum BuiltInWindowFunction { @@ -2143,3 +2178,32 @@ impl AggregateMode { } } } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum ExecutionMode { + Bounded = 0, + Unbounded = 1, + PipelineBreaking = 2, +} +impl ExecutionMode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + ExecutionMode::Bounded => "Bounded", + ExecutionMode::Unbounded => "Unbounded", + ExecutionMode::PipelineBreaking => "PipelineBreaking", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "Bounded" => Some(Self::Bounded), + "Unbounded" => Some(Self::Unbounded), + "PipelineBreaking" => Some(Self::PipelineBreaking), + _ => None, + } + } +} diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 0f6722dd375b..4e7ee43ceae5 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::fmt::Debug; -use std::sync::Arc; - +#![allow(unused_imports)] // required for disabling features +use arrow::datatypes::Schema; use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use prost::bytes::BufMut; use prost::Message; +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::Arc; use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; @@ -32,9 +34,13 @@ use datafusion::datasource::file_format::parquet::ParquetSink; #[cfg(feature = "parquet")] use datafusion::datasource::physical_plan::ParquetExec; use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; +#[cfg(feature = "flight")] +use datafusion::datasource::physical_plan::{FlightExec, FlightPartition}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; -use datafusion::physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; +use datafusion::physical_expr::{ + EquivalenceProperties, PhysicalExprRef, PhysicalSortRequirement, +}; use datafusion::physical_plan::aggregates::AggregateMode; use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use datafusion::physical_plan::analyze::AnalyzeExec; @@ -59,7 +65,8 @@ use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMerge use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ - AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, + AggregateExpr, ExecutionMode, ExecutionPlan, InputOrderMode, PhysicalExpr, + PlanProperties, WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, ScalarUDF}; @@ -72,12 +79,15 @@ use crate::physical_plan::from_proto::{ }; use crate::physical_plan::to_proto::{ serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, - serialize_physical_window_expr, + serialize_physical_sort_exprs, serialize_physical_window_expr, }; use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::physical_plan_node::PhysicalPlanType; -use crate::protobuf::{self, proto_error, window_agg_exec_node}; +use crate::protobuf::{ + self, proto_error, window_agg_exec_node, FlightPartitionNode, FlightScanExecNode, + PhysicalSortExprNodeCollection, PlanPropertiesNode, +}; use self::from_proto::parse_protobuf_partitioning; use self::to_proto::{serialize_partitioning, serialize_physical_expr}; @@ -245,6 +255,10 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } Ok(builder.build_arc()) } + #[cfg(not(feature = "parquet"))] + PhysicalPlanType::ParquetScan(_) => { + unreachable!("The `parquet` feature is disabled") + } PhysicalPlanType::AvroScan(scan) => { Ok(Arc::new(AvroExec::new(parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), @@ -1051,6 +1065,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { sort_order, ))) } + #[cfg(feature = "parquet")] PhysicalPlanType::ParquetSink(sink) => { let input = into_physical_plan(&sink.input, registry, runtime, extension_codec)?; @@ -1081,6 +1096,75 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { sort_order, ))) } + #[cfg(not(feature = "parquet"))] + PhysicalPlanType::ParquetSink(_) => { + unreachable!("The `parquet` feature is disabled") + } + #[cfg(feature = "flight")] + PhysicalPlanType::FlightScan(flight_scan) => { + let partitions = flight_scan + .partitions + .iter() + .map(|partition_node| { + FlightPartition::restore( + partition_node.locations.clone(), + partition_node.token.clone(), + ) + }) + .collect(); + let plan_props = + flight_scan.plan_properties.clone().ok_or_else(|| { + DataFusionError::Internal( + "Missing plan_properties in FlightExec".into(), + ) + })?; + let schema: Arc = Arc::new(convert_required!(plan_props.schema)?); + let codec = DefaultPhysicalExtensionCodec {}; + let partitioning = parse_protobuf_partitioning( + plan_props.partitioning.as_ref(), + registry, + schema.as_ref(), + &codec, + )? + .ok_or_else(|| { + DataFusionError::Internal( + "Missing partitioning from plan properties".into(), + ) + })?; + let orderings = &plan_props + .output_ordering + .iter() + .map(|node_collection| { + parse_physical_sort_exprs( + &node_collection.physical_sort_expr_nodes, + registry, + &schema, + &codec, + ) + .unwrap() + }) + .collect::>(); + let eq_props = + EquivalenceProperties::new_with_orderings(schema, orderings); + let execution_mode = match plan_props.execution_mode { + 0 => ExecutionMode::Bounded, + 1 => ExecutionMode::Unbounded, + 2 => ExecutionMode::PipelineBreaking, + _ => unreachable!("Unexpected execution mode"), + }; + let plan_properties = + PlanProperties::new(eq_props, partitioning, execution_mode); + + Ok(Arc::new(FlightExec::restore( + partitions, + plan_properties, + &flight_scan.grpc_headers, + ))) + } + #[cfg(not(feature = "flight"))] + PhysicalPlanType::FlightScan(_) => { + unreachable!("The `flight` feature is disabled") + } } } @@ -1923,6 +2007,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { }); } + #[cfg(feature = "parquet")] if let Some(sink) = exec.sink().as_any().downcast_ref::() { return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetSink(Box::new( @@ -1939,6 +2024,64 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { // If unknown DataSink then let extension handle it } + #[cfg(feature = "flight")] + if let Some(exec) = plan.downcast_ref::() { + let partitions = exec + .partitions + .iter() + .map(|part| FlightPartitionNode { + locations: part.locations.clone(), + token: part.ticket.clone(), + }) + .collect(); + let pp = exec.plan_properties.as_ref(); + let partitioning = + Some(serialize_partitioning(&pp.partitioning, extension_codec)?); + let output_ordering = pp + .eq_properties + .oeq_class + .orderings + .clone() + .into_iter() + .map(|ordering| PhysicalSortExprNodeCollection { + physical_sort_expr_nodes: serialize_physical_sort_exprs( + ordering, + extension_codec, + ) + .unwrap(), + }) + .collect(); + + let execution_mode = match pp.execution_mode { + ExecutionMode::Bounded => 0, + ExecutionMode::Unbounded => 1, + ExecutionMode::PipelineBreaking => 1, + }; + let schema = Some(exec.schema().try_into()?); + let plan_properties = Some(PlanPropertiesNode { + schema, + output_ordering, + partitioning, + execution_mode, + }); + let grpc_headers = HashMap::from_iter( + (*exec.grpc_metadata) + .clone() + .into_headers() + .iter() + .map(|(k, v)| (k.as_str().into(), Vec::from(v.as_bytes()))), + ); + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::FlightScan( + FlightScanExecNode { + partitions, + plan_properties, + grpc_headers, + }, + )), + }); + } + let mut buf: Vec = vec![]; match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => {