Skip to content

Commit

Permalink
feat: make cast accept built-in Python types (#858)
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo committed Sep 6, 2024
1 parent fe0738a commit 859acb4
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 11 deletions.
21 changes: 19 additions & 2 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
functions as functions_internal,
)
from datafusion.common import NullTreatment, RexType, DataTypeMap
from typing import Any, Optional
from typing import Any, Optional, Type
import pyarrow as pa

# The following are imported from the internal representation. We may choose to
Expand Down Expand Up @@ -372,8 +372,25 @@ def is_not_null(self) -> Expr:
"""Returns ``True`` if this expression is not null."""
return Expr(self.expr.is_not_null())

def cast(self, to: pa.DataType[Any]) -> Expr:
_to_pyarrow_types = {
float: pa.float64(),
int: pa.int64(),
str: pa.string(),
bool: pa.bool_(),
}

def cast(
self, to: pa.DataType[Any] | Type[float] | Type[int] | Type[str] | Type[bool]
) -> Expr:
"""Cast to a new data type."""
if not isinstance(to, pa.DataType):
try:
to = self._to_pyarrow_types[to]
except KeyError:
raise TypeError(
"Expected instance of pyarrow.DataType or builtins.type"
)

return Expr(self.expr.cast(to))

def rex_type(self) -> RexType:
Expand Down
37 changes: 28 additions & 9 deletions python/datafusion/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def df():
datetime(2020, 7, 2),
]
),
pa.array([False, True, True]),
],
names=["a", "b", "c", "d"],
names=["a", "b", "c", "d", "e"],
)
return ctx.create_dataframe([[batch]])

Expand All @@ -63,15 +64,14 @@ def test_named_struct(df):
)

expected = """DataFrame()
+-------+---+---------+------------------------------+
| a | b | c | d |
+-------+---+---------+------------------------------+
| Hello | 4 | hello | {a: Hello, b: 4, c: hello } |
| World | 5 | world | {a: World, b: 5, c: world } |
| ! | 6 | ! | {a: !, b: 6, c: !} |
+-------+---+---------+------------------------------+
+-------+---+---------+------------------------------+-------+
| a | b | c | d | e |
+-------+---+---------+------------------------------+-------+
| Hello | 4 | hello | {a: Hello, b: 4, c: hello } | false |
| World | 5 | world | {a: World, b: 5, c: world } | true |
| ! | 6 | ! | {a: !, b: 6, c: !} | true |
+-------+---+---------+------------------------------+-------+
""".strip()

assert str(df) == expected


Expand Down Expand Up @@ -978,3 +978,22 @@ def test_binary_string_functions(df):
assert pa.array(result.column(1)).cast(pa.string()) == pa.array(
["Hello", "World", "!"]
)


@pytest.mark.parametrize(
"python_datatype, name, expected",
[
pytest.param(bool, "e", pa.bool_(), id="bool"),
pytest.param(int, "b", pa.int64(), id="int"),
pytest.param(float, "b", pa.float64(), id="float"),
pytest.param(str, "b", pa.string(), id="str"),
],
)
def test_cast(df, python_datatype, name: str, expected):
df = df.select(
column(name).cast(python_datatype).alias("actual"),
column(name).cast(expected).alias("expected"),
)
result = df.collect()
result = result[0]
assert result.column(0) == result.column(1)

0 comments on commit 859acb4

Please sign in to comment.