From feee363af0c38fb4c64c5f844aece35f577b70c0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 10 Sep 2024 08:50:22 -0400 Subject: [PATCH] Working through some of the sort requirement changes --- Cargo.lock | 108 +++++++----------- Cargo.toml | 8 +- python/datafusion/context.py | 13 ++- python/datafusion/dataframe.py | 8 +- python/datafusion/expr.py | 52 ++++++++- python/datafusion/functions.py | 93 +++++++-------- python/datafusion/tests/test_sql.py | 11 +- .../datafusion/tests/test_wrapper_coverage.py | 5 +- src/dataframe.rs | 5 +- src/expr.rs | 18 +-- src/expr/sort_expr.rs | 16 ++- src/functions.rs | 69 ++++++----- src/pyarrow_filter_expression.rs | 2 +- src/udf.rs | 2 +- 14 files changed, 225 insertions(+), 185 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9b698997..bb1d800b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -747,6 +747,7 @@ dependencies = [ [[package]] name = "datafusion" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "ahash", "apache-avro", @@ -781,7 +782,7 @@ dependencies = [ "half", "hashbrown", "indexmap", - "itertools 0.13.0", + "itertools", "log", "num-traits", "num_cpus", @@ -804,6 +805,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "arrow-schema", "async-trait", @@ -817,6 +819,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "ahash", "apache-avro", @@ -841,6 +844,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "log", "tokio", @@ -849,6 +853,7 @@ dependencies = [ [[package]] name = "datafusion-execution" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "arrow", "chrono", @@ -868,6 +873,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "ahash", "arrow", @@ -888,6 +894,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "arrow", "datafusion-common", @@ -897,6 +904,7 @@ dependencies = [ [[package]] name = "datafusion-functions" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "arrow", "arrow-buffer", @@ -909,7 +917,7 @@ dependencies = [ "datafusion-expr", "hashbrown", "hex", - "itertools 0.13.0", + "itertools", "log", "md-5", "rand", @@ -922,6 +930,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "ahash", "arrow", @@ -941,6 +950,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "ahash", "arrow", @@ -953,6 +963,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "arrow", "arrow-array", @@ -965,7 +976,7 @@ dependencies = [ "datafusion-functions", "datafusion-functions-aggregate", "datafusion-physical-expr-common", - "itertools 0.13.0", + "itertools", "log", "paste", "rand", @@ -974,6 +985,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "datafusion-common", "datafusion-expr", @@ -984,6 +996,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "arrow", "async-trait", @@ -993,7 +1006,7 @@ dependencies = [ "datafusion-physical-expr", "hashbrown", "indexmap", - "itertools 0.13.0", + "itertools", "log", "paste", "regex-syntax", @@ -1002,6 +1015,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "ahash", "arrow", @@ -1022,7 +1036,7 @@ dependencies = [ "hashbrown", "hex", "indexmap", - "itertools 0.13.0", + "itertools", "log", "paste", "petgraph", @@ -1032,6 +1046,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "ahash", "arrow", @@ -1044,18 +1059,20 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "arrow-schema", "datafusion-common", "datafusion-execution", "datafusion-physical-expr", "datafusion-physical-plan", - "itertools 0.13.0", + "itertools", ] [[package]] name = "datafusion-physical-plan" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "ahash", "arrow", @@ -1077,7 +1094,7 @@ dependencies = [ "half", "hashbrown", "indexmap", - "itertools 0.13.0", + "itertools", "log", "once_cell", "parking_lot", @@ -1098,8 +1115,8 @@ dependencies = [ "mimalloc", "object_store", "parking_lot", - "prost 0.12.6", - "prost-types 0.12.6", + "prost", + "prost-types", "pyo3", "pyo3-build-config", "rand", @@ -1113,6 +1130,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "arrow", "arrow-array", @@ -1128,15 +1146,16 @@ dependencies = [ [[package]] name = "datafusion-substrait" version = "41.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=c71a9d7508e37e5d082e22d2953a12b61d290df5#c71a9d7508e37e5d082e22d2953a12b61d290df5" dependencies = [ "arrow-buffer", "async-recursion", "chrono", "datafusion", - "itertools 0.13.0", + "itertools", "object_store", "pbjson-types", - "prost 0.13.2", + "prost", "substrait", "url", ] @@ -1590,15 +1609,6 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -1963,7 +1973,7 @@ dependencies = [ "futures", "humantime", "hyper", - "itertools 0.13.0", + "itertools", "md-5", "parking_lot", "percent-encoding", @@ -2093,9 +2103,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6eea3058763d6e656105d1403cb04e0a41b7bbac6362d413e7c33be0c32279c9" dependencies = [ "heck 0.5.0", - "itertools 0.13.0", - "prost 0.13.2", - "prost-types 0.13.2", + "itertools", + "prost", + "prost-types", ] [[package]] @@ -2108,7 +2118,7 @@ dependencies = [ "chrono", "pbjson", "pbjson-build", - "prost 0.13.2", + "prost", "prost-build", "serde", ] @@ -2239,16 +2249,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "prost" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" -dependencies = [ - "bytes", - "prost-derive 0.12.6", -] - [[package]] name = "prost" version = "0.13.2" @@ -2256,7 +2256,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b2ecbe40f08db5c006b5764a2645f7f3f141ce756412ac9e1dd6087e6d32995" dependencies = [ "bytes", - "prost-derive 0.13.2", + "prost-derive", ] [[package]] @@ -2267,32 +2267,19 @@ checksum = "f8650aabb6c35b860610e9cff5dc1af886c9e25073b7b1712a68972af4281302" dependencies = [ "bytes", "heck 0.5.0", - "itertools 0.13.0", + "itertools", "log", "multimap", "once_cell", "petgraph", "prettyplease", - "prost 0.13.2", - "prost-types 0.13.2", + "prost", + "prost-types", "regex", "syn", "tempfile", ] -[[package]] -name = "prost-derive" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" -dependencies = [ - "anyhow", - "itertools 0.12.1", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "prost-derive" version = "0.13.2" @@ -2300,28 +2287,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acf0c195eebb4af52c752bec4f52f645da98b6e92077a04110c7f349477ae5ac" dependencies = [ "anyhow", - "itertools 0.13.0", + "itertools", "proc-macro2", "quote", "syn", ] -[[package]] -name = "prost-types" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0" -dependencies = [ - "prost 0.12.6", -] - [[package]] name = "prost-types" version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60caa6738c7369b940c3d49246a8d1749323674c65cb13010134f5c9bad5b519" dependencies = [ - "prost 0.13.2", + "prost", ] [[package]] @@ -3058,9 +3036,9 @@ dependencies = [ "pbjson-build", "pbjson-types", "prettyplease", - "prost 0.13.2", + "prost", "prost-build", - "prost-types 0.13.2", + "prost-types", "protobuf-src", "schemars", "semver", diff --git a/Cargo.toml b/Cargo.toml index 126aeb7c..a7fae57f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,8 +40,8 @@ pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] arrow = { version = "53", feature = ["pyarrow"] } datafusion = { version = "41.0.0", features = ["pyarrow", "avro", "unicode_expressions"] } datafusion-substrait = { version = "41.0.0", optional = true } -prost = "0.12" # keep in line with `datafusion-substrait` -prost-types = "0.12" # keep in line with `datafusion-substrait` +prost = "0.13" # keep in line with `datafusion-substrait` +prost-types = "0.13" # keep in line with `datafusion-substrait` uuid = { version = "1.9", features = ["v4"] } mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] } async-trait = "0.1" @@ -64,5 +64,5 @@ lto = true codegen-units = 1 [patch.crates-io] -datafusion = { path = "../../arrow-datafusion-main/datafusion/core" } -datafusion-substrait = { path = "../../arrow-datafusion-main/datafusion/substrait" } +datafusion = { git = "https://github.com/apache/datafusion.git", rev = "c71a9d7508e37e5d082e22d2953a12b61d290df5" } +datafusion-substrait = { git = "https://github.com/apache/datafusion.git", rev = "c71a9d7508e37e5d082e22d2953a12b61d290df5" } diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 903d4a10..35a40ccd 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -28,7 +28,7 @@ from datafusion._internal import AggregateUDF from datafusion.catalog import Catalog, Table from datafusion.dataframe import DataFrame -from datafusion.expr import Expr +from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list from datafusion.record_batch import RecordBatchStream from datafusion.udf import ScalarUDF @@ -466,7 +466,7 @@ def register_listing_table( table_partition_cols: list[tuple[str, str]] | None = None, file_extension: str = ".parquet", schema: pyarrow.Schema | None = None, - file_sort_order: list[list[Expr]] | None = None, + file_sort_order: list[list[Expr | SortExpr]] | None = None, ) -> None: """Register multiple files as a single table. @@ -484,15 +484,18 @@ def register_listing_table( """ if table_partition_cols is None: table_partition_cols = [] - if file_sort_order is not None: - file_sort_order = [[x.expr for x in xs] for xs in file_sort_order] + file_sort_order_raw = ( + [sort_list_to_raw_sort_list(f) for f in file_sort_order] + if file_sort_order is not None + else None + ) self.ctx.register_listing_table( name, str(path), table_partition_cols, file_extension, schema, - file_sort_order, + file_sort_order_raw, ) def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 56dff22a..2328ef8f 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -33,7 +33,7 @@ from typing import Callable from datafusion._internal import DataFrame as DataFrameInternal -from datafusion.expr import Expr +from datafusion.expr import Expr, SortExpr, sort_or_default from datafusion._internal import ( LogicalPlan, ExecutionPlan, @@ -199,7 +199,7 @@ def aggregate( aggs = [e.expr for e in aggs] return DataFrame(self.df.aggregate(group_by, aggs)) - def sort(self, *exprs: Expr) -> DataFrame: + def sort(self, *exprs: Expr | SortExpr) -> DataFrame: """Sort the DataFrame by the specified sorting expressions. Note that any expression can be turned into a sort expression by @@ -211,8 +211,8 @@ def sort(self, *exprs: Expr) -> DataFrame: Returns: DataFrame after sorting. """ - exprs = [expr.expr for expr in exprs] - return DataFrame(self.df.sort(*exprs)) + exprs_raw = [sort_or_default(expr) for expr in exprs] + return DataFrame(self.df.sort(*exprs_raw)) def limit(self, count: int, offset: int = 0) -> DataFrame: """Return a new :py:class:`DataFrame` with a limited number of rows. diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index bd6a86fb..60f87cf0 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -159,6 +159,27 @@ ] +def expr_list_to_raw_expr_list( + expr_list: Optional[list[Expr]], +) -> Optional[list[expr_internal.Expr]]: + """Helper function to convert an optional list to raw expressions.""" + return [e.expr for e in expr_list] if expr_list is not None else None + + +def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr: + """Helper function to return a default Sort if an Expr is provided.""" + if isinstance(e, SortExpr): + return e.raw_sort + return SortExpr(e.expr, True, True).raw_sort + + +def sort_list_to_raw_sort_list( + sort_list: Optional[list[Expr | SortExpr]], +) -> Optional[list[expr_internal.SortExpr]]: + """Helper function to return an optional sort list to raw variant.""" + return [sort_or_default(e) for e in sort_list] if sort_list is not None else None + + class Expr: """Expression object. @@ -355,14 +376,14 @@ def alias(self, name: str) -> Expr: """Assign a name to the expression.""" return Expr(self.expr.alias(name)) - def sort(self, ascending: bool = True, nulls_first: bool = True) -> Expr: + def sort(self, ascending: bool = True, nulls_first: bool = True) -> SortExpr: """Creates a sort :py:class:`Expr` from an existing :py:class:`Expr`. Args: ascending: If true, sort in ascending order. nulls_first: Return null values first. """ - return Expr(self.expr.sort(ascending=ascending, nulls_first=nulls_first)) + return SortExpr(self.expr, ascending=ascending, nulls_first=nulls_first) def is_null(self) -> Expr: """Returns ``True`` if this expression is null.""" @@ -439,14 +460,14 @@ def column_name(self, plan: LogicalPlan) -> str: """Compute the output column name based on the provided logical plan.""" return self.expr.column_name(plan) - def order_by(self, *exprs: Expr) -> ExprFuncBuilder: + def order_by(self, *exprs: Expr | SortExpr) -> ExprFuncBuilder: """Set the ordering for a window or aggregate function. This function will create an :py:class:`ExprFuncBuilder` that can be used to set parameters for either window or aggregate functions. If used on any other type of expression, an error will be generated when ``build()`` is called. """ - return ExprFuncBuilder(self.expr.order_by(list(e.expr for e in exprs))) + return ExprFuncBuilder(self.expr.order_by([sort_or_default(e) for e in exprs])) def filter(self, filter: Expr) -> ExprFuncBuilder: """Filter an aggregate function. @@ -506,7 +527,9 @@ def order_by(self, *exprs: Expr) -> ExprFuncBuilder: Values given in ``exprs`` must be sort expressions. You can convert any other expression to a sort expression using `.sort()`. """ - return ExprFuncBuilder(self.builder.order_by(list(e.expr for e in exprs))) + return ExprFuncBuilder( + self.builder.order_by([sort_or_default(e) for e in exprs]) + ) def filter(self, filter: Expr) -> ExprFuncBuilder: """Filter values during aggregation.""" @@ -643,3 +666,22 @@ def end(self) -> Expr: Any non-matching cases will end in a `null` value. """ return Expr(self.case_builder.end()) + + +class SortExpr: + """Used to specify sorting on either a DataFrame or function""" + + def __init__(self, expr: Expr, ascending: bool, nulls_first: bool) -> None: + self.raw_sort = expr_internal.SortExpr(expr, ascending, nulls_first) + + def expr(self) -> Expr: + return Expr(self.raw_sort.expr()) + + def ascending(self) -> bool: + return self.raw_sort.ascending() + + def nulls_first(self) -> bool: + return self.raw_sort.nulls_first() + + def __repr__(self) -> str: + return self.raw_sort.__repr__() diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 163ff04e..e17449ae 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -19,7 +19,14 @@ from __future__ import annotations from datafusion._internal import functions as f, expr as expr_internal -from datafusion.expr import CaseBuilder, Expr, WindowFrame +from datafusion.expr import ( + CaseBuilder, + Expr, + WindowFrame, + SortExpr, + sort_list_to_raw_sort_list, + expr_list_to_raw_expr_list, +) from datafusion.context import SessionContext from datafusion.common import NullTreatment @@ -261,12 +268,6 @@ ] -def expr_list_to_raw_expr_list( - expr_list: Optional[list[Expr]], -) -> Optional[list[expr_internal.Expr]]: - return [e.expr for e in expr_list] if expr_list is not None else None - - def isnan(expr: Expr) -> Expr: """Returns true if a given number is +NaN or -NaN otherwise returns false.""" return Expr(f.isnan(expr.expr)) @@ -352,9 +353,9 @@ def concat_ws(separator: str, *args: Expr) -> Expr: return Expr(f.concat_ws(separator, args)) -def order_by(expr: Expr, ascending: bool = True, nulls_first: bool = True) -> Expr: +def order_by(expr: Expr, ascending: bool = True, nulls_first: bool = True) -> SortExpr: """Creates a new sort expression.""" - return Expr(f.order_by(expr.expr, ascending, nulls_first)) + return SortExpr(expr.expr, ascending=ascending, nulls_first=nulls_first) def alias(expr: Expr, name: str) -> Expr: @@ -405,7 +406,7 @@ def window( name: str, args: list[Expr], partition_by: list[Expr] | None = None, - order_by: list[Expr] | None = None, + order_by: list[Expr | SortExpr] | None = None, window_frame: WindowFrame | None = None, ctx: SessionContext | None = None, ) -> Expr: @@ -419,9 +420,9 @@ def window( """ args = [a.expr for a in args] partition_by = expr_list_to_raw_expr_list(partition_by) - order_by = expr_list_to_raw_expr_list(order_by) + order_by_raw = sort_list_to_raw_sort_list(order_by) window_frame = window_frame.window_frame if window_frame is not None else None - return Expr(f.window(name, args, partition_by, order_by, window_frame, ctx)) + return Expr(f.window(name, args, partition_by, order_by_raw, window_frame, ctx)) # scalar functions @@ -1608,7 +1609,7 @@ def array_agg( expression: Expr, distinct: bool = False, filter: Optional[Expr] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, ) -> Expr: """Aggregate values into an array. @@ -1625,7 +1626,7 @@ def array_agg( filter: If provided, only compute against rows for which the filter is True order_by: Order the resultant array values """ - order_by_raw = expr_list_to_raw_expr_list(order_by) + order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None return Expr( @@ -2107,7 +2108,7 @@ def regr_syy( def first_value( expression: Expr, filter: Optional[Expr] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the first value in a group of values. @@ -2123,7 +2124,7 @@ def first_value( order_by: Set the ordering of the expression to evaluate null_treatment: Assign whether to respect or ignull null values. """ - order_by_raw = expr_list_to_raw_expr_list(order_by) + order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None return Expr( @@ -2139,7 +2140,7 @@ def first_value( def last_value( expression: Expr, filter: Optional[Expr] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the last value in a group of values. @@ -2155,7 +2156,7 @@ def last_value( order_by: Set the ordering of the expression to evaluate null_treatment: Assign whether to respect or ignull null values. """ - order_by_raw = expr_list_to_raw_expr_list(order_by) + order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None return Expr( @@ -2172,7 +2173,7 @@ def nth_value( expression: Expr, n: int, filter: Optional[Expr] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the n-th value in a group of values. @@ -2189,7 +2190,7 @@ def nth_value( order_by: Set the ordering of the expression to evaluate null_treatment: Assign whether to respect or ignull null values. """ - order_by_raw = expr_list_to_raw_expr_list(order_by) + order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None return Expr( @@ -2293,7 +2294,7 @@ def lead( shift_offset: int = 1, default_value: Optional[Any] = None, partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, ) -> Expr: """Create a lead window function. @@ -2330,7 +2331,7 @@ def lead( partition_cols = ( [col.expr for col in partition_by] if partition_by is not None else None ) - order_cols = [col.expr for col in order_by] if order_by is not None else None + order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.lead( @@ -2338,7 +2339,7 @@ def lead( shift_offset, default_value, partition_by=partition_cols, - order_by=order_cols, + order_by=order_by_raw, ) ) @@ -2348,7 +2349,7 @@ def lag( shift_offset: int = 1, default_value: Optional[Any] = None, partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, ) -> Expr: """Create a lag window function. @@ -2382,7 +2383,7 @@ def lag( partition_cols = ( [col.expr for col in partition_by] if partition_by is not None else None ) - order_cols = [col.expr for col in order_by] if order_by is not None else None + order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.lag( @@ -2390,14 +2391,14 @@ def lag( shift_offset, default_value, partition_by=partition_cols, - order_by=order_cols, + order_by=order_by_raw, ) ) def row_number( partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, ) -> Expr: """Create a row number window function. @@ -2421,19 +2422,19 @@ def row_number( partition_cols = ( [col.expr for col in partition_by] if partition_by is not None else None ) - order_cols = [col.expr for col in order_by] if order_by is not None else None + order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.row_number( partition_by=partition_cols, - order_by=order_cols, + order_by=order_by_raw, ) ) def rank( partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, ) -> Expr: """Create a rank window function. @@ -2462,19 +2463,19 @@ def rank( partition_cols = ( [col.expr for col in partition_by] if partition_by is not None else None ) - order_cols = [col.expr for col in order_by] if order_by is not None else None + order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.rank( partition_by=partition_cols, - order_by=order_cols, + order_by=order_by_raw, ) ) def dense_rank( partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, ) -> Expr: """Create a dense_rank window function. @@ -2498,19 +2499,19 @@ def dense_rank( partition_cols = ( [col.expr for col in partition_by] if partition_by is not None else None ) - order_cols = [col.expr for col in order_by] if order_by is not None else None + order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.dense_rank( partition_by=partition_cols, - order_by=order_cols, + order_by=order_by_raw, ) ) def percent_rank( partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, ) -> Expr: """Create a percent_rank window function. @@ -2535,19 +2536,19 @@ def percent_rank( partition_cols = ( [col.expr for col in partition_by] if partition_by is not None else None ) - order_cols = [col.expr for col in order_by] if order_by is not None else None + order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.percent_rank( partition_by=partition_cols, - order_by=order_cols, + order_by=order_by_raw, ) ) def cume_dist( partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, ) -> Expr: """Create a cumulative distribution window function. @@ -2572,12 +2573,12 @@ def cume_dist( partition_cols = ( [col.expr for col in partition_by] if partition_by is not None else None ) - order_cols = [col.expr for col in order_by] if order_by is not None else None + order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.cume_dist( partition_by=partition_cols, - order_by=order_cols, + order_by=order_by_raw, ) ) @@ -2585,7 +2586,7 @@ def cume_dist( def ntile( groups: int, partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, ) -> Expr: """Create a n-tile window function. @@ -2613,13 +2614,13 @@ def ntile( partition_cols = ( [col.expr for col in partition_by] if partition_by is not None else None ) - order_cols = [col.expr for col in order_by] if order_by is not None else None + order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.ntile( Expr.literal(groups).expr, partition_by=partition_cols, - order_by=order_cols, + order_by=order_by_raw, ) ) @@ -2628,7 +2629,7 @@ def string_agg( expression: Expr, delimiter: str, filter: Optional[Expr] = None, - order_by: Optional[list[Expr]] = None, + order_by: Optional[list[Expr | SortExpr]] = None, ) -> Expr: """Concatenates the input strings. @@ -2645,7 +2646,7 @@ def string_agg( filter: If provided, only compute against rows for which the filter is True order_by: Set the ordering of the expression to evaluate """ - order_by_raw = expr_list_to_raw_expr_list(order_by) + order_by_raw = sort_list_to_raw_sort_list(order_by) filter_raw = filter.expr if filter is not None else None return Expr( diff --git a/python/datafusion/tests/test_sql.py b/python/datafusion/tests/test_sql.py index e41d0100..cbb2e9f5 100644 --- a/python/datafusion/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -264,14 +264,17 @@ def test_execute(ctx, tmp_path): # count result = ctx.sql("SELECT COUNT(a) AS cnt FROM t WHERE a IS NOT NULL").collect() + ctx.sql("SELECT COUNT(a) AS cnt FROM t WHERE a IS NOT NULL").show() + + expected_schema = pa.schema([("cnt", pa.int64(), False)]) + expected_values = pa.array([7], type=pa.int64()) + expected = [pa.RecordBatch.from_arrays([expected_values], schema=expected_schema)] - expected = pa.array([7], pa.int64()) - expected = [pa.RecordBatch.from_arrays([expected], ["cnt"])] assert result == expected # where - expected = pa.array([2], pa.int64()) - expected = [pa.RecordBatch.from_arrays([expected], ["cnt"])] + expected_values = pa.array([2], type=pa.int64()) + expected = [pa.RecordBatch.from_arrays([expected_values], schema=expected_schema)] result = ctx.sql("SELECT COUNT(a) AS cnt FROM t WHERE a > 10").collect() assert result == expected diff --git a/python/datafusion/tests/test_wrapper_coverage.py b/python/datafusion/tests/test_wrapper_coverage.py index 4a47de2e..c53a89c5 100644 --- a/python/datafusion/tests/test_wrapper_coverage.py +++ b/python/datafusion/tests/test_wrapper_coverage.py @@ -39,7 +39,10 @@ def missing_exports(internal_obj, wrapped_obj) -> None: internal_attr = getattr(internal_obj, attr) wrapped_attr = getattr(wrapped_obj, attr) - assert wrapped_attr is not None if internal_attr is not None else True + if internal_attr is not None: + if wrapped_attr is None: + print("Missing attribute: ", attr) + assert False if attr in ["__self__", "__class__"]: continue diff --git a/src/dataframe.rs b/src/dataframe.rs index 07c5f7f3..69c02782 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -45,7 +45,10 @@ use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; use crate::utils::{get_tokio_runtime, wait_for_future}; -use crate::{errors::DataFusionError, expr::{PyExpr, sort_expr::PySortExpr}}; +use crate::{ + errors::DataFusionError, + expr::{sort_expr::PySortExpr, PyExpr}, +}; /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. diff --git a/src/expr.rs b/src/expr.rs index 823b1143..c4ebedc6 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -94,7 +94,7 @@ pub mod unnest; pub mod unnest_expr; pub mod window; -use sort_expr::{PySortExpr, to_sort_expressions}; +use sort_expr::{to_sort_expressions, PySortExpr}; /// A PyExpr that can be used on a DataFrame #[pyclass(name = "Expr", module = "datafusion.expr", subclass)] @@ -152,7 +152,6 @@ impl PyExpr { Expr::Case(value) => Ok(case::PyCase::from(value.clone()).into_py(py)), Expr::Cast(value) => Ok(cast::PyCast::from(value.clone()).into_py(py)), Expr::TryCast(value) => Ok(cast::PyTryCast::from(value.clone()).into_py(py)), - Expr::Sort(value) => Ok(sort_expr::PySortExpr::from(value.clone()).into_py(py)), Expr::ScalarFunction(value) => Err(py_unsupported_variant_err(format!( "Converting Expr::ScalarFunction to a Python object is not implemented: {:?}", value @@ -169,9 +168,9 @@ impl PyExpr { Expr::ScalarSubquery(value) => { Ok(scalar_subquery::PyScalarSubquery::from(value.clone()).into_py(py)) } - Expr::Wildcard { qualifier } => Err(py_unsupported_variant_err(format!( - "Converting Expr::Wildcard to a Python object is not implemented : {:?}", - qualifier + Expr::Wildcard { qualifier, options } => Err(py_unsupported_variant_err(format!( + "Converting Expr::Wildcard to a Python object is not implemented : {:?} {:?}", + qualifier, options ))), Expr::GroupingSet(value) => { Ok(grouping_set::PyGroupingSet::from(value.clone()).into_py(py)) @@ -276,7 +275,7 @@ impl PyExpr { /// Create a sort PyExpr from an existing PyExpr. #[pyo3(signature = (ascending=true, nulls_first=true))] - pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyExpr { + pub fn sort(&self, ascending: bool, nulls_first: bool) -> PySortExpr { self.expr.clone().sort(ascending, nulls_first).into() } @@ -314,7 +313,6 @@ impl PyExpr { | Expr::Case { .. } | Expr::Cast { .. } | Expr::TryCast { .. } - | Expr::Sort { .. } | Expr::ScalarFunction { .. } | Expr::AggregateFunction { .. } | Expr::WindowFunction { .. } @@ -378,7 +376,6 @@ impl PyExpr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery(InSubquery { expr, .. }) => Ok(vec![PyExpr::from(*expr.clone())]), // Expr variants containing a collection of Expr(s) for operands @@ -621,11 +618,6 @@ impl PyExpr { input_plan: &LogicalPlan, ) -> Result, DataFusionError> { match expr { - Expr::Sort(Sort { expr, .. }) => { - // DataFusion does not support create_name for sort expressions (since they never - // appear in projections) so we just delegate to the contained expression instead - Self::expr_to_field(expr, input_plan) - } Expr::Wildcard { .. } => { // Since * could be any of the valid column names just return the first one Ok(Arc::new(input_plan.schema().field(0).clone())) diff --git a/src/expr/sort_expr.rs b/src/expr/sort_expr.rs index 34ed91f6..12f74e4d 100644 --- a/src/expr/sort_expr.rs +++ b/src/expr/sort_expr.rs @@ -52,10 +52,7 @@ impl Display for PySortExpr { } pub fn to_sort_expressions(order_by: Vec) -> Vec { - order_by - .iter() - .map(|e| e.sort.clone()) - .collect() + order_by.iter().map(|e| e.sort.clone()).collect() } pub fn py_sort_expr_list(expr: &[SortExpr]) -> PyResult> { @@ -64,6 +61,17 @@ pub fn py_sort_expr_list(expr: &[SortExpr]) -> PyResult> { #[pymethods] impl PySortExpr { + #[new] + fn new(expr: PyExpr, asc: bool, nulls_first: bool) -> Self { + Self { + sort: SortExpr { + expr: expr.into(), + asc, + nulls_first, + }, + } + } + fn expr(&self) -> PyResult { Ok(self.sort.expr.clone().into()) } diff --git a/src/functions.rs b/src/functions.rs index 7765f522..d4ba67de 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -176,15 +176,11 @@ fn regexp_replace( /// Creates a new Sort Expr #[pyfunction] fn order_by(expr: PyExpr, asc: bool, nulls_first: bool) -> PyResult { - Ok( - PySortExpr::from( - datafusion::logical_expr::expr::Sort { - expr: expr.expr, - asc, - nulls_first, - } - ) - ) + Ok(PySortExpr::from(datafusion::logical_expr::expr::Sort { + expr: expr.expr, + asc, + nulls_first, + })) } /// Creates a new Alias Expr @@ -296,7 +292,7 @@ fn window( name: &str, args: Vec, partition_by: Option>, - order_by: Option>, + order_by: Option>, window_frame: Option, ctx: Option, ) -> PyResult { @@ -318,11 +314,7 @@ fn window( order_by: order_by .unwrap_or_default() .into_iter() - .map(|x| x.expr) - .map(|e| match e { - Expr::Sort(_) => e, - _ => e.sort(true, true), - }) + .map(|x| x.into()) .collect::>(), window_frame, null_treatment: None, @@ -690,20 +682,20 @@ pub fn first_value( } // nth_value requires a non-expr argument -// #[pyfunction] -// #[pyo3(signature = (expr, n, distinct=None, filter=None, order_by=None, null_treatment=None))] -// pub fn nth_value( -// expr: PyExpr, -// n: i64, -// distinct: Option, -// filter: Option, -// order_by: Option>, -// null_treatment: Option, -// ) -> PyResult { -// // @todo: Commenting this function out for now as it requires some reworking -// let agg_fn = datafusion::functions_aggregate::nth_value::nth_value(vec![expr.expr, lit(n)]); -// add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) -// } +#[pyfunction] +#[pyo3(signature = (expr, n, distinct=None, filter=None, order_by=None, null_treatment=None))] +pub fn nth_value( + expr: PyExpr, + n: i64, + distinct: Option, + filter: Option, + order_by: Option>, + null_treatment: Option, +) -> PyResult { + // @todo: Commenting this function out for now as it requires some reworking + let agg_fn = datafusion::functions_aggregate::nth_value::nth_value(expr.expr, n, vec![]); + add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) +} // string_agg requires a non-expr argument #[pyfunction] @@ -776,7 +768,21 @@ pub fn lag( #[pyfunction] #[pyo3(signature = (partition_by=None, order_by=None))] -pub fn rank(partition_by: Option>, order_by: Option>) -> PyResult { +pub fn row_number( + partition_by: Option>, + order_by: Option>, +) -> PyResult { + let window_fn = datafusion::functions_window::expr_fn::row_number(); + + add_builder_fns_to_window(window_fn, partition_by, order_by) +} + +#[pyfunction] +#[pyo3(signature = (partition_by=None, order_by=None))] +pub fn rank( + partition_by: Option>, + order_by: Option>, +) -> PyResult { let window_fn = window_function::rank(); add_builder_fns_to_window(window_fn, partition_by, order_by) @@ -969,7 +975,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(regr_syy))?; m.add_wrapped(wrap_pyfunction!(first_value))?; m.add_wrapped(wrap_pyfunction!(last_value))?; - // m.add_wrapped(wrap_pyfunction!(nth_value))?; + m.add_wrapped(wrap_pyfunction!(nth_value))?; m.add_wrapped(wrap_pyfunction!(bit_and))?; m.add_wrapped(wrap_pyfunction!(bit_or))?; m.add_wrapped(wrap_pyfunction!(bit_xor))?; @@ -1017,6 +1023,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(lead))?; m.add_wrapped(wrap_pyfunction!(lag))?; m.add_wrapped(wrap_pyfunction!(rank))?; + m.add_wrapped(wrap_pyfunction!(row_number))?; m.add_wrapped(wrap_pyfunction!(dense_rank))?; m.add_wrapped(wrap_pyfunction!(percent_rank))?; m.add_wrapped(wrap_pyfunction!(cume_dist))?; diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs index 6e2a45e1..0f97ea44 100644 --- a/src/pyarrow_filter_expression.rs +++ b/src/pyarrow_filter_expression.rs @@ -27,7 +27,7 @@ use datafusion::logical_expr::{expr::InList, Between, BinaryExpr, Expr, Operator use crate::errors::DataFusionError; -#[derive(Debug, Clone)] +#[derive(Debug)] #[repr(transparent)] pub(crate) struct PyArrowFilterExpression(PyObject); diff --git a/src/udf.rs b/src/udf.rs index 4d57f87b..7d5db2f9 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -20,8 +20,8 @@ use std::sync::Arc; use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef}; -use datafusion::arrow::pyarrow::FromPyArrow; use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::pyarrow::FromPyArrow; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::error::DataFusionError; use datafusion::logical_expr::create_udf;