diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index ba7bbf0c2c..b9906b9a70 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -69,7 +69,6 @@ use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr} use crate::execution::operators::ExecutionError::GeneralError; use crate::execution::shuffle::CompressionCodec; use crate::execution::spark_plan::SparkPlan; -use crate::parquet::parquet_exec::init_datasource_exec; use crate::parquet::parquet_support::prepare_object_store; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::common::{ @@ -86,6 +85,7 @@ use datafusion::physical_expr::expressions::{Literal, StatsType}; use datafusion::physical_expr::window::WindowExpr; use datafusion::physical_expr::LexOrdering; +use crate::parquet::parquet_exec::init_datasource_exec; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec; use datafusion_comet_proto::spark_operator::SparkFilePartition; @@ -2501,19 +2501,27 @@ fn create_case_expr( #[cfg(test)] mod tests { - use std::{sync::Arc, task::Poll}; - use futures::{poll, StreamExt}; + use std::{sync::Arc, task::Poll}; use arrow::array::{Array, DictionaryArray, Int32Array, StringArray}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field, Fields, Schema}; + use datafusion::catalog::memory::DataSourceExec; + use datafusion::datasource::listing::PartitionedFile; + use datafusion::datasource::object_store::ObjectStoreUrl; + use datafusion::datasource::physical_plan::{FileGroup, FileScanConfigBuilder, ParquetSource}; + use datafusion::error::DataFusionError; use datafusion::logical_expr::ScalarUDF; + use datafusion::physical_plan::ExecutionPlan; use datafusion::{assert_batches_eq, physical_plan::common::collect, prelude::SessionContext}; + use tempfile::TempDir; use tokio::sync::mpsc; use crate::execution::{operators::InputBatch, planner::PhysicalPlanner}; use crate::execution::operators::ExecutionError; + use crate::parquet::parquet_support::SparkParquetOptions; + use crate::parquet::schema_adapter::SparkSchemaAdapterFactory; use datafusion_comet_proto::spark_expression::expr::ExprStruct; use datafusion_comet_proto::{ spark_expression::expr::ExprStruct::*, @@ -2522,6 +2530,7 @@ mod tests { spark_operator, spark_operator::{operator::OpStruct, Operator}, }; + use datafusion_comet_spark_expr::EvalMode; #[test] fn test_unpack_dictionary_primitive() { @@ -3081,4 +3090,93 @@ mod tests { } }); } + + /* + Testing a nested types scenario + + select arr[0].a, arr[0].c from ( + select array(named_struct('a', 1, 'b', 'n', 'c', 'x')) arr) + */ + #[tokio::test] + async fn test_nested_types() -> Result<(), DataFusionError> { + let session_ctx = SessionContext::new(); + + // generate test data in the temp folder + let test_data = "select make_array(named_struct('a', 1, 'b', 'n', 'c', 'x')) c0"; + let tmp_dir = TempDir::new()?; + let test_path = tmp_dir.path().to_str().unwrap().to_string(); + + let plan = session_ctx + .sql(test_data) + .await? + .create_physical_plan() + .await?; + + // Write parquet file into temp folder + session_ctx + .write_parquet(plan, test_path.clone(), None) + .await?; + + // Define schema Comet reads with + let required_schema = Schema::new(Fields::from(vec![Field::new( + "c0", + DataType::List( + Field::new( + "element", + DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("c", DataType::Utf8, true), + ] as Vec)), + true, + ) + .into(), + ), + true, + )])); + + // Register all parquet with temp data as file groups + let mut file_groups: Vec = vec![]; + for entry in std::fs::read_dir(&test_path)? { + let entry = entry?; + let path = entry.path(); + + if path.extension().and_then(|ext| ext.to_str()) == Some("parquet") { + if let Some(path_str) = path.to_str() { + file_groups.push(FileGroup::new(vec![PartitionedFile::from_path( + path_str.into(), + )?])); + } + } + } + + let source = Arc::new( + ParquetSource::default().with_schema_adapter_factory(Arc::new( + SparkSchemaAdapterFactory::new(SparkParquetOptions::new(EvalMode::Ansi, "", false)), + )), + ); + + let object_store_url = ObjectStoreUrl::local_filesystem(); + let file_scan_config = + FileScanConfigBuilder::new(object_store_url, required_schema.into(), source) + .with_file_groups(file_groups) + .build(); + + // Run native read + let scan = Arc::new(DataSourceExec::new(Arc::new(file_scan_config.clone()))); + let stream = scan.execute(0, session_ctx.task_ctx())?; + let result: Vec<_> = stream.collect().await; + + let actual = result.first().unwrap().as_ref().unwrap(); + + let expected = [ + "+----------------+", + "| c0 |", + "+----------------+", + "| [{a: 1, c: x}] |", + "+----------------+", + ]; + assert_batches_eq!(expected, &[actual.clone()]); + + Ok(()) + } } diff --git a/native/core/src/parquet/parquet_support.rs b/native/core/src/parquet/parquet_support.rs index 6bf0f0fe45..40fbb361b6 100644 --- a/native/core/src/parquet/parquet_support.rs +++ b/native/core/src/parquet/parquet_support.rs @@ -16,6 +16,8 @@ // under the License. use crate::execution::operators::ExecutionError; +use arrow::array::ListArray; +use arrow::compute::can_cast_types; use arrow::{ array::{ cast::AsArray, new_null_array, types::Int32Type, types::TimestampMicrosecondType, Array, @@ -156,6 +158,8 @@ fn cast_array( }; let from_type = array.data_type(); + // Try Comet specific handlers first, then arrow-rs cast if supported, + // return uncasted data otherwise match (from_type, to_type) { (Struct(_), Struct(_)) => Ok(cast_struct_to_struct( array.as_struct(), @@ -163,6 +167,21 @@ fn cast_array( to_type, parquet_options, )?), + (List(_), List(to_inner_type)) => { + let list_arr: &ListArray = array.as_list(); + let cast_field = cast_array( + Arc::clone(list_arr.values()), + to_inner_type.data_type(), + parquet_options, + )?; + + Ok(Arc::new(ListArray::new( + Arc::clone(to_inner_type), + list_arr.offsets().clone(), + cast_field, + list_arr.nulls().cloned(), + ))) + } (Timestamp(TimeUnit::Microsecond, None), Timestamp(TimeUnit::Microsecond, Some(tz))) => { Ok(Arc::new( array @@ -171,7 +190,11 @@ fn cast_array( .with_timezone(Arc::clone(tz)), )) } - _ => Ok(cast_with_options(&array, to_type, &PARQUET_OPTIONS)?), + // If Arrow cast supports the cast, delegate the cast to Arrow + _ if can_cast_types(from_type, to_type) => { + Ok(cast_with_options(&array, to_type, &PARQUET_OPTIONS)?) + } + _ => Ok(array), } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala index d9c71f147d..4115ba4321 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala @@ -224,4 +224,83 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper |""".stripMargin, "select c0 from tbl") } + test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - second field") { + testSingleLineQuery( + """ + | select array(str0, str1) c0 from + | ( + | select + | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, + | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 + | ) + |""".stripMargin, + "select c0[0].b col0 from tbl") + } + + test("native reader - read a STRUCT subfield - field from second") { + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "false", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") { + testSingleLineQuery( + """ + |select 1 a, named_struct('a', 1, 'b', 'n') c0 + |""".stripMargin, + "select c0.b from tbl") + } + } + + test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - field from first") { + testSingleLineQuery( + """ + | select array(str0, str1) c0 from + | ( + | select + | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, + | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 + | ) + |""".stripMargin, + "select c0[0].a, c0[0].b, c0[0].c from tbl") + } + + test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - reverse fields") { + testSingleLineQuery( + """ + | select array(str0, str1) c0 from + | ( + | select + | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, + | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 + | ) + |""".stripMargin, + "select c0[0].c, c0[0].b, c0[0].a from tbl") + } + + test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - skip field") { + testSingleLineQuery( + """ + | select array(str0, str1) c0 from + | ( + | select + | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, + | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 + | ) + |""".stripMargin, + "select c0[0].a, c0[0].c from tbl") + } + + test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - duplicate first field") { + testSingleLineQuery( + """ + | select array(str0, str1) c0 from + | ( + | select + | named_struct('a', 1, 'b', 'n', 'c', 'x') str0, + | named_struct('a', 2, 'b', 'w', 'c', 'y') str1 + | ) + |""".stripMargin, + "select c0[0].a, c0[0].a from tbl") + } }