Skip to content

Commit 3c66201

Browse files
Upgrade to Datafusion 43 (#905)
* patch datafusion deps * migrate from deprecated RuntimeEnv::new to RuntimeEnv::try_new Ref: apache/datafusion#12566 * remove Arc from create_udf call Ref: apache/datafusion#12489 * doc typo * migrage new UnnestOptions API Ref: https://github.com/apache/datafusion/pull/12836/files * update API for logical expr Limit Ref: apache/datafusion#12836 * remove logical expr CrossJoin It was removed upstream. Ref: apache/datafusion#13076 * update PyWindowUDF Ref: apache/datafusion#12803 * migrate window functions lead and lag to udwf Ref: apache/datafusion#12802 * migrate window functions rank, dense_rank, and percent_rank to udwf Ref: apache/datafusion#12648 * convert window function cume_dist to udwf Ref: apache/datafusion#12695 * convert window function ntile to udwf Ref: apache/datafusion#12694 * clean up functions_window invocation * Only one column was being passed to udwf * Update to DF 43.0.0 * Update tests to look for string_view type * String view is now the default type for strings * Making a variety of adjustments in wrappers and unit tests to account for the switch from string to string_view as default * Resolve errors in doc building --------- Co-authored-by: Tim Saucer <[email protected]>
1 parent 4a6c4d1 commit 3c66201

19 files changed

+338
-338
lines changed

Cargo.lock

+199-174
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+5-4
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ substrait = ["dep:datafusion-substrait"]
3737
tokio = { version = "1.39", features = ["macros", "rt", "rt-multi-thread", "sync"] }
3838
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
3939
arrow = { version = "53", features = ["pyarrow"] }
40-
datafusion = { version = "42.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
41-
datafusion-substrait = { version = "42.0.0", optional = true }
42-
datafusion-proto = { version = "42.0.0" }
40+
datafusion = { version = "43.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
41+
datafusion-substrait = { version = "43.0.0", optional = true }
42+
datafusion-proto = { version = "43.0.0" }
43+
datafusion-functions-window-common = { version = "43.0.0" }
4344
prost = "0.13" # keep in line with `datafusion-substrait`
4445
uuid = { version = "1.11", features = ["v4"] }
4546
mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] }
@@ -58,4 +59,4 @@ crate-type = ["cdylib", "rlib"]
5859

5960
[profile.release]
6061
lto = true
61-
codegen-units = 1
62+
codegen-units = 1

examples/tpch/_tests.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
def df_selection(col_name, col_type):
2626
if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type):
2727
return F.round(col(col_name), lit(2)).alias(col_name)
28-
elif col_type == pa.string():
28+
elif col_type == pa.string() or col_type == pa.string_view():
2929
return F.trim(col(col_name)).alias(col_name)
3030
else:
3131
return col(col_name)
@@ -43,7 +43,7 @@ def load_schema(col_name, col_type):
4343
def expected_selection(col_name, col_type):
4444
if col_type == pa.int64() or col_type == pa.int32():
4545
return F.trim(col(col_name)).cast(col_type).alias(col_name)
46-
elif col_type == pa.string():
46+
elif col_type == pa.string() or col_type == pa.string_view():
4747
return F.trim(col(col_name)).alias(col_name)
4848
else:
4949
return col(col_name)

python/datafusion/expr.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
Column = expr_internal.Column
5252
CreateMemoryTable = expr_internal.CreateMemoryTable
5353
CreateView = expr_internal.CreateView
54-
CrossJoin = expr_internal.CrossJoin
5554
Distinct = expr_internal.Distinct
5655
DropTable = expr_internal.DropTable
5756
EmptyRelation = expr_internal.EmptyRelation
@@ -140,7 +139,6 @@
140139
"Join",
141140
"JoinType",
142141
"JoinConstraint",
143-
"CrossJoin",
144142
"Union",
145143
"Unnest",
146144
"UnnestExpr",
@@ -376,6 +374,8 @@ def literal(value: Any) -> Expr:
376374
377375
``value`` must be a valid PyArrow scalar value or easily castable to one.
378376
"""
377+
if isinstance(value, str):
378+
value = pa.scalar(value, type=pa.string_view())
379379
if not isinstance(value, pa.Scalar):
380380
value = pa.scalar(value)
381381
return Expr(expr_internal.Expr.literal(value))

python/datafusion/functions.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def decode(input: Expr, encoding: Expr) -> Expr:
297297

298298
def array_to_string(expr: Expr, delimiter: Expr) -> Expr:
299299
"""Converts each element to its text representation."""
300-
return Expr(f.array_to_string(expr.expr, delimiter.expr))
300+
return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string())))
301301

302302

303303
def array_join(expr: Expr, delimiter: Expr) -> Expr:
@@ -1067,7 +1067,10 @@ def struct(*args: Expr) -> Expr:
10671067

10681068
def named_struct(name_pairs: list[tuple[str, Expr]]) -> Expr:
10691069
"""Returns a struct with the given names and arguments pairs."""
1070-
name_pair_exprs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs]
1070+
name_pair_exprs = [
1071+
[Expr.literal(pa.scalar(pair[0], type=pa.string())), pair[1]]
1072+
for pair in name_pairs
1073+
]
10711074

10721075
# flatten
10731076
name_pairs = [x.expr for xs in name_pair_exprs for x in xs]
@@ -1424,7 +1427,9 @@ def array_sort(array: Expr, descending: bool = False, null_first: bool = False)
14241427
nulls_first = "NULLS FIRST" if null_first else "NULLS LAST"
14251428
return Expr(
14261429
f.array_sort(
1427-
array.expr, Expr.literal(desc).expr, Expr.literal(nulls_first).expr
1430+
array.expr,
1431+
Expr.literal(pa.scalar(desc, type=pa.string())).expr,
1432+
Expr.literal(pa.scalar(nulls_first, type=pa.string())).expr,
14281433
)
14291434
)
14301435

python/datafusion/udf.py

+1
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def udaf(
229229
which this UDAF is used. The following examples are all valid.
230230
231231
.. code-block:: python
232+
232233
import pyarrow as pa
233234
import pyarrow.compute as pc
234235

python/tests/test_expr.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,18 @@ def test_limit(test_ctx):
8585

8686
plan = plan.to_variant()
8787
assert isinstance(plan, Limit)
88-
assert plan.skip() == 0
88+
# TODO: Upstream now has expressions for skip and fetch
89+
# REF: https://github.com/apache/datafusion/pull/12836
90+
# assert plan.skip() == 0
8991

9092
df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5")
9193
plan = df.logical_plan()
9294

9395
plan = plan.to_variant()
9496
assert isinstance(plan, Limit)
95-
assert plan.skip() == 5
97+
# TODO: Upstream now has expressions for skip and fetch
98+
# REF: https://github.com/apache/datafusion/pull/12836
99+
# assert plan.skip() == 5
96100

97101

98102
def test_aggregate_query(test_ctx):
@@ -126,7 +130,10 @@ def test_relational_expr(test_ctx):
126130
ctx = SessionContext()
127131

128132
batch = pa.RecordBatch.from_arrays(
129-
[pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
133+
[
134+
pa.array([1, 2, 3]),
135+
pa.array(["alpha", "beta", "gamma"], type=pa.string_view()),
136+
],
130137
names=["a", "b"],
131138
)
132139
df = ctx.create_dataframe([[batch]], name="batch_array")
@@ -141,7 +148,8 @@ def test_relational_expr(test_ctx):
141148
assert df.filter(col("b") == "beta").count() == 1
142149
assert df.filter(col("b") != "beta").count() == 2
143150

144-
assert df.filter(col("a") == "beta").count() == 0
151+
with pytest.raises(Exception):
152+
df.filter(col("a") == "beta").count()
145153

146154

147155
def test_expr_to_variant():

python/tests/test_functions.py

+47-20
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def df():
3434
# create a RecordBatch and a new DataFrame from it
3535
batch = pa.RecordBatch.from_arrays(
3636
[
37-
pa.array(["Hello", "World", "!"]),
37+
pa.array(["Hello", "World", "!"], type=pa.string_view()),
3838
pa.array([4, 5, 6]),
39-
pa.array(["hello ", " world ", " !"]),
39+
pa.array(["hello ", " world ", " !"], type=pa.string_view()),
4040
pa.array(
4141
[
4242
datetime(2022, 12, 31),
@@ -88,16 +88,18 @@ def test_literal(df):
8888
assert len(result) == 1
8989
result = result[0]
9090
assert result.column(0) == pa.array([1] * 3)
91-
assert result.column(1) == pa.array(["1"] * 3)
92-
assert result.column(2) == pa.array(["OK"] * 3)
91+
assert result.column(1) == pa.array(["1"] * 3, type=pa.string_view())
92+
assert result.column(2) == pa.array(["OK"] * 3, type=pa.string_view())
9393
assert result.column(3) == pa.array([3.14] * 3)
9494
assert result.column(4) == pa.array([True] * 3)
9595
assert result.column(5) == pa.array([b"hello world"] * 3)
9696

9797

9898
def test_lit_arith(df):
9999
"""Test literals with arithmetic operations"""
100-
df = df.select(literal(1) + column("b"), f.concat(column("a"), literal("!")))
100+
df = df.select(
101+
literal(1) + column("b"), f.concat(column("a").cast(pa.string()), literal("!"))
102+
)
101103
result = df.collect()
102104
assert len(result) == 1
103105
result = result[0]
@@ -600,21 +602,33 @@ def test_array_function_obj_tests(stmt, py_expr):
600602
f.ascii(column("a")),
601603
pa.array([72, 87, 33], type=pa.int32()),
602604
), # H = 72; W = 87; ! = 33
603-
(f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())),
604-
(f.btrim(literal(" World ")), pa.array(["World", "World", "World"])),
605+
(
606+
f.bit_length(column("a").cast(pa.string())),
607+
pa.array([40, 40, 8], type=pa.int32()),
608+
),
609+
(
610+
f.btrim(literal(" World ")),
611+
pa.array(["World", "World", "World"], type=pa.string_view()),
612+
),
605613
(f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
606614
(f.chr(literal(68)), pa.array(["D", "D", "D"])),
607615
(
608616
f.concat_ws("-", column("a"), literal("test")),
609617
pa.array(["Hello-test", "World-test", "!-test"]),
610618
),
611-
(f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])),
619+
(
620+
f.concat(column("a").cast(pa.string()), literal("?")),
621+
pa.array(["Hello?", "World?", "!?"]),
622+
),
612623
(f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])),
613624
(f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
614625
(f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
615626
(f.lower(column("a")), pa.array(["hello", "world", "!"])),
616627
(f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])),
617-
(f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])),
628+
(
629+
f.ltrim(column("c")),
630+
pa.array(["hello ", "world ", "!"], type=pa.string_view()),
631+
),
618632
(
619633
f.md5(column("a")),
620634
pa.array(
@@ -640,19 +654,25 @@ def test_array_function_obj_tests(stmt, py_expr):
640654
f.rpad(column("a"), literal(8)),
641655
pa.array(["Hello ", "World ", "! "]),
642656
),
643-
(f.rtrim(column("c")), pa.array(["hello", " world", " !"])),
657+
(
658+
f.rtrim(column("c")),
659+
pa.array(["hello", " world", " !"], type=pa.string_view()),
660+
),
644661
(
645662
f.split_part(column("a"), literal("l"), literal(1)),
646663
pa.array(["He", "Wor", "!"]),
647664
),
648665
(f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
649666
(f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
650-
(f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])),
667+
(
668+
f.substr(column("a"), literal(3)),
669+
pa.array(["llo", "rld", ""], type=pa.string_view()),
670+
),
651671
(
652672
f.translate(column("a"), literal("or"), literal("ld")),
653673
pa.array(["Helll", "Wldld", "!"]),
654674
),
655-
(f.trim(column("c")), pa.array(["hello", "world", "!"])),
675+
(f.trim(column("c")), pa.array(["hello", "world", "!"], type=pa.string_view())),
656676
(f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
657677
(f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
658678
(
@@ -794,9 +814,9 @@ def test_temporal_functions(df):
794814
f.date_trunc(literal("month"), column("d")),
795815
f.datetrunc(literal("day"), column("d")),
796816
f.date_bin(
797-
literal("15 minutes"),
817+
literal("15 minutes").cast(pa.string()),
798818
column("d"),
799-
literal("2001-01-01 00:02:30"),
819+
literal("2001-01-01 00:02:30").cast(pa.string()),
800820
),
801821
f.from_unixtime(literal(1673383974)),
802822
f.to_timestamp(literal("2023-09-07 05:06:14.523952")),
@@ -858,8 +878,8 @@ def test_case(df):
858878
result = df.collect()
859879
result = result[0]
860880
assert result.column(0) == pa.array([10, 8, 8])
861-
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"])
862-
assert result.column(2) == pa.array(["Hola", "Mundo", None])
881+
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"], type=pa.string_view())
882+
assert result.column(2) == pa.array(["Hola", "Mundo", None], type=pa.string_view())
863883

864884

865885
def test_when_with_no_base(df):
@@ -877,8 +897,10 @@ def test_when_with_no_base(df):
877897
result = df.collect()
878898
result = result[0]
879899
assert result.column(0) == pa.array([4, 5, 6])
880-
assert result.column(1) == pa.array(["too small", "just right", "too big"])
881-
assert result.column(2) == pa.array(["Hello", None, None])
900+
assert result.column(1) == pa.array(
901+
["too small", "just right", "too big"], type=pa.string_view()
902+
)
903+
assert result.column(2) == pa.array(["Hello", None, None], type=pa.string_view())
882904

883905

884906
def test_regr_funcs_sql(df):
@@ -1021,8 +1043,13 @@ def test_regr_funcs_df(func, expected):
10211043

10221044
def test_binary_string_functions(df):
10231045
df = df.select(
1024-
f.encode(column("a"), literal("base64")),
1025-
f.decode(f.encode(column("a"), literal("base64")), literal("base64")),
1046+
f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())),
1047+
f.decode(
1048+
f.encode(
1049+
column("a").cast(pa.string()), literal("base64").cast(pa.string())
1050+
),
1051+
literal("base64").cast(pa.string()),
1052+
),
10261053
)
10271054
result = df.collect()
10281055
assert len(result) == 1

python/tests/test_imports.py

-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
Join,
4747
JoinType,
4848
JoinConstraint,
49-
CrossJoin,
5049
Union,
5150
Like,
5251
ILike,
@@ -129,7 +128,6 @@ def test_class_module_is_datafusion():
129128
Join,
130129
JoinType,
131130
JoinConstraint,
132-
CrossJoin,
133131
Union,
134132
Like,
135133
ILike,

python/tests/test_sql.py

+7
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,13 @@ def test_simple_select(ctx, tmp_path, arr):
468468
batches = ctx.sql("SELECT a AS tt FROM t").collect()
469469
result = batches[0].column(0)
470470

471+
# In DF 43.0.0 we now default to having BinaryView and StringView
472+
# so the array that is saved to the parquet is slightly different
473+
# than the array read. Convert to values for comparison.
474+
if isinstance(result, pa.BinaryViewArray) or isinstance(result, pa.StringViewArray):
475+
arr = arr.tolist()
476+
result = result.tolist()
477+
471478
np.testing.assert_equal(result, arr)
472479

473480

src/context.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ impl PySessionContext {
287287
} else {
288288
RuntimeConfig::default()
289289
};
290-
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
290+
let runtime = Arc::new(RuntimeEnv::try_new(runtime_config)?);
291291
let session_state = SessionStateBuilder::new()
292292
.with_config(config)
293293
.with_runtime_env(runtime)

src/dataframe.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,9 @@ impl PyDataFrame {
402402

403403
#[pyo3(signature = (column, preserve_nulls=true))]
404404
fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyResult<Self> {
405-
let unnest_options = UnnestOptions { preserve_nulls };
405+
// TODO: expose RecursionUnnestOptions
406+
// REF: https://github.com/apache/datafusion/pull/11577
407+
let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
406408
let df = self
407409
.df
408410
.as_ref()
@@ -413,7 +415,9 @@ impl PyDataFrame {
413415

414416
#[pyo3(signature = (columns, preserve_nulls=true))]
415417
fn unnest_columns(&self, columns: Vec<String>, preserve_nulls: bool) -> PyResult<Self> {
416-
let unnest_options = UnnestOptions { preserve_nulls };
418+
// TODO: expose RecursionUnnestOptions
419+
// REF: https://github.com/apache/datafusion/pull/11577
420+
let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
417421
let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
418422
let df = self
419423
.df

src/expr.rs

-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ pub mod column;
6565
pub mod conditional_expr;
6666
pub mod create_memory_table;
6767
pub mod create_view;
68-
pub mod cross_join;
6968
pub mod distinct;
7069
pub mod drop_table;
7170
pub mod empty_relation;
@@ -775,7 +774,6 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
775774
m.add_class::<join::PyJoin>()?;
776775
m.add_class::<join::PyJoinType>()?;
777776
m.add_class::<join::PyJoinConstraint>()?;
778-
m.add_class::<cross_join::PyCrossJoin>()?;
779777
m.add_class::<union::PyUnion>()?;
780778
m.add_class::<unnest::PyUnnest>()?;
781779
m.add_class::<unnest_expr::PyUnnestExpr>()?;

0 commit comments

Comments
 (0)