diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index bdb702375c94..8612a1cc4430 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -131,7 +131,7 @@ impl FileFormat for TSVFileFormat { } } -#[derive(Default)] +#[derive(Default, Debug)] /// Factory for creating TSV file formats /// /// This factory is a wrapper around the CSV file format factory @@ -166,6 +166,10 @@ impl FileFormatFactory for TSVFileFactory { fn default(&self) -> std::sync::Arc { todo!() } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for TSVFileFactory { diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 6bcbd4347682..8b6a8800119d 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -66,7 +66,7 @@ const INITIAL_BUFFER_BYTES: usize = 1048576; /// If the buffered Arrow data exceeds this size, it is flushed to object store const BUFFER_FLUSH_BYTES: usize = 1024000; -#[derive(Default)] +#[derive(Default, Debug)] /// Factory struct used to create [ArrowFormat] pub struct ArrowFormatFactory; @@ -89,6 +89,10 @@ impl FileFormatFactory for ArrowFormatFactory { fn default(&self) -> Arc { Arc::new(ArrowFormat) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for ArrowFormatFactory { diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index f4f9adcba7ed..5190bdbe153a 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::collections::HashMap; +use std::fmt; use std::sync::Arc; use arrow::datatypes::Schema; @@ -64,6 +65,16 @@ impl FileFormatFactory for AvroFormatFactory { fn default(&self) -> Arc { Arc::new(AvroFormat) } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl fmt::Debug for AvroFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AvroFormatFactory").finish() + } } impl GetExt for AvroFormatFactory { diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 958d2694aa04..e1b6daac092d 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -58,7 +58,8 @@ use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore} #[derive(Default)] /// Factory struct used to create [CsvFormatFactory] pub struct CsvFormatFactory { - options: Option, + /// the options for csv file read + pub options: Option, } impl CsvFormatFactory { @@ -75,6 +76,14 @@ impl CsvFormatFactory { } } +impl fmt::Debug for CsvFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CsvFormatFactory") + .field("options", &self.options) + .finish() + } +} + impl FileFormatFactory for CsvFormatFactory { fn create( &self, @@ -103,6 +112,10 @@ impl FileFormatFactory for CsvFormatFactory { fn default(&self) -> Arc { Arc::new(CsvFormat::default()) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for CsvFormatFactory { diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 007b084f504d..9de9c3d7d871 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -102,6 +102,10 @@ impl FileFormatFactory for JsonFormatFactory { fn default(&self) -> Arc { Arc::new(JsonFormat::default()) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for JsonFormatFactory { @@ -111,6 +115,14 @@ impl GetExt for JsonFormatFactory { } } +impl fmt::Debug for JsonFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JsonFormatFactory") + .field("options", &self.options) + .finish() + } +} + /// New line delimited JSON `FileFormat` implementation. #[derive(Debug, Default)] pub struct JsonFormat { diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 1aa93a106aff..500f20af474f 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -49,11 +49,11 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use async_trait::async_trait; use file_compression_type::FileCompressionType; use object_store::{ObjectMeta, ObjectStore}; - +use std::fmt::Debug; /// Factory for creating [`FileFormat`] instances based on session and command level options /// /// Users can provide their own `FileFormatFactory` to support arbitrary file formats -pub trait FileFormatFactory: Sync + Send + GetExt { +pub trait FileFormatFactory: Sync + Send + GetExt + Debug { /// Initialize a [FileFormat] and configure based on session and command level options fn create( &self, @@ -63,6 +63,10 @@ pub trait FileFormatFactory: Sync + Send + GetExt { /// Initialize a [FileFormat] with all options set to default values fn default(&self) -> Arc; + + /// Returns the table source as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; } /// This trait abstracts all the file format specific implementations @@ -138,6 +142,7 @@ pub trait FileFormat: Send + Sync + fmt::Debug { /// The former trait is a superset of the latter trait, which includes execution time /// relevant methods. [FileType] is only used in logical planning and only implements /// the subset of methods required during logical planning. +#[derive(Debug)] pub struct DefaultFileType { file_format_factory: Arc, } @@ -149,6 +154,11 @@ impl DefaultFileType { file_format_factory, } } + + /// get a reference to the inner [FileFormatFactory] struct + pub fn as_format_factory(&self) -> &Arc { + &self.file_format_factory + } } impl FileType for DefaultFileType { @@ -159,7 +169,7 @@ impl FileType for DefaultFileType { impl Display for DefaultFileType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.file_format_factory.default().fmt(f) + write!(f, "{:?}", self.file_format_factory) } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index d4e77b911c9f..3250b59fa1d1 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -140,6 +140,10 @@ impl FileFormatFactory for ParquetFormatFactory { fn default(&self) -> Arc { Arc::new(ParquetFormat::default()) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for ParquetFormatFactory { @@ -149,6 +153,13 @@ impl GetExt for ParquetFormatFactory { } } +impl fmt::Debug for ParquetFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetFormatFactory") + .field("ParquetFormatFactory", &self.options) + .finish() + } +} /// The Apache Parquet `FileFormat` implementation #[derive(Debug, Default)] pub struct ParquetFormat { diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 09e36a650b9f..2c4085b88869 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -18,19 +18,129 @@ use std::sync::Arc; use datafusion::{ + config::CsvOptions, datasource::file_format::{ arrow::ArrowFormatFactory, csv::CsvFormatFactory, json::JsonFormatFactory, parquet::ParquetFormatFactory, FileFormatFactory, }, prelude::SessionContext, }; -use datafusion_common::{not_impl_err, TableReference}; +use datafusion_common::{ + exec_err, not_impl_err, parsers::CompressionTypeVariant, DataFusionError, + TableReference, +}; +use prost::Message; + +use crate::protobuf::CsvOptions as CsvOptionsProto; use super::LogicalExtensionCodec; #[derive(Debug)] pub struct CsvLogicalExtensionCodec; +impl CsvOptionsProto { + fn from_factory(factory: &CsvFormatFactory) -> Self { + if let Some(options) = &factory.options { + CsvOptionsProto { + has_header: options.has_header.map_or(vec![], |v| vec![v as u8]), + delimiter: vec![options.delimiter], + quote: vec![options.quote], + escape: options.escape.map_or(vec![], |v| vec![v]), + double_quote: options.double_quote.map_or(vec![], |v| vec![v as u8]), + compression: options.compression as i32, + schema_infer_max_rec: options.schema_infer_max_rec as u64, + date_format: options.date_format.clone().unwrap_or_default(), + datetime_format: options.datetime_format.clone().unwrap_or_default(), + timestamp_format: options.timestamp_format.clone().unwrap_or_default(), + timestamp_tz_format: options + .timestamp_tz_format + .clone() + .unwrap_or_default(), + time_format: options.time_format.clone().unwrap_or_default(), + null_value: options.null_value.clone().unwrap_or_default(), + comment: options.comment.map_or(vec![], |v| vec![v]), + newlines_in_values: options + .newlines_in_values + .map_or(vec![], |v| vec![v as u8]), + } + } else { + CsvOptionsProto::default() + } + } +} + +impl From<&CsvOptionsProto> for CsvOptions { + fn from(proto: &CsvOptionsProto) -> Self { + CsvOptions { + has_header: if !proto.has_header.is_empty() { + Some(proto.has_header[0] != 0) + } else { + None + }, + delimiter: proto.delimiter.first().copied().unwrap_or(b','), + quote: proto.quote.first().copied().unwrap_or(b'"'), + escape: if !proto.escape.is_empty() { + Some(proto.escape[0]) + } else { + None + }, + double_quote: if !proto.double_quote.is_empty() { + Some(proto.double_quote[0] != 0) + } else { + None + }, + compression: match proto.compression { + 0 => CompressionTypeVariant::GZIP, + 1 => CompressionTypeVariant::BZIP2, + 2 => CompressionTypeVariant::XZ, + 3 => CompressionTypeVariant::ZSTD, + _ => CompressionTypeVariant::UNCOMPRESSED, + }, + schema_infer_max_rec: proto.schema_infer_max_rec as usize, + date_format: if proto.date_format.is_empty() { + None + } else { + Some(proto.date_format.clone()) + }, + datetime_format: if proto.datetime_format.is_empty() { + None + } else { + Some(proto.datetime_format.clone()) + }, + timestamp_format: if proto.timestamp_format.is_empty() { + None + } else { + Some(proto.timestamp_format.clone()) + }, + timestamp_tz_format: if proto.timestamp_tz_format.is_empty() { + None + } else { + Some(proto.timestamp_tz_format.clone()) + }, + time_format: if proto.time_format.is_empty() { + None + } else { + Some(proto.time_format.clone()) + }, + null_value: if proto.null_value.is_empty() { + None + } else { + Some(proto.null_value.clone()) + }, + comment: if !proto.comment.is_empty() { + Some(proto.comment[0]) + } else { + None + }, + newlines_in_values: if proto.newlines_in_values.is_empty() { + None + } else { + Some(proto.newlines_in_values[0] != 0) + }, + } + } +} + // TODO! This is a placeholder for now and needs to be implemented for real. impl LogicalExtensionCodec for CsvLogicalExtensionCodec { fn try_decode( @@ -73,17 +183,41 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { fn try_decode_file_format( &self, - __buf: &[u8], - __ctx: &SessionContext, + buf: &[u8], + _ctx: &SessionContext, ) -> datafusion_common::Result> { - Ok(Arc::new(CsvFormatFactory::new())) + let proto = CsvOptionsProto::decode(buf).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode CsvOptionsProto: {:?}", + e + )) + })?; + let options: CsvOptions = (&proto).into(); + Ok(Arc::new(CsvFormatFactory { + options: Some(options), + })) } fn try_encode_file_format( &self, - __buf: &[u8], - __node: Arc, + buf: &mut Vec, + node: Arc, ) -> datafusion_common::Result<()> { + let options = + if let Some(csv_factory) = node.as_any().downcast_ref::() { + csv_factory.options.clone().unwrap_or_default() + } else { + return exec_err!("{}", "Unsupported FileFormatFactory type".to_string()); + }; + + let proto = CsvOptionsProto::from_factory(&CsvFormatFactory { + options: Some(options), + }); + + proto.encode(buf).map_err(|e| { + DataFusionError::Execution(format!("Failed to encode CsvOptions: {:?}", e)) + })?; + Ok(()) } } @@ -141,7 +275,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) @@ -201,7 +335,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) @@ -261,7 +395,7 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) @@ -321,7 +455,7 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 2a963fb13ccf..5427f34e8e07 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -131,7 +131,7 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { fn try_encode_file_format( &self, - _buf: &[u8], + _buf: &mut Vec, _node: Arc, ) -> Result<()> { Ok(()) @@ -1666,10 +1666,9 @@ impl AsLogicalPlan for LogicalPlanNode { input, extension_codec, )?; - - let buf = Vec::new(); + let mut buf = Vec::new(); extension_codec - .try_encode_file_format(&buf, file_type_to_format(file_type)?)?; + .try_encode_file_format(&mut buf, file_type_to_format(file_type)?)?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f6557c7b2d8f..e17515086ecd 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -15,12 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::collections::HashMap; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; -use std::vec; - use arrow::array::{ ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, StringBuilder, }; @@ -30,11 +24,16 @@ use arrow::datatypes::{ DECIMAL256_MAX_PRECISION, }; use prost::Message; +use std::any::Any; +use std::collections::HashMap; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; +use std::vec; use datafusion::datasource::file_format::arrow::ArrowFormatFactory; use datafusion::datasource::file_format::csv::CsvFormatFactory; -use datafusion::datasource::file_format::format_as_file_type; use datafusion::datasource::file_format::parquet::ParquetFormatFactory; +use datafusion::datasource::file_format::{format_as_file_type, DefaultFileType}; use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::session_state::SessionStateBuilder; @@ -380,7 +379,9 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { parquet_format.global.dictionary_page_size_limit = 444; parquet_format.global.max_row_group_size = 555; - let file_type = format_as_file_type(Arc::new(ParquetFormatFactory::new())); + let file_type = format_as_file_type(Arc::new( + ParquetFormatFactory::new_with_options(parquet_format), + )); let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), @@ -395,7 +396,6 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { let logical_round_trip = logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); - match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.parquet", copy_to.output_url); @@ -458,7 +458,9 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { csv_format.time_format = Some("HH:mm:ss".to_string()); csv_format.null_value = Some("NIL".to_string()); - let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new())); + let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new_with_options( + csv_format.clone(), + ))); let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), @@ -479,6 +481,27 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { assert_eq!("test.csv", copy_to.output_url); assert_eq!("csv".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); + + let file_type = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + + let format_factory = file_type.as_format_factory(); + let csv_factory = format_factory + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let csv_config = csv_factory.options.as_ref().unwrap(); + assert_eq!(csv_format.delimiter, csv_config.delimiter); + assert_eq!(csv_format.date_format, csv_config.date_format); + assert_eq!(csv_format.datetime_format, csv_config.datetime_format); + assert_eq!(csv_format.timestamp_format, csv_config.timestamp_format); + assert_eq!(csv_format.time_format, csv_config.time_format); + assert_eq!(csv_format.null_value, csv_config.null_value) } _ => panic!(), }