Skip to content

Commit 6c8bf5f

Browse files
emgeeetimsaucerMichael-J-Ward
authored
Upgrade datafusion (#867)
* update dependencies * update get_logical_plan signature * remove row_number() function row_number was converted to a UDF in datafusion v42 apache/datafusion#12030 This specific functionality needs to be added back in. * remove unneeded dependency * fix pyo3 warnings Implicit defaults for trailing optional arguments have been deprecated in pyo3 v0.22.0 PyO3/pyo3#4078 * update object_store dependency * change PyExpr -> PySortExpr * comment out key.extract::<&PyTuple>() condition statement * change more instances of PyExpr > PySortExpr * update function signatures to use _bound versions * remove clone * Working through some of the sort requirement changes * remove unused import * expr.display_name is deprecated, used format!() + schema_name() instead * expr.canonical_name() is deprecated, use format!() expr instead * remove comment * fix tuple extraction in dataframe.__getitem__() * remove unneeded import * Add docstring comments to SortExpr python class * change extract() to downcast() Co-authored-by: Michael J Ward <[email protected]> * deprecate Expr::display_name Ref: apache/datafusion#11797 * fix lint errors * update datafusion commit hash * fix type in cargo file for arrow features * upgrade to datafusion 42 * cleanup --------- Co-authored-by: Tim Saucer <[email protected]> Co-authored-by: Michael J Ward <[email protected]> Co-authored-by: Michael-J-Ward <[email protected]>
1 parent 02d4453 commit 6c8bf5f

22 files changed

+710
-595
lines changed

Cargo.lock

+390-394
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+8-9
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,24 @@ substrait = ["dep:datafusion-substrait"]
3636
[dependencies]
3737
tokio = { version = "1.39", features = ["macros", "rt", "rt-multi-thread", "sync"] }
3838
rand = "0.8"
39-
pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38"] }
40-
arrow = { version = "52", feature = ["pyarrow"] }
41-
datafusion = { version = "41.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
42-
datafusion-substrait = { version = "41.0.0", optional = true }
43-
prost = "0.12" # keep in line with `datafusion-substrait`
44-
prost-types = "0.12" # keep in line with `datafusion-substrait`
39+
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
40+
arrow = { version = "53", features = ["pyarrow"] }
41+
datafusion = { version = "42.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
42+
datafusion-substrait = { version = "42.0.0", optional = true }
43+
prost = "0.13" # keep in line with `datafusion-substrait`
44+
prost-types = "0.13" # keep in line with `datafusion-substrait`
4545
uuid = { version = "1.9", features = ["v4"] }
4646
mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] }
4747
async-trait = "0.1"
4848
futures = "0.3"
49-
object_store = { version = "0.10.1", features = ["aws", "gcp", "azure"] }
49+
object_store = { version = "0.11.0", features = ["aws", "gcp", "azure"] }
5050
parking_lot = "0.12"
5151
regex-syntax = "0.8"
5252
syn = "2.0.68"
5353
url = "2"
5454

5555
[build-dependencies]
56-
pyo3-build-config = "0.21"
56+
pyo3-build-config = "0.22"
5757

5858
[lib]
5959
name = "datafusion_python"
@@ -62,4 +62,3 @@ crate-type = ["cdylib", "rlib"]
6262
[profile.release]
6363
lto = true
6464
codegen-units = 1
65-

python/datafusion/context.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from datafusion._internal import AggregateUDF
2929
from datafusion.catalog import Catalog, Table
3030
from datafusion.dataframe import DataFrame
31-
from datafusion.expr import Expr
31+
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
3232
from datafusion.record_batch import RecordBatchStream
3333
from datafusion.udf import ScalarUDF
3434

@@ -466,7 +466,7 @@ def register_listing_table(
466466
table_partition_cols: list[tuple[str, str]] | None = None,
467467
file_extension: str = ".parquet",
468468
schema: pyarrow.Schema | None = None,
469-
file_sort_order: list[list[Expr]] | None = None,
469+
file_sort_order: list[list[Expr | SortExpr]] | None = None,
470470
) -> None:
471471
"""Register multiple files as a single table.
472472
@@ -484,15 +484,18 @@ def register_listing_table(
484484
"""
485485
if table_partition_cols is None:
486486
table_partition_cols = []
487-
if file_sort_order is not None:
488-
file_sort_order = [[x.expr for x in xs] for xs in file_sort_order]
487+
file_sort_order_raw = (
488+
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
489+
if file_sort_order is not None
490+
else None
491+
)
489492
self.ctx.register_listing_table(
490493
name,
491494
str(path),
492495
table_partition_cols,
493496
file_extension,
494497
schema,
495-
file_sort_order,
498+
file_sort_order_raw,
496499
)
497500

498501
def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:

python/datafusion/dataframe.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from typing import Callable
3434

3535
from datafusion._internal import DataFrame as DataFrameInternal
36-
from datafusion.expr import Expr
36+
from datafusion.expr import Expr, SortExpr, sort_or_default
3737
from datafusion._internal import (
3838
LogicalPlan,
3939
ExecutionPlan,
@@ -199,7 +199,7 @@ def aggregate(
199199
aggs = [e.expr for e in aggs]
200200
return DataFrame(self.df.aggregate(group_by, aggs))
201201

202-
def sort(self, *exprs: Expr) -> DataFrame:
202+
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
203203
"""Sort the DataFrame by the specified sorting expressions.
204204
205205
Note that any expression can be turned into a sort expression by
@@ -211,8 +211,8 @@ def sort(self, *exprs: Expr) -> DataFrame:
211211
Returns:
212212
DataFrame after sorting.
213213
"""
214-
exprs = [expr.expr for expr in exprs]
215-
return DataFrame(self.df.sort(*exprs))
214+
exprs_raw = [sort_or_default(expr) for expr in exprs]
215+
return DataFrame(self.df.sort(*exprs_raw))
216216

217217
def limit(self, count: int, offset: int = 0) -> DataFrame:
218218
"""Return a new :py:class:`DataFrame` with a limited number of rows.

python/datafusion/expr.py

+70-13
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@
2222

2323
from __future__ import annotations
2424

25-
from ._internal import (
26-
expr as expr_internal,
27-
LogicalPlan,
28-
functions as functions_internal,
29-
)
30-
from datafusion.common import NullTreatment, RexType, DataTypeMap
3125
from typing import Any, Optional, Type
26+
3227
import pyarrow as pa
28+
from datafusion.common import DataTypeMap, NullTreatment, RexType
29+
from typing_extensions import deprecated
30+
31+
from ._internal import LogicalPlan
32+
from ._internal import expr as expr_internal
33+
from ._internal import functions as functions_internal
3334

3435
# The following are imported from the internal representation. We may choose to
3536
# give these all proper wrappers, or to simply leave as is. These were added
@@ -84,7 +85,6 @@
8485
ScalarVariable = expr_internal.ScalarVariable
8586
SimilarTo = expr_internal.SimilarTo
8687
Sort = expr_internal.Sort
87-
SortExpr = expr_internal.SortExpr
8888
Subquery = expr_internal.Subquery
8989
SubqueryAlias = expr_internal.SubqueryAlias
9090
TableScan = expr_internal.TableScan
@@ -159,6 +159,27 @@
159159
]
160160

161161

162+
def expr_list_to_raw_expr_list(
163+
expr_list: Optional[list[Expr]],
164+
) -> Optional[list[expr_internal.Expr]]:
165+
"""Helper function to convert an optional list to raw expressions."""
166+
return [e.expr for e in expr_list] if expr_list is not None else None
167+
168+
169+
def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
170+
"""Helper function to return a default Sort if an Expr is provided."""
171+
if isinstance(e, SortExpr):
172+
return e.raw_sort
173+
return SortExpr(e.expr, True, True).raw_sort
174+
175+
176+
def sort_list_to_raw_sort_list(
177+
sort_list: Optional[list[Expr | SortExpr]],
178+
) -> Optional[list[expr_internal.SortExpr]]:
179+
"""Helper function to return an optional sort list to raw variant."""
180+
return [sort_or_default(e) for e in sort_list] if sort_list is not None else None
181+
182+
162183
class Expr:
163184
"""Expression object.
164185
@@ -174,12 +195,22 @@ def to_variant(self) -> Any:
174195
"""Convert this expression into a python object if possible."""
175196
return self.expr.to_variant()
176197

198+
@deprecated(
199+
"display_name() is deprecated. Use :py:meth:`~Expr.schema_name` instead"
200+
)
177201
def display_name(self) -> str:
178202
"""Returns the name of this expression as it should appear in a schema.
179203
180204
This name will not include any CAST expressions.
181205
"""
182-
return self.expr.display_name()
206+
return self.schema_name()
207+
208+
def schema_name(self) -> str:
209+
"""Returns the name of this expression as it should appear in a schema.
210+
211+
This name will not include any CAST expressions.
212+
"""
213+
return self.expr.schema_name()
183214

184215
def canonical_name(self) -> str:
185216
"""Returns a complete string representation of this expression."""
@@ -355,14 +386,14 @@ def alias(self, name: str) -> Expr:
355386
"""Assign a name to the expression."""
356387
return Expr(self.expr.alias(name))
357388

358-
def sort(self, ascending: bool = True, nulls_first: bool = True) -> Expr:
389+
def sort(self, ascending: bool = True, nulls_first: bool = True) -> SortExpr:
359390
"""Creates a sort :py:class:`Expr` from an existing :py:class:`Expr`.
360391
361392
Args:
362393
ascending: If true, sort in ascending order.
363394
nulls_first: Return null values first.
364395
"""
365-
return Expr(self.expr.sort(ascending=ascending, nulls_first=nulls_first))
396+
return SortExpr(self.expr, ascending=ascending, nulls_first=nulls_first)
366397

367398
def is_null(self) -> Expr:
368399
"""Returns ``True`` if this expression is null."""
@@ -455,14 +486,14 @@ def column_name(self, plan: LogicalPlan) -> str:
455486
"""Compute the output column name based on the provided logical plan."""
456487
return self.expr.column_name(plan)
457488

458-
def order_by(self, *exprs: Expr) -> ExprFuncBuilder:
489+
def order_by(self, *exprs: Expr | SortExpr) -> ExprFuncBuilder:
459490
"""Set the ordering for a window or aggregate function.
460491
461492
This function will create an :py:class:`ExprFuncBuilder` that can be used to
462493
set parameters for either window or aggregate functions. If used on any other
463494
type of expression, an error will be generated when ``build()`` is called.
464495
"""
465-
return ExprFuncBuilder(self.expr.order_by(list(e.expr for e in exprs)))
496+
return ExprFuncBuilder(self.expr.order_by([sort_or_default(e) for e in exprs]))
466497

467498
def filter(self, filter: Expr) -> ExprFuncBuilder:
468499
"""Filter an aggregate function.
@@ -522,7 +553,9 @@ def order_by(self, *exprs: Expr) -> ExprFuncBuilder:
522553
Values given in ``exprs`` must be sort expressions. You can convert any other
523554
expression to a sort expression using `.sort()`.
524555
"""
525-
return ExprFuncBuilder(self.builder.order_by(list(e.expr for e in exprs)))
556+
return ExprFuncBuilder(
557+
self.builder.order_by([sort_or_default(e) for e in exprs])
558+
)
526559

527560
def filter(self, filter: Expr) -> ExprFuncBuilder:
528561
"""Filter values during aggregation."""
@@ -659,3 +692,27 @@ def end(self) -> Expr:
659692
Any non-matching cases will end in a `null` value.
660693
"""
661694
return Expr(self.case_builder.end())
695+
696+
697+
class SortExpr:
698+
"""Used to specify sorting on either a DataFrame or function."""
699+
700+
def __init__(self, expr: Expr, ascending: bool, nulls_first: bool) -> None:
701+
"""This constructor should not be called by the end user."""
702+
self.raw_sort = expr_internal.SortExpr(expr, ascending, nulls_first)
703+
704+
def expr(self) -> Expr:
705+
"""Return the raw expr backing the SortExpr."""
706+
return Expr(self.raw_sort.expr())
707+
708+
def ascending(self) -> bool:
709+
"""Return ascending property."""
710+
return self.raw_sort.ascending()
711+
712+
def nulls_first(self) -> bool:
713+
"""Return nulls_first property."""
714+
return self.raw_sort.nulls_first()
715+
716+
def __repr__(self) -> str:
717+
"""Generate a string representation of this expression."""
718+
return self.raw_sort.__repr__()

0 commit comments

Comments
 (0)