Skip to content

Commit

Permalink
fix: coalesce schema issues (#12308)
Browse files Browse the repository at this point in the history
closes #12307
  • Loading branch information
mesejo committed Sep 27, 2024
1 parent 9a3f8d1 commit 1b3608d
Show file tree
Hide file tree
Showing 14 changed files with 335 additions and 128 deletions.
37 changes: 37 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2029,6 +2029,43 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_coalesce_schema() -> Result<()> {
let ctx = SessionContext::new();

let query = r#"SELECT COALESCE(null, 5)"#;

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn test_coalesce_from_values_schema() -> Result<()> {
let ctx = SessionContext::new();

let query = r#"SELECT COALESCE(column1, column2) FROM VALUES (null, 1.2)"#;

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn test_coalesce_from_values_schema_multiple_rows() -> Result<()> {
let ctx = SessionContext::new();

let query = r#"SELECT COALESCE(column1, column2)
FROM VALUES
(null, 1.2),
(1.1, null),
(2, 5);"#;

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn test_array_agg_schema() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
23 changes: 12 additions & 11 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,22 @@ impl ExprSchemable for Expr {
.collect::<Result<Vec<_>>>()?;

// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| {
plan_datafusion_err!(
"{} {}",
err,
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_data_types,
let new_data_types = data_types_with_scalar_udf(&arg_data_types, func)
.map_err(|err| {
plan_datafusion_err!(
"{} {}",
err,
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_data_types,
)
)
)
})?;
})?;

// perform additional function arguments validation (due to limited
// expressiveness of `TypeSignature`), then infer return type
Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?)
Ok(func.return_type_from_exprs(args, schema, &new_data_types)?)
}
Expr::WindowFunction(window_function) => self
.data_type_and_nullable_with_window_function(schema, window_function)
Expand Down
4 changes: 3 additions & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ impl LogicalPlanBuilder {
common_type = Some(data_type);
}
}
field_types.push(common_type.unwrap_or(DataType::Utf8));
// assuming common_type was not set, and no error, therefore the type should be NULL
// since the code loop skips NULL
field_types.push(common_type.unwrap_or(DataType::Null));
}
// wrap cast if data type is not same as common type.
for row in &mut values {
Expand Down
30 changes: 26 additions & 4 deletions datafusion/functions/src/core/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use arrow::array::{new_null_array, BooleanArray};
use arrow::compute::kernels::zip::zip;
use arrow::compute::{and, is_not_null, is_null};
use arrow::datatypes::DataType;

use datafusion_common::{exec_err, ExprSchema, Result};
use datafusion_expr::type_coercion::binary::type_union_resolution;
use datafusion_expr::{ColumnarValue, Expr, ExprSchemable};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use itertools::Itertools;

#[derive(Debug)]
pub struct CoalesceFunc {
Expand Down Expand Up @@ -60,12 +60,16 @@ impl ScalarUDFImpl for CoalesceFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[0].clone())
Ok(arg_types
.iter()
.find_or_first(|d| !d.is_null())
.unwrap()
.clone())
}

// If all the element in coalesce is non-null, the result is non-null
// If any the arguments in coalesce is non-null, the result is non-null
fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true))
args.iter().all(|e| e.nullable(schema).ok().unwrap_or(true))
}

/// coalesce evaluates to the first value which is not NULL
Expand Down Expand Up @@ -154,4 +158,22 @@ mod test {
.unwrap();
assert_eq!(return_type, DataType::Date32);
}

#[test]
fn test_coalesce_return_types_with_nulls_first() {
let coalesce = core::coalesce::CoalesceFunc::new();
let return_type = coalesce
.return_type(&[DataType::Null, DataType::Date32])
.unwrap();
assert_eq!(return_type, DataType::Date32);
}

#[test]
fn test_coalesce_return_types_with_nulls_last() {
let coalesce = core::coalesce::CoalesceFunc::new();
let return_type = coalesce
.return_type(&[DataType::Int64, DataType::Null])
.unwrap();
assert_eq!(return_type, DataType::Int64);
}
}
57 changes: 30 additions & 27 deletions datafusion/functions/src/datetime/to_local_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,16 @@ use std::sync::Arc;
use arrow::array::timezone::Tz;
use arrow::array::{Array, ArrayRef, PrimitiveBuilder};
use arrow::datatypes::DataType::Timestamp;
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
use arrow::datatypes::{
ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType,
};
use arrow::datatypes::{
TimeUnit,
TimeUnit::{Microsecond, Millisecond, Nanosecond, Second},
};

use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc};
use datafusion_common::cast::as_primitive_array;
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{
ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD,
};
use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

/// A UDF function that converts a timezone-aware timestamp to local time (with no offset or
/// timezone information). In other words, this function strips off the timezone from the timestamp,
Expand All @@ -55,20 +49,8 @@ impl Default for ToLocalTimeFunc {

impl ToLocalTimeFunc {
pub fn new() -> Self {
let base_sig = |array_type: TimeUnit| {
[
Exact(vec![Timestamp(array_type, None)]),
Exact(vec![Timestamp(array_type, Some(TIMEZONE_WILDCARD.into()))]),
]
};

let full_sig = [Nanosecond, Microsecond, Millisecond, Second]
.into_iter()
.flat_map(base_sig)
.collect::<Vec<_>>();

Self {
signature: Signature::one_of(full_sig, Volatility::Immutable),
signature: Signature::user_defined(Volatility::Immutable),
}
}

Expand Down Expand Up @@ -328,13 +310,10 @@ impl ScalarUDFImpl for ToLocalTimeFunc {
}

match &arg_types[0] {
Timestamp(Nanosecond, _) => Ok(Timestamp(Nanosecond, None)),
Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)),
Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)),
Timestamp(Second, _) => Ok(Timestamp(Second, None)),
Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)),
_ => exec_err!(
"The to_local_time function can only accept timestamp as the arg, got {:?}", arg_types[0]
),
)
}
}

Expand All @@ -348,6 +327,30 @@ impl ScalarUDFImpl for ToLocalTimeFunc {

self.to_local_time(args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return plan_err!(
"to_local_time function requires 1 argument, got {:?}",
arg_types.len()
);
}

let first_arg = arg_types[0].clone();
match &first_arg {
Timestamp(Nanosecond, timezone) => {
Ok(vec![Timestamp(Nanosecond, timezone.clone())])
}
Timestamp(Microsecond, timezone) => {
Ok(vec![Timestamp(Microsecond, timezone.clone())])
}
Timestamp(Millisecond, timezone) => {
Ok(vec![Timestamp(Millisecond, timezone.clone())])
}
Timestamp(Second, timezone) => Ok(vec![Timestamp(Second, timezone.clone())]),
_ => plan_err!("The to_local_time function can only accept Timestamp as the arg got {first_arg}"),
}
}
}

#[cfg(test)]
Expand Down
103 changes: 58 additions & 45 deletions datafusion/functions/src/encoding/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use datafusion_expr::ColumnarValue;
use std::sync::Arc;
use std::{fmt, str::FromStr};

use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;

Expand All @@ -49,17 +48,8 @@ impl Default for EncodeFunc {

impl EncodeFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Exact(vec![LargeBinary, Utf8]),
],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
Expand All @@ -77,23 +67,39 @@ impl ScalarUDFImpl for EncodeFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;

Ok(match arg_types[0] {
Utf8 => Utf8,
LargeUtf8 => LargeUtf8,
Binary => Utf8,
LargeBinary => LargeUtf8,
Null => Null,
_ => {
return plan_err!("The encode function can only accept utf8 or binary.");
}
})
Ok(arg_types[0].to_owned())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
encode(args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return plan_err!(
"{} expects to get 2 arguments, but got {}",
self.name(),
arg_types.len()
);
}

if arg_types[1] != DataType::Utf8 {
return Err(DataFusionError::Plan("2nd argument should be Utf8".into()));
}

match arg_types[0] {
DataType::Utf8 | DataType::Binary | DataType::Null => {
Ok(vec![DataType::Utf8; 2])
}
DataType::LargeUtf8 | DataType::LargeBinary => {
Ok(vec![DataType::LargeUtf8, DataType::Utf8])
}
_ => plan_err!(
"1st argument should be Utf8 or Binary or Null, got {:?}",
arg_types[0]
),
}
}
}

#[derive(Debug)]
Expand All @@ -109,17 +115,8 @@ impl Default for DecodeFunc {

impl DecodeFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Exact(vec![LargeBinary, Utf8]),
],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
Expand All @@ -137,23 +134,39 @@ impl ScalarUDFImpl for DecodeFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;

Ok(match arg_types[0] {
Utf8 => Binary,
LargeUtf8 => LargeBinary,
Binary => Binary,
LargeBinary => LargeBinary,
Null => Null,
_ => {
return plan_err!("The decode function can only accept utf8 or binary.");
}
})
Ok(arg_types[0].to_owned())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
decode(args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return plan_err!(
"{} expects to get 2 arguments, but got {}",
self.name(),
arg_types.len()
);
}

if arg_types[1] != DataType::Utf8 {
return plan_err!("2nd argument should be Utf8");
}

match arg_types[0] {
DataType::Utf8 | DataType::Binary | DataType::Null => {
Ok(vec![DataType::Binary, DataType::Utf8])
}
DataType::LargeUtf8 | DataType::LargeBinary => {
Ok(vec![DataType::LargeBinary, DataType::Utf8])
}
_ => plan_err!(
"1st argument should be Utf8 or Binary or Null, got {:?}",
arg_types[0]
),
}
}
}

#[derive(Debug, Copy, Clone)]
Expand Down
Loading

0 comments on commit 1b3608d

Please sign in to comment.