From 8b49989cb82619c31e3b515e2545fbb81e8a1f24 Mon Sep 17 00:00:00 2001 From: Costi Ciudatu Date: Tue, 20 Aug 2024 00:06:00 +0300 Subject: [PATCH] Generic FlightTableFactory with a default FlightSqlDriver --- .gitignore | 2 +- Cargo.toml | 37 ++++ README.md | 12 +- examples/flight-sql.rs | 56 +++++ src/codec.rs | 67 ++++++ src/exec.rs | 313 +++++++++++++++++++++++++++ src/lib.rs | 276 ++++++++++++++++++++++++ src/sql.rs | 465 +++++++++++++++++++++++++++++++++++++++++ 8 files changed, 1225 insertions(+), 3 deletions(-) create mode 100644 Cargo.toml create mode 100644 examples/flight-sql.rs create mode 100644 src/codec.rs create mode 100644 src/exec.rs create mode 100644 src/lib.rs create mode 100644 src/sql.rs diff --git a/.gitignore b/.gitignore index d01bd1a..290f8c5 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,4 @@ Cargo.lock # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +.idea/ diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..e75a32e --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "datafusion-table-provider-flight" +version = "0.1.0" +edition = "2021" + +[dependencies] +arrow-array = "52.2.0" +arrow-flight = { version = "52.2.0", features = ["flight-sql-experimental", "tls"] } +arrow-schema = { version = "52.2.0", features = ["serde"] } +async-trait = "0.1.81" +base64 = "0.22.1" +bytes = "1.7.1" +datafusion = "41.0.0" +datafusion-expr = "41.0.0" +datafusion-physical-expr = "41.0.0" +datafusion-physical-plan = "41.0.0" +datafusion-proto = "41.0.0" +futures = "0.3.30" +prost = "0.12" # pinned for arrow-flight compat +serde = { version = "1.0.208", features = ["derive"] } +serde_json = "1.0.125" +tokio = { version = "1.36", features = [ + "macros", + "rt", + "sync", + "rt-multi-thread", + "parking_lot", + "fs", +] } +tonic = "0.11" # pinned for arrow-flight compat + +[dev-dependencies] +tokio-stream = { version = "0.1.15", features = ["net"] } + +[[example]] +name = "flight-sql" +path = "examples/flight-sql.rs" diff --git a/README.md b/README.md index 925e5a1..287d5b4 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,10 @@ -# datafusion-table-provider-flightsql -FlightSQL TableProvider for DataFusion +# DataFusion TableProviderFactory for Arrow Flight +A generic `FlightTableFactory` that can integrate any Arrow Flight RPC service +as a `TableProviderFactory`. Relies on a `FlightDriver` trait implementation to +handle the `GetFlightInfo` call and all its prerequisites. + +## Flight SQL +This crate includes a `FlightSqlDriver` that has been tested with +[Ballista](https://github.com/apache/datafusion-ballista), +[Dremio](https://github.com/dremio/dremio-oss) and +[ROAPI](https://github.com/roapi/roapi). \ No newline at end of file diff --git a/examples/flight-sql.rs b/examples/flight-sql.rs new file mode 100644 index 0000000..f985f7b --- /dev/null +++ b/examples/flight-sql.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::prelude::SessionContext; +use datafusion_table_provider_flight::sql::FlightSqlDriver; +use datafusion_table_provider_flight::FlightTableFactory; +use std::sync::Arc; + +/// Prerequisites: +/// ``` +/// $ brew install roapi +/// $ roapi -t taxi=https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2024-01.parquet +/// ``` +#[tokio::main] +async fn main() -> datafusion::common::Result<()> { + let ctx = SessionContext::new(); + ctx.state_ref().write().table_factories_mut().insert( + "FLIGHT_SQL".into(), + Arc::new(FlightTableFactory::new( + Arc::new(FlightSqlDriver::default()), + )), + ); + let _ = ctx + .sql(r#" + CREATE EXTERNAL TABLE trip_data STORED AS FLIGHT_SQL + LOCATION 'http://localhost:32010' + OPTIONS ( + 'flight.sql.query' 'SELECT * FROM taxi' + ) + "#) + .await?; + + let df = ctx + .sql(r#" + SELECT "VendorID", COUNT(*), SUM(passenger_count), SUM(total_amount) + FROM trip_data + GROUP BY "VendorID" + ORDER BY COUNT(*) DESC + "#) + .await?; + df.show().await +} diff --git a/src/codec.rs b/src/codec.rs new file mode 100644 index 0000000..913bf00 --- /dev/null +++ b/src/codec.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Flight plan codecs + +use std::sync::Arc; + +use datafusion::common::DataFusionError; +use datafusion_expr::registry::FunctionRegistry; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; + +use crate::exec::{FlightConfig, FlightExec}; + +/// Physical extension codec for FlightExec +#[derive(Clone, Debug, Default)] +pub struct FlightPhysicalCodec {} + +impl PhysicalExtensionCodec for FlightPhysicalCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> datafusion::common::Result> { + if inputs.is_empty() { + let config: FlightConfig = serde_json::from_slice(buf) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + Ok(Arc::from(FlightExec::from(config))) + } else { + Err(DataFusionError::Internal( + "FlightExec is not supposed to have any inputs".into(), + )) + } + } + + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + ) -> datafusion::common::Result<()> { + if let Some(flight) = node.as_any().downcast_ref::() { + let mut bytes = serde_json::to_vec(flight.config()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + buf.append(&mut bytes); + Ok(()) + } else { + Err(DataFusionError::Internal( + "This codec only supports the FlightExec physical plan".into(), + )) + } + } +} diff --git a/src/exec.rs b/src/exec.rs new file mode 100644 index 0000000..79352cf --- /dev/null +++ b/src/exec.rs @@ -0,0 +1,313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution plan for reading flights from Arrow Flight services + +use std::any::Any; +use std::error::Error; +use std::fmt::Formatter; +use std::str::FromStr; +use std::sync::Arc; + +use arrow_array::RecordBatch; +use arrow_flight::error::FlightError; +use arrow_flight::{FlightClient, FlightEndpoint, Ticket}; +use arrow_schema::SchemaRef; +use datafusion::common::Result; +use datafusion::common::{project_schema, DataFusionError}; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, PlanProperties, +}; +use futures::{StreamExt, TryStreamExt}; +use serde::{Deserialize, Serialize}; +use tonic::metadata::{AsciiMetadataKey, MetadataMap}; +use tonic::transport::Channel; + +use crate::{FlightMetadata, FlightProperties}; + +/// Arrow Flight physical plan that maps flight endpoints to partitions +#[derive(Clone, Debug)] +pub(crate) struct FlightExec { + config: FlightConfig, + plan_properties: PlanProperties, + metadata_map: Arc, +} + +impl FlightExec { + /// Creates a FlightExec with the provided [FlightMetadata] + /// and origin URL (used as fallback location as per the protocol spec). + pub fn try_new( + metadata: FlightMetadata, + projection: Option<&Vec>, + origin: &str, + ) -> Result { + let partitions: Vec<_> = metadata + .info + .endpoint + .iter() + .map(|endpoint| FlightPartition::new(endpoint, origin.to_string())) + .map(Arc::new) + .collect(); + let schema = project_schema(&metadata.schema, projection)?; + let config = FlightConfig { + schema, + partitions, + properties: metadata.props, + }; + Ok(config.into()) + } + + pub(crate) fn config(&self) -> &FlightConfig { + &self.config + } +} + +impl From for FlightExec { + fn from(config: FlightConfig) -> Self { + let exec_mode = if config.properties.unbounded_stream { + ExecutionMode::Unbounded + } else { + ExecutionMode::Bounded + }; + let plan_properties = PlanProperties::new( + EquivalenceProperties::new(config.schema.clone()), + Partitioning::UnknownPartitioning(config.partitions.len()), + exec_mode, + ); + let mut mm = MetadataMap::new(); + for (k, v) in config.properties.grpc_headers.iter() { + let key = AsciiMetadataKey::from_str(k.as_str()) + .expect("invalid header name"); + let value = v.parse() + .expect("invalid header value"); + mm.insert(key, value); + } + Self { + config, + plan_properties, + metadata_map: Arc::from(mm), + } + } +} + +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +pub(crate) struct FlightConfig { + schema: SchemaRef, + partitions: Vec>, + properties: FlightProperties, +} + +/// The minimum information required for fetching a flight stream. +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +struct FlightPartition { + locations: Vec, + ticket: Vec, +} + +impl FlightPartition { + fn new(endpoint: &FlightEndpoint, fallback_location: String) -> Self { + let locations = if endpoint.location.is_empty() { + vec![fallback_location] + } else { + endpoint + .location + .iter() + .map(|loc| { + if loc.uri.starts_with("arrow-flight-reuse-connection://") { + fallback_location.clone() + } else { + loc.uri.clone() + } + }) + .collect() + }; + Self { + locations, + ticket: endpoint + .ticket + .clone() + .expect("No flight ticket") + .ticket + .to_vec(), + } + } +} + +async fn flight_stream( + partition: Arc, + schema: SchemaRef, + grpc_headers: Arc, +) -> Result { + let mut errors: Vec> = vec![]; + for loc in &partition.locations { + match try_fetch_stream( + loc, + partition.ticket.clone(), + schema.clone(), + grpc_headers.clone(), + ) + .await + { + Ok(stream) => return Ok(stream), + Err(e) => errors.push(Box::new(e)), + } + } + let err = errors.into_iter().last().unwrap_or_else(|| { + Box::new(FlightError::ProtocolError(format!( + "No available location for endpoint {:?}", + partition.locations + ))) + }); + Err(DataFusionError::External(err)) +} + +async fn try_fetch_stream( + source: impl Into, + ticket: Vec, + schema: SchemaRef, + grpc_headers: Arc, +) -> arrow_flight::error::Result { + let ticket = Ticket::new(ticket); + let dest = + Channel::from_shared(source.into()).map_err(|e| FlightError::ExternalError(Box::new(e)))?; + let channel = dest + .connect() + .await + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + let mut client = FlightClient::new(channel); + client.metadata_mut().clone_from(grpc_headers.as_ref()); + let stream = client.do_get(ticket).await?; + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema.clone(), + stream.map(move |rb| { + let schema = schema.clone(); + rb.map(move |rb| { + if schema.fields.is_empty() || rb.schema() == schema { + rb + } else if schema.contains(rb.schema_ref()) { + rb.with_schema(schema.clone()).unwrap() + } else { + let columns = schema + .fields + .iter() + .map(|field| { + rb.column_by_name(field.name()) + .expect("missing fields in record batch") + .clone() + }) + .collect(); + RecordBatch::try_new(schema.clone(), columns) + .expect("cannot impose desired schema on record batch") + } + }) + .map_err(|e| DataFusionError::External(Box::new(e))) + }), + ))) +} + +impl DisplayAs for FlightExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default => f.write_str("FlightExec"), + DisplayFormatType::Verbose => write!(f, "FlightExec {:?}", self.config), + } + } +} + +impl ExecutionPlan for FlightExec { + fn name(&self) -> &str { + "FlightExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + let future_stream = flight_stream( + self.config.partitions[partition].clone(), + self.schema(), + self.metadata_map.clone(), + ); + let stream = futures::stream::once(future_stream).try_flatten(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream, + ))) + } +} + +#[cfg(test)] +mod tests { + use crate::exec::{FlightConfig, FlightPartition}; + use crate::FlightProperties; + use arrow_schema::{DataType, Field, Schema}; + use std::collections::HashMap; + use std::sync::Arc; + + #[test] + fn test_flight_config_serde() { + let schema = Arc::new(Schema::new(vec![ + Arc::new(Field::new("f1", DataType::Utf8, true)), + Arc::new(Field::new("f2", DataType::Int32, false)), + ])); + let partitions = vec![ + Arc::new(FlightPartition { + locations: vec!["l1".into(), "l2".into()], + ticket: "tichet".as_bytes().to_vec(), + }), + Arc::new(FlightPartition { + locations: vec!["l3".into(), "l4".into()], + ticket: "tichet2".as_bytes().to_vec(), + }), + ]; + let properties = FlightProperties::new( + true, + HashMap::from([("h1".into(), "v1".into()), ("h2".into(), "v2".into())]), + ); + let config = FlightConfig { + schema, + partitions, + properties, + }; + let json = serde_json::to_vec(&config).expect("cannot encode config as json"); + let restored = serde_json::from_slice(json.as_slice()).expect("cannot decode json config"); + assert_eq!(config, restored); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..ef4b13d --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,276 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Generic [FlightTableFactory] that can connect to Arrow Flight services, +//! with a [sql::FlightSqlDriver] provided out-of-the-box. + +use std::any::Any; +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::Arc; + +use crate::exec::FlightExec; +use arrow_flight::error::FlightError; +use arrow_flight::FlightInfo; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use datafusion::catalog::{Session, TableProviderFactory}; +use datafusion::common::stats::Precision; +use datafusion::common::{DataFusionError, Statistics}; +use datafusion::datasource::TableProvider; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use serde::{Deserialize, Serialize}; +use tonic::transport::Channel; + +pub mod codec; +mod exec; +pub mod sql; + +/// Generic Arrow Flight data source. Requires a [FlightDriver] that allows implementors +/// to integrate any custom Flight RPC service by producing a [FlightMetadata] for some DDL. +/// +/// # Sample usage: +/// ``` +/// use std::collections::HashMap; +/// use arrow_flight::{FlightClient, FlightDescriptor}; +/// use tonic::transport::Channel; +/// use datafusion::prelude::SessionContext; +/// use std::sync::Arc; +/// use datafusion_table_provider_flight::{FlightDriver, FlightMetadata}; +/// +/// #[derive(Debug, Clone, Default)] +/// struct CustomFlightDriver {} +/// #[async_trait::async_trait] +/// impl FlightDriver for CustomFlightDriver { +/// async fn metadata(&self, channel: Channel, opts: &HashMap) +/// -> arrow_flight::error::Result { +/// let mut client = FlightClient::new(channel); +/// // the `flight.` prefix is an already registered namespace in datafusion-cli +/// let descriptor = FlightDescriptor::new_cmd(opts["flight.command"].clone()); +/// let flight_info = client.get_flight_info(descriptor).await?; +/// FlightMetadata::try_from(flight_info) +/// } +/// } +/// +/// #[tokio::main] +/// async fn main() -> datafusion::common::Result<()> { +/// use datafusion_table_provider_flight::FlightTableFactory; +/// let ctx = SessionContext::new(); +/// ctx.state_ref().write().table_factories_mut() +/// .insert("CUSTOM_FLIGHT".into(), Arc::new(FlightTableFactory::new( +/// Arc::new(CustomFlightDriver::default()) +/// ))); +/// let _ = ctx.sql(r#" +/// CREATE EXTERNAL TABLE custom_flight_table STORED AS CUSTOM_FLIGHT +/// LOCATION 'https://custom.flight.rpc' +/// OPTIONS ('flight.command' 'select * from everywhere') +/// "#).await; // will fail as it can't connect to the bogus URL, but we ignore the error +/// Ok(()) +/// } +/// ``` +#[derive(Clone, Debug)] +pub struct FlightTableFactory { + driver: Arc, +} + +impl FlightTableFactory { + /// Create a data source using the provided driver + pub fn new(driver: Arc) -> Self { + Self { driver } + } + + /// Convenient way to create a [FlightTable] programatically, as an alternative to DDL. + pub async fn open_table( + &self, + entry_point: impl Into, + options: HashMap, + ) -> datafusion::common::Result { + let origin = entry_point.into(); + let channel = Channel::from_shared(origin.clone()) + .unwrap() + .connect() + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let metadata = self + .driver + .metadata(channel.clone(), &options) + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let num_rows = precision(metadata.info.total_records); + let total_byte_size = precision(metadata.info.total_bytes); + let logical_schema = metadata.schema; + let stats = Statistics { + num_rows, + total_byte_size, + column_statistics: vec![], + }; + Ok(FlightTable { + driver: self.driver.clone(), + channel, + options, + origin, + logical_schema, + stats, + }) + } +} + +fn precision(total: i64) -> Precision { + if total < 0 { + Precision::Absent + } else { + Precision::Exact(total as usize) + } +} + +#[async_trait] +impl TableProviderFactory for FlightTableFactory { + async fn create( + &self, + _state: &dyn Session, + cmd: &CreateExternalTable, + ) -> datafusion::common::Result> { + let table = self.open_table(&cmd.location, cmd.options.clone()).await?; + Ok(Arc::new(table)) + } +} + +/// Extension point for integrating any Flight RPC service as a [FlightTableFactory]. +/// Handles the initial `GetFlightInfo` call and all its prerequisites (such as `Handshake`), +/// to produce a [FlightMetadata]. +#[async_trait] +pub trait FlightDriver: Sync + Send + Debug { + /// Returns a [FlightMetadata] from the specified channel, + /// according to the provided table options. + /// The driver must provide at least a [FlightInfo] in order to construct a flight metadata. + async fn metadata( + &self, + channel: Channel, + options: &HashMap, + ) -> arrow_flight::error::Result; +} + +/// The information that a [FlightDriver] must produce +/// in order to register flights as DataFusion tables. +#[derive(Clone, Debug)] +pub struct FlightMetadata { + /// FlightInfo object produced by the driver + info: FlightInfo, + /// Arrow schema. Can be enforced by the driver or inferred from the FlightInfo + schema: SchemaRef, + /// Various knobs that control execution + props: FlightProperties, +} + +impl FlightMetadata { + /// Customize everything that is in the driver's control + pub fn new(info: FlightInfo, schema: SchemaRef, props: FlightProperties) -> Self { + Self { + info, + schema, + props, + } + } + + /// Customize gRPC headers + pub fn try_new( + info: FlightInfo, + grpc_headers: HashMap, + ) -> arrow_flight::error::Result { + let schema = Arc::new(info.clone().try_decode_schema()?); + let props = grpc_headers.into(); + Ok(Self::new(info, schema, props)) + } +} + +impl TryFrom for FlightMetadata { + type Error = FlightError; + + fn try_from(info: FlightInfo) -> Result { + Self::try_new(info, HashMap::default()) + } +} + +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +pub struct FlightProperties { + unbounded_stream: bool, + grpc_headers: HashMap, +} + +impl FlightProperties { + pub fn new(unbounded_stream: bool, grpc_headers: HashMap) -> Self { + Self { + unbounded_stream, + grpc_headers, + } + } +} + +impl From> for FlightProperties { + fn from(grpc_headers: HashMap) -> Self { + Self::new(false, grpc_headers) + } +} + +/// Table provider that wraps a specific flight from an Arrow Flight service +pub struct FlightTable { + driver: Arc, + channel: Channel, + options: HashMap, + origin: String, + logical_schema: SchemaRef, + stats: Statistics, +} + +#[async_trait] +impl TableProvider for FlightTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.logical_schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::View + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> datafusion::common::Result> { + let metadata = self + .driver + .metadata(self.channel.clone(), &self.options) + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + Ok(Arc::new(FlightExec::try_new( + metadata, + projection, + &self.origin, + )?)) + } + + fn statistics(&self) -> Option { + Some(self.stats.clone()) + } +} diff --git a/src/sql.rs b/src/sql.rs new file mode 100644 index 0000000..db8cb3c --- /dev/null +++ b/src/sql.rs @@ -0,0 +1,465 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Default [FlightDriver] for Flight SQL + +use std::collections::HashMap; +use std::str::FromStr; + +use arrow_flight::error::Result; +use arrow_flight::flight_service_client::FlightServiceClient; +use arrow_flight::sql::{CommandStatementQuery, ProstMessageExt}; +use arrow_flight::{FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse}; +use arrow_schema::ArrowError; +use async_trait::async_trait; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use bytes::Bytes; +use futures::{stream, TryStreamExt}; +use prost::Message; +use tonic::metadata::AsciiMetadataKey; +use tonic::transport::Channel; +use tonic::IntoRequest; + +use crate::{FlightDriver, FlightMetadata}; + +/// Default Flight SQL driver. Requires a `flight.sql.query` to be passed as a table option. +/// If `flight.sql.username` (and optionally `flight.sql.password`) are passed, +/// will perform the `Handshake` using basic authentication. +/// Any additional headers can be passed as table options using the `flight.sql.header.` prefix. +/// +/// A [crate::FlightTableFactory] using this driver is registered +/// with the default `SessionContext` under the name `FLIGHT_SQL`. +#[derive(Clone, Debug, Default)] +pub struct FlightSqlDriver {} + +#[async_trait] +impl FlightDriver for FlightSqlDriver { + async fn metadata( + &self, + channel: Channel, + options: &HashMap, + ) -> Result { + let mut client = FlightSqlClient::new(channel); + let headers = options.iter().filter_map(|(key, value)| { + key.strip_prefix("flight.sql.header.") + .map(|header_name| (header_name, value)) + }); + for header in headers { + client.set_header(header.0, header.1) + } + if let Some(username) = options.get("flight.sql.username") { + let default_password = "".to_string(); + let password = options + .get("flight.sql.password") + .unwrap_or(&default_password); + _ = client.handshake(username, password).await?; + } + let info = client + .execute(options["flight.sql.query"].clone(), None) + .await?; + let mut grpc_headers = HashMap::default(); + if let Some(token) = client.token { + grpc_headers.insert("authorization".into(), format!("Bearer {}", token)); + } + FlightMetadata::try_new(info, grpc_headers) + } +} + +///////////////////////////////////////////////////////////////////////// +// Shameless copy/paste from arrow-flight FlightSqlServiceClient +// This is only needed in order to access the bearer token received +// during handshake, as the standard client does not expose this information. +// The bearer token has to be passed to the clients that perform +// the DoGet operation, since Dremio, Ballista and possibly others +// expect the bearer token they produce with the handshake response +// to be set on all subsequent requests, including DoGet. +// +// TODO: remove this and switch to the official client once +// https://github.com/apache/arrow-rs/pull/6254 is released +#[derive(Debug, Clone)] +struct FlightSqlClient { + token: Option, + headers: HashMap, + flight_client: FlightServiceClient, +} + +impl FlightSqlClient { + /// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel` + fn new(channel: Channel) -> Self { + Self::new_from_inner(FlightServiceClient::new(channel)) + } + + /// Creates a new higher level client with the provided lower level client + fn new_from_inner(inner: FlightServiceClient) -> Self { + Self { + token: None, + flight_client: inner, + headers: HashMap::default(), + } + } + + /// Perform a `handshake` with the server, passing credentials and establishing a session. + /// + /// If the server returns an "authorization" header, it is automatically parsed and set as + /// a token for future requests. Any other data returned by the server in the handshake + /// response is returned as a binary blob. + async fn handshake( + &mut self, + username: &str, + password: &str, + ) -> std::result::Result { + let cmd = HandshakeRequest { + protocol_version: 0, + payload: Default::default(), + }; + let mut req = tonic::Request::new(stream::iter(vec![cmd])); + let val = BASE64_STANDARD.encode(format!("{username}:{password}")); + let val = format!("Basic {val}") + .parse() + .map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?; + req.metadata_mut().insert("authorization", val); + let req = self.set_request_headers(req)?; + let resp = self + .flight_client + .handshake(req) + .await + .map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?; + if let Some(auth) = resp.metadata().get("authorization") { + let auth = auth + .to_str() + .map_err(|_| ArrowError::ParseError("Can't read auth header".to_string()))?; + let bearer = "Bearer "; + if !auth.starts_with(bearer) { + Err(ArrowError::ParseError("Invalid auth header!".to_string()))?; + } + let auth = auth[bearer.len()..].to_string(); + self.token = Some(auth); + } + let responses: Vec = resp + .into_inner() + .try_collect() + .await + .map_err(|_| ArrowError::ParseError("Can't collect responses".to_string()))?; + let resp = match responses.as_slice() { + [resp] => resp.payload.clone(), + [] => Bytes::new(), + _ => Err(ArrowError::ParseError( + "Multiple handshake responses".to_string(), + ))?, + }; + Ok(resp) + } + + async fn execute( + &mut self, + query: String, + transaction_id: Option, + ) -> std::result::Result { + let cmd = CommandStatementQuery { + query, + transaction_id, + }; + self.get_flight_info_for_command(cmd).await + } + + async fn get_flight_info_for_command( + &mut self, + cmd: M, + ) -> std::result::Result { + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let req = self.set_request_headers(descriptor.into_request())?; + let fi = self + .flight_client + .get_flight_info(req) + .await + .map_err(|status| ArrowError::IpcError(format!("{status:?}")))? + .into_inner(); + Ok(fi) + } + + fn set_header(&mut self, key: impl Into, value: impl Into) { + let key: String = key.into(); + let value: String = value.into(); + self.headers.insert(key, value); + } + + fn set_request_headers( + &self, + mut req: tonic::Request, + ) -> std::result::Result, ArrowError> { + for (k, v) in &self.headers { + let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| { + ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}")) + })?; + let v = v.parse().map_err(|e| { + ArrowError::ParseError(format!("Cannot convert header value \"{v}\": {e}")) + })?; + req.metadata_mut().insert(k, v); + } + if let Some(token) = &self.token { + let val = format!("Bearer {token}").parse().map_err(|e| { + ArrowError::ParseError(format!("Cannot convert token to header value: {e}")) + })?; + req.metadata_mut().insert("authorization", val); + } + Ok(req) + } +} +///////////////////////////////////////////////////////////////////////// + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::net::SocketAddr; + use std::pin::Pin; + use std::sync::Arc; + use std::time::Duration; + + use arrow_array::{Array, Float32Array, Int64Array, Int8Array, RecordBatch}; + use arrow_flight::encode::FlightDataEncoderBuilder; + use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; + use arrow_flight::sql::server::FlightSqlService; + use arrow_flight::sql::{ + CommandStatementQuery, ProstMessageExt, SqlInfo, TicketStatementQuery, + }; + use arrow_flight::{ + FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, Ticket, + }; + use arrow_schema::{DataType, Field, Schema}; + use async_trait::async_trait; + use datafusion::prelude::SessionContext; + use futures::{stream, Stream, TryStreamExt}; + use prost::Message; + use tokio::net::TcpListener; + use tokio::sync::oneshot::{channel, Receiver, Sender}; + use tokio_stream::wrappers::TcpListenerStream; + use tonic::codegen::http::HeaderMap; + use tonic::codegen::tokio_stream; + use tonic::metadata::MetadataMap; + use tonic::transport::Server; + use tonic::{Extensions, Request, Response, Status, Streaming}; + + use crate::sql::FlightSqlDriver; + use crate::FlightTableFactory; + + const AUTH_HEADER: &str = "authorization"; + const BEARER_TOKEN: &str = "Bearer flight-sql-token"; + + struct TestFlightSqlService { + flight_info: FlightInfo, + partition_data: RecordBatch, + expected_handshake_headers: HashMap, + expected_flight_info_query: String, + shutdown_sender: Option>, + } + + impl TestFlightSqlService { + async fn run_in_background(self, rx: Receiver<()>) -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let service = FlightServiceServer::new(self); + #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + tokio::spawn(async move { + Server::builder() + .timeout(Duration::from_secs(1)) + .add_service(service) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async { + rx.await.ok(); + }) + .await + .unwrap(); + }); + tokio::time::sleep(Duration::from_millis(25)).await; + addr + } + } + + impl Drop for TestFlightSqlService { + fn drop(&mut self) { + if let Some(tx) = self.shutdown_sender.take() { + tx.send(()).ok(); + } + } + } + + fn check_header(request: &Request, rpc: &str, header_name: &str, expected_value: &str) { + let actual_value = request + .metadata() + .get(header_name) + .unwrap_or_else(|| panic!("[{}] missing header `{}`", rpc, header_name)) + .to_str() + .unwrap_or_else(|e| { + panic!( + "[{}] error parsing value for header `{}`: {:?}", + rpc, header_name, e + ) + }); + assert_eq!( + actual_value, expected_value, + "[{}] unexpected value for header `{}`", + rpc, header_name + ) + } + + #[async_trait] + impl FlightSqlService for TestFlightSqlService { + type FlightService = TestFlightSqlService; + + async fn do_handshake( + &self, + request: Request>, + ) -> Result< + Response> + Send>>>, + Status, + > { + for (header_name, expected_value) in self.expected_handshake_headers.iter() { + check_header(&request, "do_handshake", header_name, expected_value); + } + Ok(Response::from_parts( + MetadataMap::from_headers(HeaderMap::from_iter([( + AUTH_HEADER.parse().unwrap(), + BEARER_TOKEN.parse().unwrap(), + )])), // the client should send this header back on the next request (i.e. GetFlightInfo) + Box::pin(tokio_stream::empty()), + Extensions::default(), + )) + } + + async fn get_flight_info_statement( + &self, + query: CommandStatementQuery, + request: Request, + ) -> Result, Status> { + let mut expected_flight_info_headers = self.expected_handshake_headers.clone(); + expected_flight_info_headers.insert(AUTH_HEADER.into(), BEARER_TOKEN.into()); + for (header_name, expected_value) in expected_flight_info_headers.iter() { + check_header(&request, "get_flight_info", header_name, expected_value); + } + assert_eq!( + query.query.to_lowercase(), + self.expected_flight_info_query.to_lowercase() + ); + Ok(Response::new(self.flight_info.clone())) + } + + async fn do_get_statement( + &self, + _ticket: TicketStatementQuery, + request: Request, + ) -> Result::DoGetStream>, Status> { + let data = self.partition_data.clone(); + let rb = async move { Ok(data) }; + check_header(&request, "do_get", "authorization", BEARER_TOKEN); + let stream = FlightDataEncoderBuilder::default() + .with_schema(self.partition_data.schema()) + .build(stream::once(rb)) + .map_err(|e| Status::from_error(Box::new(e))); + + Ok(Response::new(Box::pin(stream))) + } + + async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} + } + + #[tokio::test] + async fn test_flight_sql_data_source() -> datafusion::common::Result<()> { + let partition_data = RecordBatch::try_new( + Arc::new(Schema::new([ + Arc::new(Field::new("col1", DataType::Float32, false)), + Arc::new(Field::new("col2", DataType::Int8, false)), + ])), + vec![ + Arc::new(Float32Array::from(vec![0.0, 0.1, 0.2, 0.3])), + Arc::new(Int8Array::from(vec![10, 20, 30, 40])), + ], + ) + .unwrap(); + let rows_per_partition = partition_data.num_rows(); + + let query = "SELECT * FROM some_table"; + let ticket_payload = TicketStatementQuery::default().as_any().encode_to_vec(); + let endpoint_archetype = FlightEndpoint::default().with_ticket(Ticket::new(ticket_payload)); + let endpoints = vec![ + endpoint_archetype.clone(), + endpoint_archetype.clone(), + endpoint_archetype, + ]; + let num_partitions = endpoints.len(); + let flight_info = FlightInfo::default() + .try_with_schema(partition_data.schema().as_ref()) + .unwrap(); + let flight_info = endpoints + .into_iter() + .fold(flight_info, |fi, e| fi.with_endpoint(e)); + let (tx, rx) = channel(); + let service = TestFlightSqlService { + flight_info, + partition_data, + expected_handshake_headers: HashMap::from([ + (AUTH_HEADER.into(), "Basic YWRtaW46cGFzc3dvcmQ=".into()), + ("custom-hdr1".into(), "v1".into()), + ("custom-hdr2".into(), "v2".into()), + ]), + expected_flight_info_query: query.into(), + shutdown_sender: Some(tx), + }; + let port = service.run_in_background(rx).await.port(); + let ctx = SessionContext::new(); + ctx.state_ref().write().table_factories_mut().insert( + "FLIGHT_SQL".into(), + Arc::new(FlightTableFactory::new( + Arc::new(FlightSqlDriver::default()), + )), + ); + let _ = ctx + .sql(&format!( + r#" + CREATE EXTERNAL TABLE fsql STORED AS FLIGHT_SQL + LOCATION 'http://localhost:{port}' + OPTIONS( + 'flight.sql.username' 'admin', + 'flight.sql.password' 'password', + 'flight.sql.query' '{query}', + 'flight.sql.header.custom-hdr1' 'v1', + 'flight.sql.header.custom-hdr2' 'v2', + )"# + )) + .await + .unwrap(); + let df = ctx.sql("select col1 from fsql").await.unwrap(); + assert_eq!( + df.count().await.unwrap(), + rows_per_partition * num_partitions + ); + let df = ctx.sql("select sum(col2) from fsql").await?; + let rb = df + .collect() + .await? + .first() + .cloned() + .expect("no record batch"); + assert_eq!(rb.schema().fields.len(), 1); + let arr = rb + .column(0) + .as_any() + .downcast_ref::() + .expect("wrong type of column"); + assert_eq!(arr.iter().next().unwrap().unwrap(), 300); + Ok(()) + } +}