From 8229282f07322dcb2be8737657b380e1f923521c Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:28:08 +0100 Subject: [PATCH 1/2] feat: `SparkLikeNamespace` methods (#1779) --------- Co-authored-by: Marco Edward Gorelli --- narwhals/_dask/namespace.py | 4 +- narwhals/_pandas_like/namespace.py | 4 +- narwhals/_spark_like/expr.py | 14 + narwhals/_spark_like/group_by.py | 6 +- narwhals/_spark_like/namespace.py | 260 ++++++++++++++++-- tests/expr_and_series/any_horizontal_test.py | 10 +- tests/expr_and_series/concat_str_test.py | 2 +- tests/expr_and_series/is_null_test.py | 7 +- tests/expr_and_series/len_test.py | 5 +- tests/expr_and_series/max_horizontal_test.py | 12 +- tests/expr_and_series/mean_horizontal_test.py | 4 +- tests/expr_and_series/min_horizontal_test.py | 12 +- tests/expr_and_series/n_unique_test.py | 8 +- tests/expr_and_series/unary_test.py | 4 +- tests/frame/concat_test.py | 9 +- tests/group_by_test.py | 6 +- 16 files changed, 280 insertions(+), 87 deletions(-) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index d8b2b7a9a..23805afdc 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -322,8 +322,8 @@ def concat_str( self, exprs: Iterable[IntoDaskExpr], *more_exprs: IntoDaskExpr, - separator: str = "", - ignore_nulls: bool = False, + separator: str, + ignore_nulls: bool, ) -> DaskExpr: parsed_exprs = [ *parse_into_exprs(*exprs, namespace=self), diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 212c9c938..52e56d34f 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -380,8 +380,8 @@ def concat_str( self, exprs: Iterable[IntoPandasLikeExpr], *more_exprs: IntoPandasLikeExpr, - separator: str = "", - ignore_nulls: bool = False, + separator: str, + ignore_nulls: bool, ) -> PandasLikeExpr: parsed_exprs = [ *parse_into_exprs(*exprs, namespace=self), diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 353261c21..efd3975ff 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -481,6 +481,20 @@ def skew(self) -> Self: return self._from_call(F.skewness, "skew", returns_scalar=True) + def n_unique(self: Self) -> Self: + from pyspark.sql import functions as F # noqa: N812 + from pyspark.sql.types import IntegerType + + def _n_unique(_input: Column) -> Column: + return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType())) + + return self._from_call(_n_unique, "n_unique", returns_scalar=True) + + def is_null(self: Self) -> Self: + from pyspark.sql import functions as F # noqa: N812 + + return self._from_call(F.isnull, "is_null", returns_scalar=self._returns_scalar) + @property def str(self: Self) -> SparkLikeExprStringNamespace: return SparkLikeExprStringNamespace(self) diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index cbcf87692..66f4bf2b8 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -128,11 +128,7 @@ def agg_pyspark( if expr._output_names is None: # pragma: no cover msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) - - function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get( - expr._function_name, expr._function_name - ) - agg_func = get_spark_function(function_name, **expr._kwargs) + agg_func = get_spark_function(expr._function_name, **expr._kwargs) simple_aggregations.update( {output_name: agg_func(keys[0]) for output_name in expr._output_names} ) diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 56cc4d271..f53f66f77 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -3,18 +3,21 @@ import operator from functools import reduce from typing import TYPE_CHECKING +from typing import Iterable +from typing import Literal from narwhals._expression_parsing import combine_root_names from narwhals._expression_parsing import parse_into_exprs from narwhals._expression_parsing import reduce_output_names +from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.expr import SparkLikeExpr from narwhals._spark_like.utils import get_column_name from narwhals.typing import CompliantNamespace if TYPE_CHECKING: from pyspark.sql import Column + from pyspark.sql import DataFrame - from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.typing import IntoSparkLikeExpr from narwhals.dtypes import DType from narwhals.utils import Version @@ -43,26 +46,6 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]: kwargs={}, ) - def all_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: - parsed_exprs = parse_into_exprs(*exprs, namespace=self) - - def func(df: SparkLikeLazyFrame) -> list[Column]: - cols = [c for _expr in parsed_exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) - return [reduce(operator.and_, cols).alias(col_name)] - - return SparkLikeExpr( # type: ignore[abstract] - call=func, - depth=max(x._depth for x in parsed_exprs) + 1, - function_name="all_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), - returns_scalar=False, - backend_version=self._backend_version, - version=self._version, - kwargs={"exprs": exprs}, - ) - def col(self, *column_names: str) -> SparkLikeExpr: return SparkLikeExpr.from_column_names( *column_names, backend_version=self._backend_version, version=self._version @@ -90,6 +73,64 @@ def _lit(_: SparkLikeLazyFrame) -> list[Column]: kwargs={}, ) + def len(self) -> SparkLikeExpr: + def func(_: SparkLikeLazyFrame) -> list[Column]: + import pyspark.sql.functions as F # noqa: N812 + + return [F.count("*").alias("len")] + + return SparkLikeExpr( # type: ignore[abstract] + func, + depth=0, + function_name="len", + root_names=None, + output_names=["len"], + returns_scalar=True, + backend_version=self._backend_version, + version=self._version, + kwargs={}, + ) + + def all_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: + parsed_exprs = parse_into_exprs(*exprs, namespace=self) + + def func(df: SparkLikeLazyFrame) -> list[Column]: + cols = [c for _expr in parsed_exprs for c in _expr(df)] + col_name = get_column_name(df, cols[0]) + return [reduce(operator.and_, cols).alias(col_name)] + + return SparkLikeExpr( # type: ignore[abstract] + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="all_horizontal", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={"exprs": exprs}, + ) + + def any_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: + parsed_exprs = parse_into_exprs(*exprs, namespace=self) + + def func(df: SparkLikeLazyFrame) -> list[Column]: + cols = [c for _expr in parsed_exprs for c in _expr(df)] + col_name = get_column_name(df, cols[0]) + return [reduce(operator.or_, cols).alias(col_name)] + + return SparkLikeExpr( # type: ignore[abstract] + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="any_horizontal", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={"exprs": exprs}, + ) + def sum_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) @@ -116,3 +157,180 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: version=self._version, kwargs={"exprs": exprs}, ) + + def mean_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: + from pyspark.sql import functions as F # noqa: N812 + from pyspark.sql.types import IntegerType + + parsed_exprs = parse_into_exprs(*exprs, namespace=self) + + def func(df: SparkLikeLazyFrame) -> list[Column]: + cols = [c for _expr in parsed_exprs for c in _expr(df)] + col_name = get_column_name(df, cols[0]) + return [ + ( + reduce(operator.add, (F.coalesce(col, F.lit(0)) for col in cols)) + / reduce( + operator.add, + (col.isNotNull().cast(IntegerType()) for col in cols), + ) + ).alias(col_name) + ] + + return SparkLikeExpr( # type: ignore[abstract] + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="mean_horizontal", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={"exprs": exprs}, + ) + + def max_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: + from pyspark.sql import functions as F # noqa: N812 + + parsed_exprs = parse_into_exprs(*exprs, namespace=self) + + def func(df: SparkLikeLazyFrame) -> list[Column]: + cols = [c for _expr in parsed_exprs for c in _expr(df)] + col_name = get_column_name(df, cols[0]) + return [F.greatest(*cols).alias(col_name)] + + return SparkLikeExpr( # type: ignore[abstract] + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="max_horizontal", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={"exprs": exprs}, + ) + + def min_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: + from pyspark.sql import functions as F # noqa: N812 + + parsed_exprs = parse_into_exprs(*exprs, namespace=self) + + def func(df: SparkLikeLazyFrame) -> list[Column]: + cols = [c for _expr in parsed_exprs for c in _expr(df)] + col_name = get_column_name(df, cols[0]) + return [F.least(*cols).alias(col_name)] + + return SparkLikeExpr( # type: ignore[abstract] + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="min_horizontal", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={"exprs": exprs}, + ) + + def concat( + self, + items: Iterable[SparkLikeLazyFrame], + *, + how: Literal["horizontal", "vertical", "diagonal"], + ) -> SparkLikeLazyFrame: + dfs: list[DataFrame] = [item._native_frame for item in items] + if how == "horizontal": + msg = ( + "Horizontal concatenation is not supported for LazyFrame backed by " + "a PySpark DataFrame." + ) + raise NotImplementedError(msg) + + if how == "vertical": + cols_0 = dfs[0].columns + for i, df in enumerate(dfs[1:], start=1): + cols_current = df.columns + if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)): + msg = ( + "unable to vstack, column names don't match:\n" + f" - dataframe 0: {cols_0}\n" + f" - dataframe {i}: {cols_current}\n" + ) + raise TypeError(msg) + + return SparkLikeLazyFrame( + native_dataframe=reduce(lambda x, y: x.union(y), dfs), + backend_version=self._backend_version, + version=self._version, + ) + + if how == "diagonal": + return SparkLikeLazyFrame( + native_dataframe=reduce( + lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs + ), + backend_version=self._backend_version, + version=self._version, + ) + raise NotImplementedError + + def concat_str( + self, + exprs: Iterable[IntoSparkLikeExpr], + *more_exprs: IntoSparkLikeExpr, + separator: str, + ignore_nulls: bool, + ) -> SparkLikeExpr: + from pyspark.sql import functions as F # noqa: N812 + from pyspark.sql.types import StringType + + parsed_exprs = [ + *parse_into_exprs(*exprs, namespace=self), + *parse_into_exprs(*more_exprs, namespace=self), + ] + + def func(df: SparkLikeLazyFrame) -> list[Column]: + cols = (s.cast(StringType()) for _expr in parsed_exprs for s in _expr(df)) + null_mask = [F.isnull(s) for _expr in parsed_exprs for s in _expr(df)] + + if not ignore_nulls: + null_mask_result = reduce(lambda x, y: x | y, null_mask) + result = F.when( + ~null_mask_result, + reduce(lambda x, y: F.format_string(f"%s{separator}%s", x, y), cols), + ).otherwise(F.lit(None)) + else: + init_value, *values = [ + F.when(~nm, col).otherwise(F.lit("")) + for col, nm in zip(cols, null_mask) + ] + + separators = ( + F.when(nm, F.lit("")).otherwise(F.lit(separator)) + for nm in null_mask[:-1] + ) + result = reduce( + lambda x, y: F.format_string("%s%s", x, y), + (F.format_string("%s%s", s, v) for s, v in zip(separators, values)), + init_value, + ) + + return [result] + + return SparkLikeExpr( # type: ignore[abstract] + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="concat_str", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={ + "exprs": exprs, + "more_exprs": more_exprs, + "separator": separator, + "ignore_nulls": ignore_nulls, + }, + ) diff --git a/tests/expr_and_series/any_horizontal_test.py b/tests/expr_and_series/any_horizontal_test.py index 06157f393..4eb082b51 100644 --- a/tests/expr_and_series/any_horizontal_test.py +++ b/tests/expr_and_series/any_horizontal_test.py @@ -11,11 +11,7 @@ @pytest.mark.parametrize("expr1", ["a", nw.col("a")]) @pytest.mark.parametrize("expr2", ["b", nw.col("b")]) -def test_anyh( - request: pytest.FixtureRequest, constructor: Constructor, expr1: Any, expr2: Any -) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None: data = { "a": [False, False, True], "b": [False, True, True], @@ -27,9 +23,7 @@ def test_anyh( assert_equal_data(result, expected) -def test_anyh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_anyh_all(constructor: Constructor) -> None: data = { "a": [False, False, True], "b": [False, True, True], diff --git a/tests/expr_and_series/concat_str_test.py b/tests/expr_and_series/concat_str_test.py index 37d4a581d..7c9f259ba 100644 --- a/tests/expr_and_series/concat_str_test.py +++ b/tests/expr_and_series/concat_str_test.py @@ -27,7 +27,7 @@ def test_concat_str( expected: list[str], request: pytest.FixtureRequest, ) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = ( diff --git a/tests/expr_and_series/is_null_test.py b/tests/expr_and_series/is_null_test.py index cf4d2e73b..5d5250da9 100644 --- a/tests/expr_and_series/is_null_test.py +++ b/tests/expr_and_series/is_null_test.py @@ -1,17 +1,12 @@ from __future__ import annotations -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data -def test_null(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_null(constructor: Constructor) -> None: data_na = {"a": [None, 3, 2], "z": [7.0, None, None]} expected = {"a": [True, False, False], "z": [True, False, False]} df = nw.from_native(constructor(data_na)) diff --git a/tests/expr_and_series/len_test.py b/tests/expr_and_series/len_test.py index 142fe488b..fffcbd4a3 100644 --- a/tests/expr_and_series/len_test.py +++ b/tests/expr_and_series/len_test.py @@ -34,10 +34,7 @@ def test_len_chaining( assert_equal_data(df, expected) -def test_namespace_len(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_namespace_len(constructor: Constructor) -> None: df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})).select( nw.len(), a=nw.len() ) diff --git a/tests/expr_and_series/max_horizontal_test.py b/tests/expr_and_series/max_horizontal_test.py index 9df17fed3..c86e11318 100644 --- a/tests/expr_and_series/max_horizontal_test.py +++ b/tests/expr_and_series/max_horizontal_test.py @@ -14,12 +14,7 @@ @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) @pytest.mark.filterwarnings(r"ignore:.*All-NaN slice encountered:RuntimeWarning") -def test_maxh( - request: pytest.FixtureRequest, constructor: Constructor, col_expr: Any -) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_maxh(constructor: Constructor, col_expr: Any) -> None: df = nw.from_native(constructor(data)) result = df.select(horizontal_max=nw.max_horizontal(col_expr, nw.col("b"), "z")) expected = {"horizontal_max": expected_values} @@ -27,10 +22,7 @@ def test_maxh( @pytest.mark.filterwarnings(r"ignore:.*All-NaN slice encountered:RuntimeWarning") -def test_maxh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_maxh_all(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.max_horizontal(nw.all()), c=nw.max_horizontal(nw.all())) expected = {"a": expected_values, "c": expected_values} diff --git a/tests/expr_and_series/mean_horizontal_test.py b/tests/expr_and_series/mean_horizontal_test.py index 5ed472e31..c1652c837 100644 --- a/tests/expr_and_series/mean_horizontal_test.py +++ b/tests/expr_and_series/mean_horizontal_test.py @@ -13,7 +13,7 @@ def test_meanh( constructor: Constructor, col_expr: Any, request: pytest.FixtureRequest ) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, None, None], "b": [4, None, 6, None]} df = nw.from_native(constructor(data)) @@ -23,7 +23,7 @@ def test_meanh( def test_meanh_all(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [2, 4, 6], "b": [10, 20, 30]} df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/min_horizontal_test.py b/tests/expr_and_series/min_horizontal_test.py index bbb0b9149..787e3e2a4 100644 --- a/tests/expr_and_series/min_horizontal_test.py +++ b/tests/expr_and_series/min_horizontal_test.py @@ -14,12 +14,7 @@ @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) @pytest.mark.filterwarnings(r"ignore:.*All-NaN slice encountered:RuntimeWarning") -def test_minh( - request: pytest.FixtureRequest, constructor: Constructor, col_expr: Any -) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_minh(constructor: Constructor, col_expr: Any) -> None: df = nw.from_native(constructor(data)) result = df.select(horizontal_min=nw.min_horizontal(col_expr, nw.col("b"), "z")) expected = {"horizontal_min": expected_values} @@ -27,10 +22,7 @@ def test_minh( @pytest.mark.filterwarnings(r"ignore:.*All-NaN slice encountered:RuntimeWarning") -def test_minh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_minh_all(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.min_horizontal(nw.all()), c=nw.min_horizontal(nw.all())) expected = {"a": expected_values, "c": expected_values} diff --git a/tests/expr_and_series/n_unique_test.py b/tests/expr_and_series/n_unique_test.py index cfa14e0d7..1bcbe89fd 100644 --- a/tests/expr_and_series/n_unique_test.py +++ b/tests/expr_and_series/n_unique_test.py @@ -1,7 +1,5 @@ from __future__ import annotations -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -13,11 +11,9 @@ } -def test_n_unique(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_n_unique(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) - result = df.select(nw.all().n_unique()) + result = df.select(nw.col("a", "b").n_unique()) expected = {"a": [3], "b": [4]} assert_equal_data(result, expected) diff --git a/tests/expr_and_series/unary_test.py b/tests/expr_and_series/unary_test.py index 82f616a64..1eb1b0f99 100644 --- a/tests/expr_and_series/unary_test.py +++ b/tests/expr_and_series/unary_test.py @@ -11,7 +11,7 @@ def test_unary(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = { "a": [1, 3, 2], @@ -82,7 +82,7 @@ def test_unary_series(constructor_eager: ConstructorEager) -> None: def test_unary_two_elements( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1, 2], "b": [2, 10], "c": [2.0, None]} result = nw.from_native(constructor(data)).select( diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 6d8fdbda0..e38f4f4ff 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -32,12 +32,7 @@ def test_concat_horizontal( nw.concat([]) -def test_concat_vertical( - request: pytest.FixtureRequest, constructor: Constructor -) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_concat_vertical(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_left = ( nw.from_native(constructor(data)).lazy().rename({"a": "c", "b": "d"}).drop("z") @@ -68,7 +63,7 @@ def test_concat_vertical( def test_concat_diagonal( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data_1 = {"a": [1, 3], "b": [4, 6]} data_2 = {"a": [100, 200], "z": ["x", "y"]} diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 64b3844d0..911cc473e 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -300,8 +300,12 @@ def test_key_with_nulls( def test_key_with_nulls_ignored( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) + + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + data = {"b": [4, 5, None], "a": [1, 2, 3]} result = ( nw.from_native(constructor(data)) From 339683c529ccdfcabfa96627607203984061916f Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Fri, 10 Jan 2025 16:42:28 +0000 Subject: [PATCH 2/2] feat: implement anti-join, str.len_chars, and null_count for DuckDB (#1777) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- narwhals/_arrow/dataframe.py | 2 +- narwhals/_dask/dataframe.py | 2 +- narwhals/_duckdb/dataframe.py | 6 +---- narwhals/_duckdb/expr.py | 25 ++++++++++++++++----- narwhals/_pandas_like/dataframe.py | 2 +- tests/expr_and_series/null_count_test.py | 8 +------ tests/expr_and_series/str/len_chars_test.py | 6 +---- tests/frame/join_test.py | 3 --- 8 files changed, 25 insertions(+), 29 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index e6bb6fa65..c36f58938 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -333,7 +333,7 @@ def join( self: Self, other: Self, *, - how: Literal["left", "inner", "outer", "cross", "anti", "semi"], + how: Literal["left", "inner", "cross", "anti", "semi"], left_on: str | list[str] | None, right_on: str | list[str] | None, suffix: str, diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 16053d69a..35a0d045c 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -236,7 +236,7 @@ def join( self: Self, other: Self, *, - how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", + how: Literal["left", "inner", "cross", "anti", "semi"] = "inner", left_on: str | list[str] | None, right_on: str | list[str] | None, suffix: str, diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 33cfc19d2..98eca0bdb 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -215,7 +215,7 @@ def join( self: Self, other: Self, *, - how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", + how: Literal["left", "inner", "cross", "anti", "semi"] = "inner", left_on: str | list[str] | None, right_on: str | list[str] | None, suffix: str, @@ -226,10 +226,6 @@ def join( right_on = [right_on] original_alias = self._native_frame.alias - if how not in ("inner", "left", "semi", "cross"): - msg = "Only inner and left join is implemented for DuckDB" - raise NotImplementedError(msg) - if how == "cross": if self._backend_version < (1, 1, 4): msg = f"DuckDB>=1.1.4 is required for cross-join, found version: {self._backend_version}" diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index e5e612085..cfd2efdac 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -1,6 +1,5 @@ from __future__ import annotations -import functools from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -488,6 +487,15 @@ def min(self) -> Self: lambda _input: FunctionExpression("min", _input), "min", returns_scalar=True ) + def null_count(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("sum", _input.isnull().cast("int")), + "null_count", + returns_scalar=True, + ) + def is_null(self) -> Self: return self._from_call( lambda _input: _input.isnull(), "is_null", returns_scalar=self._returns_scalar @@ -497,11 +505,7 @@ def is_in(self, other: Sequence[Any]) -> Self: from duckdb import ConstantExpression return self._from_call( - lambda _input: functools.reduce( - lambda x, y: x | _input.isin(ConstantExpression(y)), - other[1:], - _input.isin(ConstantExpression(other[0])), - ), + lambda _input: _input.isin(*[ConstantExpression(x) for x in other]), "is_in", returns_scalar=self._returns_scalar, ) @@ -619,6 +623,15 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: func, "slice", returns_scalar=self._compliant_expr._returns_scalar ) + def len_chars(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("length", _input), + "len_chars", + returns_scalar=self._compliant_expr._returns_scalar, + ) + def to_lowercase(self) -> DuckDBExpr: from duckdb import FunctionExpression diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index e11c02710..b8b707851 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -512,7 +512,7 @@ def join( self, other: Self, *, - how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", + how: Literal["left", "inner", "cross", "anti", "semi"] = "inner", left_on: str | list[str] | None, right_on: str | list[str] | None, suffix: str, diff --git a/tests/expr_and_series/null_count_test.py b/tests/expr_and_series/null_count_test.py index a49fd79c8..db162363b 100644 --- a/tests/expr_and_series/null_count_test.py +++ b/tests/expr_and_series/null_count_test.py @@ -1,7 +1,5 @@ from __future__ import annotations -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -13,11 +11,7 @@ } -def test_null_count_expr( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_null_count_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a", "b").null_count()) expected = { diff --git a/tests/expr_and_series/str/len_chars_test.py b/tests/expr_and_series/str/len_chars_test.py index 1a318801a..f9c63e01c 100644 --- a/tests/expr_and_series/str/len_chars_test.py +++ b/tests/expr_and_series/str/len_chars_test.py @@ -1,7 +1,5 @@ from __future__ import annotations -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -10,9 +8,7 @@ data = {"a": ["foo", "foobar", "Café", "345", "東京"]} -def test_str_len_chars(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_str_len_chars(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a").str.len_chars()) expected = { diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index f15a1b79e..5ff112f31 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -166,10 +166,7 @@ def test_anti_join( join_key: list[str], filter_expr: nw.Expr, expected: dict[str, list[Any]], - request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zor ro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) other = df.filter(filter_expr)