Skip to content

Commit

Permalink
Fix coalesce expr_fn function to take multiple arguments (#10321)
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb authored May 7, 2024
1 parent 89443bf commit 40a2055
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 12 deletions.
177 changes: 176 additions & 1 deletion datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion::error::Result;
use datafusion::prelude::*;

use datafusion::assert_batches_eq;
use datafusion_common::DFSchema;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::Alias;
use datafusion_expr::ExprSchemable;

Expand Down Expand Up @@ -161,6 +161,181 @@ async fn test_fn_btrim_with_chars() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_fn_nullif() -> Result<()> {
let expr = nullif(col("a"), lit("abcDEF"));

let expected = [
"+-------------------------------+",
"| nullif(test.a,Utf8(\"abcDEF\")) |",
"+-------------------------------+",
"| |",
"| abc123 |",
"| CBAdef |",
"| 123AbcDef |",
"+-------------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
async fn test_fn_arrow_cast() -> Result<()> {
let expr = arrow_typeof(arrow_cast(col("b"), lit("Float64")));

let expected = [
"+--------------------------------------------------+",
"| arrow_typeof(arrow_cast(test.b,Utf8(\"Float64\"))) |",
"+--------------------------------------------------+",
"| Float64 |",
"| Float64 |",
"| Float64 |",
"| Float64 |",
"+--------------------------------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
async fn test_nvl() -> Result<()> {
let lit_null = lit(ScalarValue::Utf8(None));
// nvl(CASE WHEN a = 'abcDEF' THEN NULL ELSE a END, 'TURNED_NULL')
let expr = nvl(
when(col("a").eq(lit("abcDEF")), lit_null)
.otherwise(col("a"))
.unwrap(),
lit("TURNED_NULL"),
)
.alias("nvl_expr");

let expected = [
"+-------------+",
"| nvl_expr |",
"+-------------+",
"| TURNED_NULL |",
"| abc123 |",
"| CBAdef |",
"| 123AbcDef |",
"+-------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}
#[tokio::test]
async fn test_nvl2() -> Result<()> {
let lit_null = lit(ScalarValue::Utf8(None));
// nvl2(CASE WHEN a = 'abcDEF' THEN NULL ELSE a END, 'NON_NUll', 'TURNED_NULL')
let expr = nvl2(
when(col("a").eq(lit("abcDEF")), lit_null)
.otherwise(col("a"))
.unwrap(),
lit("NON_NULL"),
lit("TURNED_NULL"),
)
.alias("nvl2_expr");

let expected = [
"+-------------+",
"| nvl2_expr |",
"+-------------+",
"| TURNED_NULL |",
"| NON_NULL |",
"| NON_NULL |",
"| NON_NULL |",
"+-------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}
#[tokio::test]
async fn test_fn_arrow_typeof() -> Result<()> {
let expr = arrow_typeof(col("l"));

let expected = [
"+------------------------------------------------------------------------------------------------------------------+",
"| arrow_typeof(test.l) |",
"+------------------------------------------------------------------------------------------------------------------+",
"| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
"| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
"| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
"| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
"+------------------------------------------------------------------------------------------------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
async fn test_fn_struct() -> Result<()> {
let expr = r#struct(vec![col("a"), col("b")]);

let expected = [
"+--------------------------+",
"| struct(test.a,test.b) |",
"+--------------------------+",
"| {c0: abcDEF, c1: 1} |",
"| {c0: abc123, c1: 10} |",
"| {c0: CBAdef, c1: 10} |",
"| {c0: 123AbcDef, c1: 100} |",
"+--------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
async fn test_fn_named_struct() -> Result<()> {
let expr = named_struct(vec![lit("column_a"), col("a"), lit("column_b"), col("b")]);

let expected = [
"+---------------------------------------------------------------+",
"| named_struct(Utf8(\"column_a\"),test.a,Utf8(\"column_b\"),test.b) |",
"+---------------------------------------------------------------+",
"| {column_a: abcDEF, column_b: 1} |",
"| {column_a: abc123, column_b: 10} |",
"| {column_a: CBAdef, column_b: 10} |",
"| {column_a: 123AbcDef, column_b: 100} |",
"+---------------------------------------------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
async fn test_fn_coalesce() -> Result<()> {
let expr = coalesce(vec![lit(ScalarValue::Utf8(None)), lit("ab")]);

let expected = [
"+---------------------------------+",
"| coalesce(Utf8(NULL),Utf8(\"ab\")) |",
"+---------------------------------+",
"| ab |",
"| ab |",
"| ab |",
"| ab |",
"+---------------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
async fn test_fn_approx_median() -> Result<()> {
let expr = approx_median(col("b"));
Expand Down
79 changes: 68 additions & 11 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

//! "core" DataFusion functions

use datafusion_expr::ScalarUDF;
use std::sync::Arc;

pub mod arrow_cast;
pub mod arrowtypeof;
pub mod coalesce;
Expand All @@ -39,14 +42,68 @@ make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field);
make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce);

// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
(nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."),
(arrow_cast, arg_1 arg_2, "returns arg_1 cast to the `arrow_type` given the second argument. This can be used to cast to a specific `arrow_type`."),
(nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"),
(nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."),
(arrow_typeof, arg_1, "Returns the Arrow type of the input expression."),
(r#struct, args, "Returns a struct with the given arguments"),
(named_struct, args, "Returns a struct with the given names and arguments pairs"),
(get_field, arg_1 arg_2, "Returns the value of the field with the given name from the struct"),
(coalesce, args, "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL")
);
pub mod expr_fn {
use datafusion_expr::Expr;

/// returns NULL if value1 equals value2; otherwise it returns value1. This
/// can be used to perform the inverse operation of the COALESCE expression
pub fn nullif(arg1: Expr, arg2: Expr) -> Expr {
super::nullif().call(vec![arg1, arg2])
}

/// returns value1 cast to the `arrow_type` given the second argument. This
/// can be used to cast to a specific `arrow_type`.
pub fn arrow_cast(arg1: Expr, arg2: Expr) -> Expr {
super::arrow_cast().call(vec![arg1, arg2])
}

/// Returns value2 if value1 is NULL; otherwise it returns value1
pub fn nvl(arg1: Expr, arg2: Expr) -> Expr {
super::nvl().call(vec![arg1, arg2])
}

/// Returns value2 if value1 is not NULL; otherwise, it returns value3.
pub fn nvl2(arg1: Expr, arg2: Expr, arg3: Expr) -> Expr {
super::nvl2().call(vec![arg1, arg2, arg3])
}

/// Returns the Arrow type of the input expression.
pub fn arrow_typeof(arg1: Expr) -> Expr {
super::arrow_typeof().call(vec![arg1])
}

/// Returns a struct with the given arguments
pub fn r#struct(args: Vec<Expr>) -> Expr {
super::r#struct().call(args)
}

/// Returns a struct with the given names and arguments pairs
pub fn named_struct(args: Vec<Expr>) -> Expr {
super::named_struct().call(args)
}

/// Returns the value of the field with the given name from the struct
pub fn get_field(arg1: Expr, arg2: Expr) -> Expr {
super::get_field().call(vec![arg1, arg2])
}

/// Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL
pub fn coalesce(args: Vec<Expr>) -> Expr {
super::coalesce().call(args)
}
}

/// Return a list of all functions in this package
pub fn functions() -> Vec<Arc<ScalarUDF>> {
vec![
nullif(),
arrow_cast(),
nvl(),
nvl2(),
arrow_typeof(),
r#struct(),
named_struct(),
get_field(),
coalesce(),
]
}

0 comments on commit 40a2055

Please sign in to comment.