From b10b21bdcca93ba3c50a5437f3d3feb33371df1f Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Fri, 4 Aug 2023 15:32:17 +0200 Subject: [PATCH] feat: add case function (#447) --- datafusion/tests/test_functions.py | 22 ++++++++++++ src/expr.rs | 1 + src/expr/conditional_expr.rs | 58 ++++++++++++++++++++++++++++++ src/functions.rs | 13 ++++++- 4 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 src/expr/conditional_expr.rs diff --git a/datafusion/tests/test_functions.py b/datafusion/tests/test_functions.py index ec334747c..f3f00fd5b 100644 --- a/datafusion/tests/test_functions.py +++ b/datafusion/tests/test_functions.py @@ -411,3 +411,25 @@ def test_temporal_functions(df): assert result.column(9) == pa.array( [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("us") ) + + +def test_case(df): + df = df.select( + f.case(column("b")) + .when(literal(4), literal(10)) + .otherwise(literal(8)), + f.case(column("a")) + .when(literal("Hello"), literal("Hola")) + .when(literal("World"), literal("Mundo")) + .otherwise(literal("!!")), + f.case(column("a")) + .when(literal("Hello"), literal("Hola")) + .when(literal("World"), literal("Mundo")) + .end() + ) + + result = df.collect() + result = result[0] + assert result.column(0) == pa.array([10, 8, 8]) + assert result.column(1) == pa.array(["Hola", "Mundo", "!!"]) + assert result.column(2) == pa.array(["Hola", "Mundo", None]) diff --git a/src/expr.rs b/src/expr.rs index 2fd638a13..d1022e905 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -59,6 +59,7 @@ pub mod bool_expr; pub mod case; pub mod cast; pub mod column; +pub mod conditional_expr; pub mod create_memory_table; pub mod create_view; pub mod cross_join; diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs new file mode 100644 index 000000000..d745ad8a7 --- /dev/null +++ b/src/expr/conditional_expr.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expr::PyExpr; +use datafusion_expr::conditional_expressions::CaseBuilder; +use pyo3::prelude::*; + +#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass)] +pub struct PyCaseBuilder { + pub case_builder: CaseBuilder, +} + +impl From for CaseBuilder { + fn from(case_builder: PyCaseBuilder) -> Self { + case_builder.case_builder + } +} + +impl From for PyCaseBuilder { + fn from(case_builder: CaseBuilder) -> PyCaseBuilder { + PyCaseBuilder { case_builder } + } +} + + +#[pymethods] +impl PyCaseBuilder { + + fn when(&mut self, when: PyExpr, then: PyExpr) -> PyCaseBuilder { + PyCaseBuilder { + case_builder : self.case_builder.when(when.expr, then.expr) + } + } + + fn otherwise(&mut self, else_expr: PyExpr) -> PyResult { + Ok(self.case_builder.otherwise(else_expr.expr)?.clone().into()) + } + + fn end(&mut self) -> PyResult { + Ok(self.case_builder.end()?.clone().into()) + } + + +} diff --git a/src/functions.rs b/src/functions.rs index 8b60e6433..4d3cd3e85 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -26,8 +26,8 @@ use datafusion_expr::{ window_function::find_df_window_func, BuiltinScalarFunction, Expr, WindowFrame, }; - use crate::errors::DataFusionError; +use crate::expr::conditional_expr::PyCaseBuilder; use crate::expr::PyExpr; #[pyfunction] @@ -115,6 +115,16 @@ fn count_star() -> PyResult { }) } +/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. +#[pyfunction] +fn case(expr: PyExpr) -> PyResult { + Ok( + PyCaseBuilder{ + case_builder: datafusion_expr:: case(expr.expr) + } + ) +} + /// Creates a new Window function expression #[pyfunction] fn window( @@ -355,6 +365,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(chr))?; m.add_wrapped(wrap_pyfunction!(char_length))?; m.add_wrapped(wrap_pyfunction!(coalesce))?; + m.add_wrapped(wrap_pyfunction!(case))?; m.add_wrapped(wrap_pyfunction!(col))?; m.add_wrapped(wrap_pyfunction!(concat_ws))?; m.add_wrapped(wrap_pyfunction!(concat))?;