diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 9059ae07e648..63423bc3dc53 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -22,7 +22,7 @@ mod struct_builder; use std::borrow::Borrow; use std::cmp::Ordering; -use std::collections::{HashSet, VecDeque}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::convert::Infallible; use std::fmt; use std::hash::Hash; @@ -297,6 +297,7 @@ pub enum ScalarValue { Union(Option<(i8, Box)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box, Box), + Extension(Arc, Arc, Option>), } impl Hash for Fl { @@ -420,6 +421,15 @@ impl PartialEq for ScalarValue { (Dictionary(_, _), _) => false, (Null, Null) => true, (Null, _) => false, + ( + Extension(storage, extension_name, extension_metadata), + Extension(storage2, extension_name2, extension_metadata2), + ) => { + storage == storage2 + && extension_name == extension_name2 + && extension_metadata == extension_metadata2 + } + (Extension(_, _, _), _) => false, } } } @@ -571,6 +581,10 @@ impl PartialOrd for ScalarValue { (Dictionary(_, _), _) => None, (Null, Null) => Some(Ordering::Equal), (Null, _) => None, + (Extension(_, _, _), _) => { + // Treat Extension scalars as Opaque storage with undefined ordering + None + } } } } @@ -765,6 +779,11 @@ impl Hash for ScalarValue { } // stable hash for Null value Null => 1.hash(state), + Extension(storage, name, metadata) => { + storage.hash(state); + name.hash(state); + metadata.hash(state); + } } } } @@ -1393,6 +1412,28 @@ impl ScalarValue { }) } + /// return the [`Field`] of this `ScalarValue` + pub fn field(&self) -> Field { + match self { + ScalarValue::Extension(storage, name, metadata) => { + let mut metadata_fields = HashMap::from([( + "ARROW:extension:name".to_string(), + name.to_string(), + )]); + + if let Some(metadata_value) = metadata { + metadata_fields.insert( + "ARROW:extension:metadata".to_string(), + metadata_value.to_string(), + ); + } + + Field::new("", storage.data_type(), true).with_metadata(metadata_fields) + }, + _ => Field::new("", self.data_type(), true) + } + } + /// return the [`DataType`] of this `ScalarValue` pub fn data_type(&self) -> DataType { match self { @@ -1466,6 +1507,10 @@ impl ScalarValue { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } ScalarValue::Null => DataType::Null, + ScalarValue::Extension(storage, _, _) => { + // TODO: drops extension information + storage.data_type() + } } } @@ -1724,6 +1769,7 @@ impl ScalarValue { None => true, }, ScalarValue::Dictionary(_, v) => v.is_null(), + ScalarValue::Extension(storage, _, _) => storage.is_null(), } } @@ -2642,6 +2688,10 @@ impl ScalarValue { } } ScalarValue::Null => new_null_array(&DataType::Null, size), + ScalarValue::Extension(storage, _, _) => { + // TODO: Drops extension information + storage.to_array_of_size(size)? + } }) } @@ -3277,6 +3327,11 @@ impl ScalarValue { } } ScalarValue::Null => array.is_null(index), + ScalarValue::Extension(storage, _, _) => { + // TODO: Drops extension information (will compare equal to storage + // whether or not the storage is from an extension type or not) + storage.eq_array(array, index)? + } }) } @@ -3353,6 +3408,14 @@ impl ScalarValue { // `dt` and `sv` are boxed, so they are NOT already included in `self` dt.size() + sv.size() } + ScalarValue::Extension(storage, name, metadata) => { + storage.size() + + name.len() + + match metadata { + Some(value) => value.len(), + None => 0, + } + } } } @@ -3743,6 +3806,12 @@ impl fmt::Display for ScalarValue { }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, + ScalarValue::Extension(storage, name, metadata) => match metadata { + Some(metadata_value) => { + write!(f, "<{}: {}> {}", name, metadata_value, storage)? + } + None => write!(f, "<{}> {}", name, storage)?, + }, }; Ok(()) } @@ -3920,6 +3989,12 @@ impl fmt::Debug for ScalarValue { }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), ScalarValue::Null => write!(f, "NULL"), + ScalarValue::Extension(storage, name, metadata) => match metadata { + Some(metadata_value) => { + write!(f, "Extension<{}: {}>({:?})", name, metadata_value, storage) + } + None => write!(f, "Extension<{}>({:?})", name, storage), + }, } } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 3323ea1614fd..5e0791cd384e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -29,7 +29,7 @@ use crate::utils::expr_to_columns; use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; -use arrow::datatypes::{DataType, FieldRef}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, @@ -228,7 +228,7 @@ pub enum Expr { /// A named reference to a qualified field in a schema. Column(Column), /// A named reference to a variable in a registry. - ScalarVariable(DataType, Vec), + ScalarVariable(Field, Vec), /// A constant value. Literal(ScalarValue), /// A binary expression such as "age > 21" @@ -327,7 +327,7 @@ pub enum Expr { Placeholder(Placeholder), /// A place holder which hold a reference to a qualified field /// in the outer query, used for correlated sub queries. - OuterReferenceColumn(DataType, Column), + OuterReferenceColumn(Field, Column), /// Unnest expression Unnest(Unnest), } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a8e7fd76d037..39c425869558 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -68,7 +68,7 @@ pub fn col(ident: impl Into) -> Expr { /// Create an out reference column which hold a reference that has been resolved to a field /// outside of the current plan. pub fn out_ref_col(dt: DataType, ident: impl Into) -> Expr { - Expr::OuterReferenceColumn(dt, ident.into()) + Expr::OuterReferenceColumn(Field::new("", dt, true), ident.into()) } /// Create an unqualified column expression from the provided name, without normalizing diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 0a14cb5c60a0..67620b7b8b00 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -113,8 +113,8 @@ impl ExprSchemable for Expr { }, Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), - Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), - Expr::ScalarVariable(ty, _) => Ok(ty.clone()), + Expr::OuterReferenceColumn(field, _) => Ok(field.data_type().clone()), + Expr::ScalarVariable(field, _) => Ok(field.data_type().clone()), Expr::Literal(l) => Ok(l.data_type()), Expr::Case(case) => { for (_, then_expr) in &case.when_then_expr { @@ -345,6 +345,8 @@ impl ExprSchemable for Expr { Expr::Column(c) => Ok(schema.metadata(c)?.clone()), Expr::Alias(Alias { expr, .. }) => expr.metadata(schema), Expr::Cast(Cast { expr, .. }) => expr.metadata(schema), + Expr::ScalarVariable(field, _) => Ok(field.metadata().clone()), + Expr::OuterReferenceColumn(field, _) => Ok(field.metadata().clone()), _ => Ok(HashMap::new()), } } @@ -377,8 +379,12 @@ impl ExprSchemable for Expr { Expr::Column(c) => schema .data_type_and_nullable(c) .map(|(d, n)| (d.clone(), n)), - Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), - Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), + Expr::OuterReferenceColumn(field, _) => { + Ok((field.data_type().clone(), field.is_nullable())) + } + Expr::ScalarVariable(field, _) => { + Ok((field.data_type().clone(), field.is_nullable())) + } Expr::Literal(l) => Ok((l.data_type(), l.is_null())), Expr::IsNull(_) | Expr::IsNotNull(_) @@ -463,10 +469,17 @@ impl ExprSchemable for Expr { ) -> Result<(Option, Arc)> { let (relation, schema_name) = self.qualified_name(); let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; - let field = Field::new(schema_name, data_type, nullable) - .with_metadata(self.metadata(input_schema)?) - .into(); - Ok((relation, field)) + let field = match self { + Expr::ScalarVariable(field, _) | Expr::OuterReferenceColumn(field, _) => { + field.clone().with_name(schema_name) + } + _ => Field::new(schema_name, data_type, nullable), + }; + + Ok(( + relation, + field.with_metadata(self.metadata(input_schema)?).into(), + )) } /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index decd0cf63038..aa7580367ed5 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -635,6 +635,9 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { ))), }) } + ScalarValue::Extension(_, _, _) => Err(Error::General(format!( + "Proto serialization error: {val} not yet supported" + ))), } } } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 7d358d0b6624..5522208b0b7c 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -38,13 +38,14 @@ impl SqlToRel<'_, S> { if id.value.starts_with('@') { // TODO: figure out if ScalarVariables should be insensitive. let var_names = vec![id.value]; + // TODO: dropping extension information let ty = self .context_provider .get_variable_type(&var_names) .ok_or_else(|| { plan_datafusion_err!("variable {var_names:?} has no type information") })?; - Ok(Expr::ScalarVariable(ty, var_names)) + Ok(Expr::ScalarVariable(Field::new("", ty, true), var_names)) } else { // Don't use `col()` here because it will try to // interpret names with '.' as if they were @@ -75,7 +76,7 @@ impl SqlToRel<'_, S> { { // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column return Ok(Expr::OuterReferenceColumn( - field.data_type().clone(), + field.clone().with_name("").with_nullable(true), Column::from((qualifier, field)), )); } @@ -112,6 +113,7 @@ impl SqlToRel<'_, S> { .into_iter() .map(|id| self.ident_normalizer.normalize(id)) .collect(); + // TODO: dropping extension information let ty = self .context_provider .get_variable_type(&var_names) @@ -120,7 +122,7 @@ impl SqlToRel<'_, S> { "variable {var_names:?} has no type information" )) })?; - Ok(Expr::ScalarVariable(ty, var_names)) + Ok(Expr::ScalarVariable(Field::new("", ty, true), var_names)) } else { let ids = ids .into_iter() @@ -182,7 +184,7 @@ impl SqlToRel<'_, S> { Some((field, qualifier, _nested_names)) => { // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column Ok(Expr::OuterReferenceColumn( - field.data_type().clone(), + field.clone().with_name("").with_nullable(true), Column::from((qualifier, field)), )) } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index bf6361312727..365433604e13 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1283,6 +1283,7 @@ impl Unparser<'_> { ScalarValue::Map(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(_k, v) => self.scalar_to_sql(v), + ScalarValue::Extension(_, _, _) => not_impl_err!("Unsupported scalar: {v:?}"), } } @@ -2011,12 +2012,15 @@ mod tests { r#"TRY_CAST(a AS INTEGER UNSIGNED)"#, ), ( - Expr::ScalarVariable(Int8, vec![String::from("@a")]), + Expr::ScalarVariable( + Field::new("", Int8, true), + vec![String::from("@a")], + ), r#"@a"#, ), ( Expr::ScalarVariable( - Int8, + Field::new("", Int8, true), vec![String::from("@root"), String::from("foo")], ), r#"@root.foo"#,