Skip to content

Commit

Permalink
refactor test
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Sep 9, 2023
1 parent 644a2c5 commit f01063c
Showing 1 changed file with 50 additions and 103 deletions.
153 changes: 50 additions & 103 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::window_function;
use crate::Operator;
use arrow::datatypes::DataType;
use datafusion_common::internal_err;
use datafusion_common::not_impl_err;
use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue};
use std::collections::HashSet;
use std::fmt;
Expand Down Expand Up @@ -1058,10 +1059,7 @@ impl Expr {
args: flatten_args,
}))
}
_ => Err(DataFusionError::Internal(format!(
"Cannot flatten expression {:?}",
self
))),
_ => not_impl_err!("flatten() is not implemented for {self}"),
}
}

Expand Down Expand Up @@ -1627,6 +1625,8 @@ mod test {
use datafusion_common::Column;
use datafusion_common::{Result, ScalarValue};

use super::ScalarFunction;

#[test]
fn format_case_when() -> Result<()> {
let expr = case(col("a"))
Expand Down Expand Up @@ -1691,117 +1691,64 @@ mod test {
Ok(())
}

fn create_make_array_expr(args: &[Expr]) -> Expr {
Expr::ScalarFunction(ScalarFunction::new(
crate::BuiltinScalarFunction::MakeArray,
args.to_vec(),
))
}

#[test]
fn test_flatten() {
let arr = Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::Literal(ScalarValue::Int64(Some(10))),
Expr::Literal(ScalarValue::Int64(Some(20))),
Expr::Literal(ScalarValue::Int64(Some(30))),
],
}),
Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::Literal(ScalarValue::Int64(Some(1))),
Expr::Literal(ScalarValue::Int64(None)),
Expr::Literal(ScalarValue::Int64(Some(10))),
],
}),
Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::Literal(ScalarValue::Int64(Some(4))),
Expr::Literal(ScalarValue::Int64(Some(5))),
Expr::Literal(ScalarValue::Int64(Some(6))),
],
}),
],
});
let i64_none = ScalarValue::try_from(&DataType::Int64).unwrap();

let arr = create_make_array_expr(&[
create_make_array_expr(&[lit(10i64), lit(20i64), lit(30i64)]),
create_make_array_expr(&[lit(1i64), lit(i64_none.clone()), lit(10i64)]),
create_make_array_expr(&[lit(4i64), lit(5i64), lit(6i64)]),
]);

let flattened = arr.flatten();
assert_eq!(
flattened,
Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::Literal(ScalarValue::Int64(Some(10))),
Expr::Literal(ScalarValue::Int64(Some(20))),
Expr::Literal(ScalarValue::Int64(Some(30))),
Expr::Literal(ScalarValue::Int64(Some(1))),
Expr::Literal(ScalarValue::Int64(None)),
Expr::Literal(ScalarValue::Int64(Some(10))),
Expr::Literal(ScalarValue::Int64(Some(4))),
Expr::Literal(ScalarValue::Int64(Some(5))),
Expr::Literal(ScalarValue::Int64(Some(6))),
]
})
create_make_array_expr(&[
lit(10i64),
lit(20i64),
lit(30i64),
lit(1i64),
lit(i64_none),
lit(10i64),
lit(4i64),
lit(5i64),
lit(6i64),
])
);

// [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] -> [1, 2, 3, 4, 5, 6, 7, 8]
let arr = Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::Literal(ScalarValue::Int64(Some(1))),
Expr::Literal(ScalarValue::Int64(Some(2))),
],
}),
Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::Literal(ScalarValue::Int64(Some(3))),
Expr::Literal(ScalarValue::Int64(Some(4))),
],
}),
],
}),
Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::Literal(ScalarValue::Int64(Some(5))),
Expr::Literal(ScalarValue::Int64(Some(6))),
],
}),
Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::Literal(ScalarValue::Int64(Some(7))),
Expr::Literal(ScalarValue::Int64(Some(8))),
],
}),
],
}),
],
});
let arr = create_make_array_expr(&[
create_make_array_expr(&[
create_make_array_expr(&[lit(1i64), lit(2i64)]),
create_make_array_expr(&[lit(3i64), lit(4i64)]),
]),
create_make_array_expr(&[
create_make_array_expr(&[lit(5i64), lit(6i64)]),
create_make_array_expr(&[lit(7i64), lit(8i64)]),
]),
]);

let flattened = arr.flatten();
assert_eq!(
flattened,
Expr::ScalarFunction(super::ScalarFunction {
fun: crate::BuiltinScalarFunction::MakeArray,
args: vec![
Expr::Literal(ScalarValue::Int64(Some(1))),
Expr::Literal(ScalarValue::Int64(Some(2))),
Expr::Literal(ScalarValue::Int64(Some(3))),
Expr::Literal(ScalarValue::Int64(Some(4))),
Expr::Literal(ScalarValue::Int64(Some(5))),
Expr::Literal(ScalarValue::Int64(Some(6))),
Expr::Literal(ScalarValue::Int64(Some(7))),
Expr::Literal(ScalarValue::Int64(Some(8))),
]
})
create_make_array_expr(&[
lit(1i64),
lit(2i64),
lit(3i64),
lit(4i64),
lit(5i64),
lit(6i64),
lit(7i64),
lit(8i64),
])
);
}
}

0 comments on commit f01063c

Please sign in to comment.