Skip to content

Commit 1fde8e4

Browse files
authored
feat: add case function (#447) (#448)
1 parent 92ca34b commit 1fde8e4

File tree

4 files changed

+89
-3
lines changed

4 files changed

+89
-3
lines changed

datafusion/tests/test_functions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,25 @@ def test_temporal_functions(df):
411411
assert result.column(9) == pa.array(
412412
[datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("us")
413413
)
414+
415+
416+
def test_case(df):
417+
df = df.select(
418+
f.case(column("b"))
419+
.when(literal(4), literal(10))
420+
.otherwise(literal(8)),
421+
f.case(column("a"))
422+
.when(literal("Hello"), literal("Hola"))
423+
.when(literal("World"), literal("Mundo"))
424+
.otherwise(literal("!!")),
425+
f.case(column("a"))
426+
.when(literal("Hello"), literal("Hola"))
427+
.when(literal("World"), literal("Mundo"))
428+
.end(),
429+
)
430+
431+
result = df.collect()
432+
result = result[0]
433+
assert result.column(0) == pa.array([10, 8, 8])
434+
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"])
435+
assert result.column(2) == pa.array(["Hola", "Mundo", None])

src/expr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ pub mod bool_expr;
5959
pub mod case;
6060
pub mod cast;
6161
pub mod column;
62+
pub mod conditional_expr;
6263
pub mod create_memory_table;
6364
pub mod create_view;
6465
pub mod cross_join;

src/expr/conditional_expr.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::expr::PyExpr;
19+
use datafusion_expr::conditional_expressions::CaseBuilder;
20+
use pyo3::prelude::*;
21+
22+
#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass)]
23+
pub struct PyCaseBuilder {
24+
pub case_builder: CaseBuilder,
25+
}
26+
27+
impl From<PyCaseBuilder> for CaseBuilder {
28+
fn from(case_builder: PyCaseBuilder) -> Self {
29+
case_builder.case_builder
30+
}
31+
}
32+
33+
impl From<CaseBuilder> for PyCaseBuilder {
34+
fn from(case_builder: CaseBuilder) -> PyCaseBuilder {
35+
PyCaseBuilder { case_builder }
36+
}
37+
}
38+
39+
#[pymethods]
40+
impl PyCaseBuilder {
41+
fn when(&mut self, when: PyExpr, then: PyExpr) -> PyCaseBuilder {
42+
PyCaseBuilder {
43+
case_builder: self.case_builder.when(when.expr, then.expr),
44+
}
45+
}
46+
47+
fn otherwise(&mut self, else_expr: PyExpr) -> PyResult<PyExpr> {
48+
Ok(self.case_builder.otherwise(else_expr.expr)?.clone().into())
49+
}
50+
51+
fn end(&mut self) -> PyResult<PyExpr> {
52+
Ok(self.case_builder.end()?.clone().into())
53+
}
54+
}

src/functions.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
use pyo3::{prelude::*, wrap_pyfunction};
1919

20+
use crate::errors::DataFusionError;
21+
use crate::expr::conditional_expr::PyCaseBuilder;
22+
use crate::expr::PyExpr;
2023
use datafusion_common::Column;
2124
use datafusion_expr::expr::Alias;
2225
use datafusion_expr::{
@@ -27,9 +30,6 @@ use datafusion_expr::{
2730
BuiltinScalarFunction, Expr, WindowFrame,
2831
};
2932

30-
use crate::errors::DataFusionError;
31-
use crate::expr::PyExpr;
32-
3333
#[pyfunction]
3434
fn in_list(expr: PyExpr, value: Vec<PyExpr>, negated: bool) -> PyExpr {
3535
datafusion_expr::in_list(
@@ -115,6 +115,14 @@ fn count_star() -> PyResult<PyExpr> {
115115
})
116116
}
117117

118+
/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
119+
#[pyfunction]
120+
fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
121+
Ok(PyCaseBuilder {
122+
case_builder: datafusion_expr::case(expr.expr),
123+
})
124+
}
125+
118126
/// Creates a new Window function expression
119127
#[pyfunction]
120128
fn window(
@@ -355,6 +363,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
355363
m.add_wrapped(wrap_pyfunction!(chr))?;
356364
m.add_wrapped(wrap_pyfunction!(char_length))?;
357365
m.add_wrapped(wrap_pyfunction!(coalesce))?;
366+
m.add_wrapped(wrap_pyfunction!(case))?;
358367
m.add_wrapped(wrap_pyfunction!(col))?;
359368
m.add_wrapped(wrap_pyfunction!(concat_ws))?;
360369
m.add_wrapped(wrap_pyfunction!(concat))?;

0 commit comments

Comments
 (0)