From 01e66818a38d2e05d69267b1726f3faaf47a0672 Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Sat, 4 Jan 2025 14:48:32 -0600 Subject: [PATCH 01/18] test: add logical tests, import ConstructorEager type --- narwhals/_spark_like/expr.py | 16 +++++++++++++++ tests/spark_like_test.py | 38 ++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 4887e8001..43e3977ee 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -201,6 +201,22 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]: kwargs={**self._kwargs, "name": name}, ) + def all(self) -> Self: + def _all(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.bool_and(_input) + + return self._from_call(_all, "all", returns_scalar=True) + + def any(self) -> Self: + def _any(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.bool_any(_input) + + return self._from_call(_any, "any", returns_scalar=True) + def count(self) -> Self: def _count(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index ffb841b4e..46ed65a10 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -15,6 +15,7 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import ColumnNotFoundError +from tests.utils import ConstructorEager from tests.utils import assert_equal_data if TYPE_CHECKING: @@ -324,6 +325,43 @@ def test_sumh_all(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) +def test_any_all(pyspark_constructor: Constructor) -> None: + df = nw.from_native( + pyspark_constructor( + { + "a": [True, False, True], + "b": [True, True, True], + "c": [False, False, False], + } + ) + ) + result = df.select(nw.col("a", "b", "c").all()) + expected = {"a": [False], "b": [True], "c": [False]} + assert_equal_data(result, expected) + result = df.select(nw.all().any()) + expected = {"a": [True], "b": [True], "c": [False]} + assert_equal_data(result, expected) + + +def test_any_all_series(constructor_eager: ConstructorEager) -> None: + df = nw.from_native( + constructor_eager( + { + "a": [True, False, True], + "b": [True, True, True], + "c": [False, False, False], + } + ), + eager_only=True, + ) + result = {"a": [df["a"].all()], "b": [df["b"].all()], "c": [df["c"].all()]} + expected = {"a": [False], "b": [True], "c": [False]} + assert_equal_data(result, expected) + result = {"a": [df["a"].any()], "b": [df["b"].any()], "c": [df["c"].any()]} + expected = {"a": [True], "b": [True], "c": [False]} + assert_equal_data(result, expected) + + # copied from tests/expr_and_series/count_test.py def test_count(pyspark_constructor: Constructor) -> None: data = {"a": [1, 2, 3], "b": [4, None, 6], "z": [7.0, None, None]} From 76424d6433a51a29709743488ac4dc7c086e215f Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Sat, 4 Jan 2025 15:07:52 -0600 Subject: [PATCH 02/18] feat: add any_horizontal method --- narwhals/_spark_like/namespace.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index d34867b00..2ca711350 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -62,6 +62,26 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: kwargs={"exprs": exprs}, ) + def any_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: + parsed_exprs = parse_into_exprs(*exprs, namespace=self) + + def func(df: SparkLikeLazyFrame) -> list[Column]: + cols = [c for _expr in parsed_exprs for c in _expr(df)] + col_name = get_column_name(df, cols[0]) + return [reduce(operator.or_, cols).alias(col_name)] + + return SparkLikeExpr( # type: ignore[abstract] + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="any_horizontal", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={"exprs": exprs}, + ) + def col(self, *column_names: str) -> SparkLikeExpr: return SparkLikeExpr.from_column_names( *column_names, backend_version=self._backend_version, version=self._version From c11a133b32dac8d1283a4e58da6a360cb414c25c Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Sat, 4 Jan 2025 15:09:11 -0600 Subject: [PATCH 03/18] test: add any_horizontal test, update any_all reference --- tests/spark_like_test.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index 46ed65a10..cc9f6f509 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -286,6 +286,35 @@ def test_allh_all(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) +# copied from tests/expr_and_series/any_horizontal_test.py +@pytest.mark.parametrize("expr1", ["a", nw.col("a")]) +@pytest.mark.parametrize("expr2", ["b", nw.col("b")]) +def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None: + data = { + "a": [False, False, True], + "b": [False, True, True], + } + df = nw.from_native(constructor(data)) + result = df.select(any=nw.any_horizontal(expr1, expr2)) + + expected = {"any": [False, True, True]} + assert_equal_data(result, expected) + + +def test_anyh_all(constructor: Constructor) -> None: + data = { + "a": [False, False, True], + "b": [False, True, True], + } + df = nw.from_native(constructor(data)) + result = df.select(any=nw.any_horizontal(nw.all())) + expected = {"any": [False, True, True]} + assert_equal_data(result, expected) + result = df.select(nw.any_horizontal(nw.all())) + expected = {"a": [False, True, True]} + assert_equal_data(result, expected) + + # copied from tests/expr_and_series/sum_horizontal_test.py @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) def test_sumh(pyspark_constructor: Constructor, col_expr: Any) -> None: @@ -325,6 +354,7 @@ def test_sumh_all(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) +# copied from tests/expr_and_series/any_all_test.py def test_any_all(pyspark_constructor: Constructor) -> None: df = nw.from_native( pyspark_constructor( From 77b7c4f5a387900c7eef6f88bb9e7ee94a819372 Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Sat, 4 Jan 2025 15:09:39 -0600 Subject: [PATCH 04/18] dev: correct bool_any to bool_or --- narwhals/_spark_like/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 43e3977ee..497c540f0 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -213,7 +213,7 @@ def any(self) -> Self: def _any(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 - return F.bool_any(_input) + return F.bool_or(_input) return self._from_call(_any, "any", returns_scalar=True) From 3383be43bfb5a6b238b704b101460f1d1ae05892 Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Sat, 4 Jan 2025 16:00:47 -0600 Subject: [PATCH 05/18] feat: add null_count expr --- narwhals/_spark_like/expr.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 497c540f0..9fb0f831f 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -249,6 +249,14 @@ def _min(_input: Column) -> Column: return self._from_call(_min, "min", returns_scalar=True) + def null_count(self) -> Self: + def _null_count(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.count_if(F.isnull(_input)) + + return self._from_call(_null_count, "null_count", returns_scalar=True) + def sum(self) -> Self: def _sum(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 From ed775145c81ec7770ded2fae51a6826394ce1a04 Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Sat, 4 Jan 2025 16:02:32 -0600 Subject: [PATCH 06/18] test: add tests for null_count expr --- tests/spark_like_test.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index cc9f6f509..03419109e 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -442,6 +442,28 @@ def test_expr_min_expr(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) +# copied from tests/expr_and_series/null_count_test.py +def test_null_count_expr(constructor: Constructor) -> None: + data = { + "a": [1.0, None, None, 3.0], + "b": [1.0, None, 4, 5.0], + } + df = nw.from_native(constructor(data)) + result = df.select(nw.all().null_count()) + expected = { + "a": [2], + "b": [1], + } + assert_equal_data(result, expected) + + +def test_null_count_series(constructor_eager: ConstructorEager) -> None: + data = [1, 2, None] + series = nw.from_native(constructor_eager({"a": data}), eager_only=True)["a"] + result = series.null_count() + assert result == 1 + + # copied from tests/expr_and_series/min_test.py @pytest.mark.parametrize("expr", [nw.col("a", "b", "z").sum(), nw.sum("a", "b", "z")]) def test_expr_sum_expr(pyspark_constructor: Constructor, expr: nw.Expr) -> None: From 85110bffc8c5b582f0f150d650bb748303d9d3af Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Sun, 5 Jan 2025 16:10:39 -0600 Subject: [PATCH 07/18] tests: update constructor to pyspark_constructor --- tests/spark_like_test.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index 03419109e..4aae65409 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -289,24 +289,24 @@ def test_allh_all(pyspark_constructor: Constructor) -> None: # copied from tests/expr_and_series/any_horizontal_test.py @pytest.mark.parametrize("expr1", ["a", nw.col("a")]) @pytest.mark.parametrize("expr2", ["b", nw.col("b")]) -def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None: +def test_anyh(pyspark_constructor: Constructor, expr1: Any, expr2: Any) -> None: data = { "a": [False, False, True], "b": [False, True, True], } - df = nw.from_native(constructor(data)) + df = nw.from_native(pyspark_constructor(data)) result = df.select(any=nw.any_horizontal(expr1, expr2)) expected = {"any": [False, True, True]} assert_equal_data(result, expected) -def test_anyh_all(constructor: Constructor) -> None: +def test_anyh_all(pyspark_constructor: Constructor) -> None: data = { "a": [False, False, True], "b": [False, True, True], } - df = nw.from_native(constructor(data)) + df = nw.from_native(pyspark_constructor(data)) result = df.select(any=nw.any_horizontal(nw.all())) expected = {"any": [False, True, True]} assert_equal_data(result, expected) @@ -443,12 +443,12 @@ def test_expr_min_expr(pyspark_constructor: Constructor) -> None: # copied from tests/expr_and_series/null_count_test.py -def test_null_count_expr(constructor: Constructor) -> None: +def test_null_count_expr(pyspark_constructor: Constructor) -> None: data = { "a": [1.0, None, None, 3.0], "b": [1.0, None, 4, 5.0], } - df = nw.from_native(constructor(data)) + df = nw.from_native(pyspark_constructor(data)) result = df.select(nw.all().null_count()) expected = { "a": [2], @@ -457,13 +457,6 @@ def test_null_count_expr(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_null_count_series(constructor_eager: ConstructorEager) -> None: - data = [1, 2, None] - series = nw.from_native(constructor_eager({"a": data}), eager_only=True)["a"] - result = series.null_count() - assert result == 1 - - # copied from tests/expr_and_series/min_test.py @pytest.mark.parametrize("expr", [nw.col("a", "b", "z").sum(), nw.sum("a", "b", "z")]) def test_expr_sum_expr(pyspark_constructor: Constructor, expr: nw.Expr) -> None: From eda44db1864a79c048449119d3199b0e99d15d56 Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Sun, 5 Jan 2025 16:13:27 -0600 Subject: [PATCH 08/18] tests: remove eager tests --- tests/spark_like_test.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index 4aae65409..9fb69ce5d 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -15,7 +15,6 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import ColumnNotFoundError -from tests.utils import ConstructorEager from tests.utils import assert_equal_data if TYPE_CHECKING: @@ -373,25 +372,6 @@ def test_any_all(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_any_all_series(constructor_eager: ConstructorEager) -> None: - df = nw.from_native( - constructor_eager( - { - "a": [True, False, True], - "b": [True, True, True], - "c": [False, False, False], - } - ), - eager_only=True, - ) - result = {"a": [df["a"].all()], "b": [df["b"].all()], "c": [df["c"].all()]} - expected = {"a": [False], "b": [True], "c": [False]} - assert_equal_data(result, expected) - result = {"a": [df["a"].any()], "b": [df["b"].any()], "c": [df["c"].any()]} - expected = {"a": [True], "b": [True], "c": [False]} - assert_equal_data(result, expected) - - # copied from tests/expr_and_series/count_test.py def test_count(pyspark_constructor: Constructor) -> None: data = {"a": [1, 2, 3], "b": [4, None, 6], "z": [7.0, None, None]} From 94e9a0433c34285e92e4e73a381dbbdab790efff Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Sun, 5 Jan 2025 23:31:33 -0600 Subject: [PATCH 09/18] feat: initial draft of replace_strict method --- narwhals/_spark_like/expr.py | 43 ++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 9fb0f831f..0379ec0ab 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -1,14 +1,21 @@ from __future__ import annotations from copy import copy +from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Mapping from typing import Sequence from narwhals._spark_like.utils import get_column_name from narwhals._spark_like.utils import maybe_evaluate +from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantExpr + +if TYPE_CHECKING: + from narwhals.dtypes import DType + from narwhals.utils import Implementation from narwhals.utils import parse_version @@ -257,6 +264,42 @@ def _null_count(_input: Column) -> Column: return self._from_call(_null_count, "null_count", returns_scalar=True) + def replace_strict( + self, + old: Sequence[Any] | Mapping[Any, Any], + new: Sequence[Any], + default: Any | None = None, + return_dtype: DType | None = None, + ) -> Self: + def _replace_strict(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + if isinstance(old, Mapping): + mapping = old + else: + if new is None: + msg = "`new` argument is required if `old` argument is not a Mapping type" + raise InvalidOperationError(msg) + mapping = dict(zip(old, new)) # QUESTION: check len(old) == len(new)? + + mapping_expr = F.create_map([F.lit(x) for x in chain(*mapping.items())]) + replacements = mapping_expr[_input] + + if default: + replacements = F.coalesce(replacements, F.lit(default)) + + # QUESTION: check all values mapped? + # we can check that all values are mapped using: F.bool_and(replacements.isNotNull()) + # however, I'm not sure how to validate this as an expression - F.assert_true looked promising + # until I realized it will convert the expression to NULL if the condition is True + + if return_dtype: + replacements = replacements.cast(return_dtype) + + return replacements + + return self._from_call(_replace_strict, "replace_strict", returns_scalar=False) + def sum(self) -> Self: def _sum(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 From 51f7b2c67f949a8bdaf7d662ac85454b391e23bb Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Sun, 5 Jan 2025 23:32:47 -0600 Subject: [PATCH 10/18] feat: initial draft of replace_strict method --- narwhals/_spark_like/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 0379ec0ab..87ce56cff 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -267,7 +267,7 @@ def _null_count(_input: Column) -> Column: def replace_strict( self, old: Sequence[Any] | Mapping[Any, Any], - new: Sequence[Any], + new: Sequence[Any] | None = None, default: Any | None = None, return_dtype: DType | None = None, ) -> Self: From 972df870c50181480b12573f1e9da271d958dac3 Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Sun, 5 Jan 2025 23:34:54 -0600 Subject: [PATCH 11/18] test: add lazy tests for replace_strict method --- tests/spark_like_test.py | 73 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index 9fb69ce5d..99d282f16 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -15,11 +15,13 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import ColumnNotFoundError +from tests.utils import POLARS_VERSION from tests.utils import assert_equal_data if TYPE_CHECKING: from pyspark.sql import SparkSession + from narwhals.dtypes import DType from narwhals.typing import IntoFrame from tests.utils import Constructor @@ -437,6 +439,77 @@ def test_null_count_expr(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" +) +@pytest.mark.parametrize("return_dtype", [nw.String(), None]) +def test_replace_strict( + pyspark_constructor: Constructor, + request: pytest.FixtureRequest, + return_dtype: DType | None, +) -> None: + if "dask" in str(pyspark_constructor): # QUESTION: remove? + request.applymarker(pytest.mark.xfail) + df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]})) + result = df.select( + nw.col("a").replace_strict( + [1, 2, 3], ["one", "two", "three"], return_dtype=return_dtype + ) + ) + assert_equal_data(result, {"a": ["one", "two", "three"]}) + + +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" +) +def test_replace_non_full( + pyspark_constructor: Constructor, request: pytest.FixtureRequest +) -> None: + from polars.exceptions import PolarsError + + if "dask" in str(pyspark_constructor): # QUESTION: remove? + request.applymarker(pytest.mark.xfail) + df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]})) + if isinstance(df, nw.LazyFrame): + with pytest.raises((ValueError, PolarsError)): + df.select( + nw.col("a").replace_strict([1, 3], [3, 4], return_dtype=nw.Int64) + ).collect() + else: + with pytest.raises((ValueError, PolarsError)): + df.select(nw.col("a").replace_strict([1, 3], [3, 4], return_dtype=nw.Int64)) + + +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" +) +def test_replace_strict_mapping( + pyspark_constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "dask" in str(pyspark_constructor): # QUESTION: remove? + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]})) + result = df.select( + nw.col("a").replace_strict( + {1: "one", 2: "two", 3: "three"}, return_dtype=nw.String() + ) + ) + assert_equal_data(result, {"a": ["one", "two", "three"]}) + + +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" +) +def test_replace_strict_invalid(pyspark_constructor: Constructor) -> None: + df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]})) + with pytest.raises( + TypeError, + match="`new` argument is required if `old` argument is not a Mapping type", + ): + df.select(nw.col("a").replace_strict(old=[1, 2, 3])) + + # copied from tests/expr_and_series/min_test.py @pytest.mark.parametrize("expr", [nw.col("a", "b", "z").sum(), nw.sum("a", "b", "z")]) def test_expr_sum_expr(pyspark_constructor: Constructor, expr: nw.Expr) -> None: From b4dac6580153791b59ebc7d7e9468b9842f1da53 Mon Sep 17 00:00:00 2001 From: Lucas Nelson Date: Thu, 9 Jan 2025 14:38:15 -0600 Subject: [PATCH 12/18] Update expr.py Co-authored-by: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> --- narwhals/_spark_like/expr.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 87ce56cff..b26a88a6f 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -209,12 +209,9 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]: ) def all(self) -> Self: - def _all(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.bool_and(_input) + from pyspark.sql import functions as F # noqa: N812 - return self._from_call(_all, "all", returns_scalar=True) + return self._from_call(F.bool_and, "all", returns_scalar=True) def any(self) -> Self: def _any(_input: Column) -> Column: From f6a7312eaa3d014fa8ec3d796061bd24913b0f5c Mon Sep 17 00:00:00 2001 From: Lucas Nelson Date: Thu, 9 Jan 2025 14:38:24 -0600 Subject: [PATCH 13/18] Update expr.py Co-authored-by: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> --- narwhals/_spark_like/expr.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index b26a88a6f..d1f096cd8 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -214,12 +214,9 @@ def all(self) -> Self: return self._from_call(F.bool_and, "all", returns_scalar=True) def any(self) -> Self: - def _any(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.bool_or(_input) + from pyspark.sql import functions as F # noqa: N812 - return self._from_call(_any, "any", returns_scalar=True) + return self._from_call(F.bool_or, "any", returns_scalar=True) def count(self) -> Self: def _count(_input: Column) -> Column: From 1bd8cafee257ad89b742bb85c1990f724d34947a Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Thu, 9 Jan 2025 14:59:59 -0600 Subject: [PATCH 14/18] remove replace_strict method --- narwhals/_spark_like/expr.py | 43 ------------------------------------ 1 file changed, 43 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index d1f096cd8..d33e9e3c7 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -1,21 +1,14 @@ from __future__ import annotations from copy import copy -from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Mapping from typing import Sequence from narwhals._spark_like.utils import get_column_name from narwhals._spark_like.utils import maybe_evaluate -from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantExpr - -if TYPE_CHECKING: - from narwhals.dtypes import DType - from narwhals.utils import Implementation from narwhals.utils import parse_version @@ -258,42 +251,6 @@ def _null_count(_input: Column) -> Column: return self._from_call(_null_count, "null_count", returns_scalar=True) - def replace_strict( - self, - old: Sequence[Any] | Mapping[Any, Any], - new: Sequence[Any] | None = None, - default: Any | None = None, - return_dtype: DType | None = None, - ) -> Self: - def _replace_strict(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - if isinstance(old, Mapping): - mapping = old - else: - if new is None: - msg = "`new` argument is required if `old` argument is not a Mapping type" - raise InvalidOperationError(msg) - mapping = dict(zip(old, new)) # QUESTION: check len(old) == len(new)? - - mapping_expr = F.create_map([F.lit(x) for x in chain(*mapping.items())]) - replacements = mapping_expr[_input] - - if default: - replacements = F.coalesce(replacements, F.lit(default)) - - # QUESTION: check all values mapped? - # we can check that all values are mapped using: F.bool_and(replacements.isNotNull()) - # however, I'm not sure how to validate this as an expression - F.assert_true looked promising - # until I realized it will convert the expression to NULL if the condition is True - - if return_dtype: - replacements = replacements.cast(return_dtype) - - return replacements - - return self._from_call(_replace_strict, "replace_strict", returns_scalar=False) - def sum(self) -> Self: def _sum(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 From 3b2b88a6950b2b1306dc8c0e9cf45ab2d637c0f1 Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Thu, 9 Jan 2025 15:00:38 -0600 Subject: [PATCH 15/18] remove replace_strict tests --- tests/spark_like_test.py | 73 ---------------------------------------- 1 file changed, 73 deletions(-) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index 99d282f16..9fb69ce5d 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -15,13 +15,11 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import ColumnNotFoundError -from tests.utils import POLARS_VERSION from tests.utils import assert_equal_data if TYPE_CHECKING: from pyspark.sql import SparkSession - from narwhals.dtypes import DType from narwhals.typing import IntoFrame from tests.utils import Constructor @@ -439,77 +437,6 @@ def test_null_count_expr(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) -@pytest.mark.skipif( - POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" -) -@pytest.mark.parametrize("return_dtype", [nw.String(), None]) -def test_replace_strict( - pyspark_constructor: Constructor, - request: pytest.FixtureRequest, - return_dtype: DType | None, -) -> None: - if "dask" in str(pyspark_constructor): # QUESTION: remove? - request.applymarker(pytest.mark.xfail) - df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]})) - result = df.select( - nw.col("a").replace_strict( - [1, 2, 3], ["one", "two", "three"], return_dtype=return_dtype - ) - ) - assert_equal_data(result, {"a": ["one", "two", "three"]}) - - -@pytest.mark.skipif( - POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" -) -def test_replace_non_full( - pyspark_constructor: Constructor, request: pytest.FixtureRequest -) -> None: - from polars.exceptions import PolarsError - - if "dask" in str(pyspark_constructor): # QUESTION: remove? - request.applymarker(pytest.mark.xfail) - df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]})) - if isinstance(df, nw.LazyFrame): - with pytest.raises((ValueError, PolarsError)): - df.select( - nw.col("a").replace_strict([1, 3], [3, 4], return_dtype=nw.Int64) - ).collect() - else: - with pytest.raises((ValueError, PolarsError)): - df.select(nw.col("a").replace_strict([1, 3], [3, 4], return_dtype=nw.Int64)) - - -@pytest.mark.skipif( - POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" -) -def test_replace_strict_mapping( - pyspark_constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "dask" in str(pyspark_constructor): # QUESTION: remove? - request.applymarker(pytest.mark.xfail) - - df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]})) - result = df.select( - nw.col("a").replace_strict( - {1: "one", 2: "two", 3: "three"}, return_dtype=nw.String() - ) - ) - assert_equal_data(result, {"a": ["one", "two", "three"]}) - - -@pytest.mark.skipif( - POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0" -) -def test_replace_strict_invalid(pyspark_constructor: Constructor) -> None: - df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]})) - with pytest.raises( - TypeError, - match="`new` argument is required if `old` argument is not a Mapping type", - ): - df.select(nw.col("a").replace_strict(old=[1, 2, 3])) - - # copied from tests/expr_and_series/min_test.py @pytest.mark.parametrize("expr", [nw.col("a", "b", "z").sum(), nw.sum("a", "b", "z")]) def test_expr_sum_expr(pyspark_constructor: Constructor, expr: nw.Expr) -> None: From d2772eb0ca2001ea74e8d819a3e326c4d87a43e9 Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Thu, 9 Jan 2025 15:49:09 -0600 Subject: [PATCH 16/18] remove any_h references --- narwhals/_spark_like/namespace.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 2ca711350..d34867b00 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -62,26 +62,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: kwargs={"exprs": exprs}, ) - def any_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: - parsed_exprs = parse_into_exprs(*exprs, namespace=self) - - def func(df: SparkLikeLazyFrame) -> list[Column]: - cols = [c for _expr in parsed_exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) - return [reduce(operator.or_, cols).alias(col_name)] - - return SparkLikeExpr( # type: ignore[abstract] - call=func, - depth=max(x._depth for x in parsed_exprs) + 1, - function_name="any_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), - returns_scalar=False, - backend_version=self._backend_version, - version=self._version, - kwargs={"exprs": exprs}, - ) - def col(self, *column_names: str) -> SparkLikeExpr: return SparkLikeExpr.from_column_names( *column_names, backend_version=self._backend_version, version=self._version From 6e59b1b867b54c0a2033ce3acae8e9959f90003c Mon Sep 17 00:00:00 2001 From: lucas-nelson-uiuc Date: Thu, 9 Jan 2025 15:49:21 -0600 Subject: [PATCH 17/18] remove any_h references --- tests/spark_like_test.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index 9fb69ce5d..7ea7addac 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -285,35 +285,6 @@ def test_allh_all(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) -# copied from tests/expr_and_series/any_horizontal_test.py -@pytest.mark.parametrize("expr1", ["a", nw.col("a")]) -@pytest.mark.parametrize("expr2", ["b", nw.col("b")]) -def test_anyh(pyspark_constructor: Constructor, expr1: Any, expr2: Any) -> None: - data = { - "a": [False, False, True], - "b": [False, True, True], - } - df = nw.from_native(pyspark_constructor(data)) - result = df.select(any=nw.any_horizontal(expr1, expr2)) - - expected = {"any": [False, True, True]} - assert_equal_data(result, expected) - - -def test_anyh_all(pyspark_constructor: Constructor) -> None: - data = { - "a": [False, False, True], - "b": [False, True, True], - } - df = nw.from_native(pyspark_constructor(data)) - result = df.select(any=nw.any_horizontal(nw.all())) - expected = {"any": [False, True, True]} - assert_equal_data(result, expected) - result = df.select(nw.any_horizontal(nw.all())) - expected = {"a": [False, True, True]} - assert_equal_data(result, expected) - - # copied from tests/expr_and_series/sum_horizontal_test.py @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) def test_sumh(pyspark_constructor: Constructor, col_expr: Any) -> None: From a2f4993bb91ee2464a0ef027a083e1525d40a598 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Fri, 10 Jan 2025 08:31:08 +0100 Subject: [PATCH 18/18] pyspark test --- tests/expr_and_series/any_all_test.py | 9 +- tests/expr_and_series/null_count_test.py | 4 +- tests/spark_like_test.py | 945 ----------------------- 3 files changed, 4 insertions(+), 954 deletions(-) delete mode 100644 tests/spark_like_test.py diff --git a/tests/expr_and_series/any_all_test.py b/tests/expr_and_series/any_all_test.py index 7fd81f04d..e8554316e 100644 --- a/tests/expr_and_series/any_all_test.py +++ b/tests/expr_and_series/any_all_test.py @@ -1,17 +1,12 @@ from __future__ import annotations -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data -def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_any_all(constructor: Constructor) -> None: df = nw.from_native( constructor( { @@ -24,7 +19,7 @@ def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> No result = df.select(nw.col("a", "b", "c").all()) expected = {"a": [False], "b": [True], "c": [False]} assert_equal_data(result, expected) - result = df.select(nw.all().any()) + result = df.select(nw.col("a", "b", "c").any()) expected = {"a": [True], "b": [True], "c": [False]} assert_equal_data(result, expected) diff --git a/tests/expr_and_series/null_count_test.py b/tests/expr_and_series/null_count_test.py index 3bd15c66c..a49fd79c8 100644 --- a/tests/expr_and_series/null_count_test.py +++ b/tests/expr_and_series/null_count_test.py @@ -16,10 +16,10 @@ def test_null_count_expr( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.select(nw.all().null_count()) + result = df.select(nw.col("a", "b").null_count()) expected = { "a": [2], "b": [1], diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py deleted file mode 100644 index 3997f409a..000000000 --- a/tests/spark_like_test.py +++ /dev/null @@ -1,945 +0,0 @@ -"""PySpark support in Narwhals is still _very_ limited. - -Start with a simple test file whilst we develop the basics. -Once we're a bit further along, we can integrate PySpark tests into the main test suite. -""" - -from __future__ import annotations - -from contextlib import nullcontext as does_not_raise -from typing import TYPE_CHECKING -from typing import Any - -import pandas as pd -import pytest - -import narwhals.stable.v1 as nw -from narwhals.exceptions import ColumnNotFoundError -from tests.utils import assert_equal_data - -if TYPE_CHECKING: - from pyspark.sql import SparkSession - - from narwhals.typing import IntoFrame - from tests.utils import Constructor - - -def _pyspark_constructor_with_session(obj: Any, spark_session: SparkSession) -> IntoFrame: - # NaN and NULL are not the same in PySpark - pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index() - return ( # type: ignore[no-any-return] - spark_session.createDataFrame(pd_df).repartition(2).orderBy("index").drop("index") - ) - - -@pytest.fixture(params=[_pyspark_constructor_with_session]) -def pyspark_constructor( - request: pytest.FixtureRequest, spark_session: SparkSession -) -> Constructor: - def _constructor(obj: Any) -> IntoFrame: - return request.param(obj, spark_session) # type: ignore[no-any-return] - - return _constructor - - -# copied from tests/translate/from_native_test.py -def test_series_only(pyspark_constructor: Constructor) -> None: - obj = pyspark_constructor({"a": [1, 2, 3]}) - with pytest.raises(TypeError, match="Cannot only use `series_only`"): - _ = nw.from_native(obj, series_only=True) - - -def test_eager_only_lazy(pyspark_constructor: Constructor) -> None: - dframe = pyspark_constructor({"a": [1, 2, 3]}) - with pytest.raises(TypeError, match="Cannot only use `eager_only`"): - _ = nw.from_native(dframe, eager_only=True) - - -# copied from tests/frame/with_columns_test.py -def test_columns(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.columns - expected = ["a", "b", "z"] - assert result == expected - - -# copied from tests/frame/with_columns_test.py -def test_with_columns_order(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.with_columns(nw.col("a") + 1, d=nw.col("a") - 1) - assert result.collect_schema().names() == ["a", "b", "z", "d"] - expected = {"a": [2, 4, 3], "b": [4, 4, 6], "z": [7.0, 8, 9], "d": [0, 2, 1]} - assert_equal_data(result, expected) - - -@pytest.mark.filterwarnings("ignore:If `index_col` is not specified for `to_spark`") -def test_with_columns_empty(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select().with_columns() - assert_equal_data(result, {}) - - -def test_with_columns_order_single_row(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "i": [0, 1, 2]} - df = nw.from_native(pyspark_constructor(data)).filter(nw.col("i") < 1).drop("i") - result = df.with_columns(nw.col("a") + 1, d=nw.col("a") - 1) - assert result.collect_schema().names() == ["a", "b", "z", "d"] - expected = {"a": [2], "b": [4], "z": [7.0], "d": [0]} - assert_equal_data(result, expected) - - -# copied from tests/frame/select_test.py -def test_select(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select("a") - expected = {"a": [1, 3, 2]} - assert_equal_data(result, expected) - - -@pytest.mark.filterwarnings("ignore:If `index_col` is not specified for `to_spark`") -def test_empty_select(pyspark_constructor: Constructor) -> None: - result = nw.from_native(pyspark_constructor({"a": [1, 2, 3]})).lazy().select() - assert result.collect().shape == (0, 0) - - -# copied from tests/frame/filter_test.py -def test_filter(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.filter(nw.col("a") > 1) - expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} - assert_equal_data(result, expected) - - -# copied from tests/frame/schema_test.py -@pytest.mark.filterwarnings("ignore:Determining|Resolving.*") -def test_schema(pyspark_constructor: Constructor) -> None: - df = nw.from_native( - pyspark_constructor({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8, 9]}) - ) - result = df.schema - expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} - - result = df.schema - assert result == expected - result = df.lazy().collect().schema - assert result == expected - - -def test_collect_schema(pyspark_constructor: Constructor) -> None: - df = nw.from_native( - pyspark_constructor({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8, 9]}) - ) - expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} - - result = df.collect_schema() - assert result == expected - result = df.lazy().collect().collect_schema() - assert result == expected - - -# copied from tests/frame/drop_test.py -@pytest.mark.parametrize( - ("to_drop", "expected"), - [ - ("abc", ["b", "z"]), - (["abc"], ["b", "z"]), - (["abc", "b"], ["z"]), - ], -) -def test_drop( - pyspark_constructor: Constructor, to_drop: list[str], expected: list[str] -) -> None: - data = {"abc": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - assert df.drop(to_drop).collect_schema().names() == expected - if not isinstance(to_drop, str): - assert df.drop(*to_drop).collect_schema().names() == expected - - -@pytest.mark.parametrize( - ("strict", "context"), - [ - (True, pytest.raises(ColumnNotFoundError, match="z")), - (False, does_not_raise()), - ], -) -def test_drop_strict( - pyspark_constructor: Constructor, context: Any, *, strict: bool -) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6]} - to_drop = ["a", "z"] - - df = nw.from_native(pyspark_constructor(data)) - - with context: - names_out = df.drop(to_drop, strict=strict).collect_schema().names() - assert names_out == ["b"] - - -# copied from tests/frame/head_test.py -def test_head(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]} - - df_raw = pyspark_constructor(data) - df = nw.from_native(df_raw) - - result = df.head(2) - assert_equal_data(result, expected) - - result = df.head(2) - assert_equal_data(result, expected) - - # negative indices not allowed for lazyframes - result = df.lazy().collect().head(-1) - assert_equal_data(result, expected) - - -# copied from tests/frame/sort_test.py -def test_sort(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.sort("a", "b") - expected = { - "a": [1, 2, 3], - "b": [4, 6, 4], - "z": [7.0, 9.0, 8.0], - } - assert_equal_data(result, expected) - result = df.sort("a", "b", descending=[True, False]).lazy().collect() - expected = { - "a": [3, 2, 1], - "b": [4, 6, 4], - "z": [8.0, 9.0, 7.0], - } - assert_equal_data(result, expected) - - -@pytest.mark.parametrize( - ("nulls_last", "expected"), - [ - (True, {"a": [0, 2, 0, -1], "b": [3, 2, 1, None]}), - (False, {"a": [-1, 0, 2, 0], "b": [None, 3, 2, 1]}), - ], -) -def test_sort_nulls( - pyspark_constructor: Constructor, *, nulls_last: bool, expected: dict[str, float] -) -> None: - data = {"a": [0, 0, 2, -1], "b": [1, 3, 2, None]} - df = nw.from_native(pyspark_constructor(data)) - result = df.sort("b", descending=True, nulls_last=nulls_last).lazy().collect() - assert_equal_data(result, expected) - - -# copied from tests/frame/add_test.py -def test_add(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.with_columns( - c=nw.col("a") + nw.col("b"), - d=nw.col("a") - nw.col("a").mean(), - e=nw.col("a") - nw.col("a").std(), - ) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8.0, 9.0], - "c": [5, 7, 8], - "d": [-1.0, 1.0, 0.0], - "e": [0.0, 2.0, 1.0], - } - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/all_horizontal_test.py -@pytest.mark.parametrize("expr1", ["a", nw.col("a")]) -@pytest.mark.parametrize("expr2", ["b", nw.col("b")]) -def test_allh(pyspark_constructor: Constructor, expr1: Any, expr2: Any) -> None: - data = { - "a": [False, False, True], - "b": [False, True, True], - } - df = nw.from_native(pyspark_constructor(data)) - result = df.select(all=nw.all_horizontal(expr1, expr2)) - - expected = {"all": [False, False, True]} - assert_equal_data(result, expected) - - -def test_allh_all(pyspark_constructor: Constructor) -> None: - data = { - "a": [False, False, True], - "b": [False, True, True], - } - df = nw.from_native(pyspark_constructor(data)) - result = df.select(all=nw.all_horizontal(nw.all())) - expected = {"all": [False, False, True]} - assert_equal_data(result, expected) - result = df.select(nw.all_horizontal(nw.all())) - expected = {"a": [False, False, True]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/sum_horizontal_test.py -@pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) -def test_sumh(pyspark_constructor: Constructor, col_expr: Any) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.with_columns(horizontal_sum=nw.sum_horizontal(col_expr, nw.col("b"))) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8.0, 9.0], - "horizontal_sum": [5, 7, 8], - } - assert_equal_data(result, expected) - - -def test_sumh_nullable(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 8, 3], "b": [4, 5, None], "idx": [0, 1, 2]} - expected = {"hsum": [5, 13, 3]} - - df = nw.from_native(pyspark_constructor(data)) - result = df.select("idx", hsum=nw.sum_horizontal("a", "b")).sort("idx").drop("idx") - assert_equal_data(result, expected) - - -def test_sumh_all(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, 3], "b": [10, 20, 30]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nw.sum_horizontal(nw.all())) - expected = { - "a": [11, 22, 33], - } - assert_equal_data(result, expected) - result = df.select(c=nw.sum_horizontal(nw.all())) - expected = { - "c": [11, 22, 33], - } - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/count_test.py -def test_count(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, 3], "b": [4, None, 6], "z": [7.0, None, None]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nw.col("a", "b", "z").count()) - expected = {"a": [3], "b": [2], "z": [1]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/double_test.py -def test_double(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.with_columns(nw.all() * 2) - expected = {"a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]} - assert_equal_data(result, expected) - - -def test_double_alias(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.with_columns(nw.col("a").alias("o"), nw.all() * 2) - expected = { - "a": [2, 6, 4], - "b": [8, 8, 12], - "z": [14.0, 16.0, 18.0], - "o": [1, 3, 2], - } - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/max_test.py -def test_expr_max_expr(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nw.col("a", "b", "z").max()) - expected = {"a": [3], "b": [6], "z": [9.0]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/min_test.py -def test_expr_min_expr(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nw.col("a", "b", "z").min()) - expected = {"a": [1], "b": [4], "z": [7.0]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/null_count_test.py -def test_null_count_expr(pyspark_constructor: Constructor) -> None: - data = { - "a": [1.0, None, None, 3.0], - "b": [1.0, None, 4, 5.0], - } - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nw.all().null_count()) - expected = { - "a": [2], - "b": [1], - } - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/min_test.py -@pytest.mark.parametrize("expr", [nw.col("a", "b", "z").sum(), nw.sum("a", "b", "z")]) -def test_expr_sum_expr(pyspark_constructor: Constructor, expr: nw.Expr) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(expr) - expected = {"a": [6], "b": [14], "z": [24.0]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/std_test.py -def test_std(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - - df = nw.from_native(pyspark_constructor(data)) - result = df.select( - nw.col("a").std().alias("a_ddof_default"), - nw.col("a").std(ddof=1).alias("a_ddof_1"), - nw.col("a").std(ddof=0).alias("a_ddof_0"), - nw.col("b").std(ddof=2).alias("b_ddof_2"), - nw.col("z").std(ddof=0).alias("z_ddof_0"), - ) - expected = { - "a_ddof_default": [1.0], - "a_ddof_1": [1.0], - "a_ddof_0": [0.816497], - "b_ddof_2": [1.632993], - "z_ddof_0": [0.816497], - } - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/var_test.py -def test_var(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2, None], "b": [4, 4, 6, None], "z": [7.0, 8, 9, None]} - - expected_results = { - "a_ddof_1": [1.0], - "a_ddof_0": [0.6666666666666666], - "b_ddof_2": [2.666666666666667], - "z_ddof_0": [0.6666666666666666], - } - - df = nw.from_native(pyspark_constructor(data)) - result = df.select( - nw.col("a").var(ddof=1).alias("a_ddof_1"), - nw.col("a").var(ddof=0).alias("a_ddof_0"), - nw.col("b").var(ddof=2).alias("b_ddof_2"), - nw.col("z").var(ddof=0).alias("z_ddof_0"), - ) - assert_equal_data(result, expected_results) - - -# copied from tests/group_by_test.py -def test_group_by_std(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 1, 2, 2], "b": [5, 4, 3, 2]} - result = ( - nw.from_native(pyspark_constructor(data)) - .group_by("a") - .agg(nw.col("b").std()) - .sort("a") - ) - expected = {"a": [1, 2], "b": [0.707107] * 2} - assert_equal_data(result, expected) - - -def test_group_by_simple_named(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} - df = nw.from_native(pyspark_constructor(data)).lazy() - result = ( - df.group_by("a") - .agg( - b_min=nw.col("b").min(), - b_max=nw.col("b").max(), - ) - .collect() - .sort("a") - ) - expected = { - "a": [1, 2], - "b_min": [4, 6], - "b_max": [5, 6], - } - assert_equal_data(result, expected) - - -def test_group_by_simple_unnamed(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} - df = nw.from_native(pyspark_constructor(data)).lazy() - result = ( - df.group_by("a") - .agg( - nw.col("b").min(), - nw.col("c").max(), - ) - .collect() - .sort("a") - ) - expected = { - "a": [1, 2], - "b": [4, 6], - "c": [7, 1], - } - assert_equal_data(result, expected) - - -def test_group_by_multiple_keys(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 1, 2], "b": [4, 4, 6], "c": [7, 2, 1]} - df = nw.from_native(pyspark_constructor(data)).lazy() - result = ( - df.group_by("a", "b") - .agg( - c_min=nw.col("c").min(), - c_max=nw.col("c").max(), - ) - .collect() - .sort("a") - ) - expected = { - "a": [1, 2], - "b": [4, 6], - "c_min": [2, 1], - "c_max": [7, 1], - } - assert_equal_data(result, expected) - - -# copied from tests/group_by_test.py -@pytest.mark.parametrize( - ("attr", "ddof"), - [ - ("std", 0), - ("var", 0), - ("std", 2), - ("var", 2), - ], -) -def test_group_by_depth_1_std_var( - pyspark_constructor: Constructor, - attr: str, - ddof: int, -) -> None: - data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} - _pow = 0.5 if attr == "std" else 1 - expected = { - "a": [1, 2], - "b": [ - (sum((v - 5) ** 2 for v in [4, 5, 6]) / (3 - ddof)) ** _pow, - (sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) / (3 - ddof)) ** _pow, - ], - } - expr = getattr(nw.col("b"), attr)(ddof=ddof) - result = nw.from_native(pyspark_constructor(data)).group_by("a").agg(expr).sort("a") - assert_equal_data(result, expected) - - -# copied from tests/frame/drop_nulls_test.py -def test_drop_nulls(pyspark_constructor: Constructor) -> None: - data = { - "a": [1.0, 2.0, None, 4.0], - "b": [None, 3.0, None, 5.0], - } - - result = nw.from_native(pyspark_constructor(data)).drop_nulls() - expected = { - "a": [2.0, 4.0], - "b": [3.0, 5.0], - } - assert_equal_data(result, expected) - - -@pytest.mark.parametrize( - ("subset", "expected"), - [ - ("a", {"a": [1, 2.0, 4.0], "b": [None, 3.0, 5.0]}), - (["a"], {"a": [1, 2.0, 4.0], "b": [None, 3.0, 5.0]}), - (["a", "b"], {"a": [2.0, 4.0], "b": [3.0, 5.0]}), - ], -) -def test_drop_nulls_subset( - pyspark_constructor: Constructor, subset: str | list[str], expected: dict[str, float] -) -> None: - data = { - "a": [1.0, 2.0, None, 4.0], - "b": [None, 3.0, None, 5.0], - } - - result = nw.from_native(pyspark_constructor(data)).drop_nulls(subset=subset) - assert_equal_data(result, expected) - - -# copied from tests/frame/rename_test.py -def test_rename(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.rename({"a": "x", "b": "y"}) - expected = {"x": [1, 3, 2], "y": [4, 4, 6], "z": [7.0, 8, 9]} - assert_equal_data(result, expected) - - -# adapted from tests/frame/unique_test.py -@pytest.mark.parametrize("subset", ["b", ["b"]]) -@pytest.mark.parametrize( - ("keep", "expected"), - [ - ("first", {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]}), - ("last", {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}), - ("any", {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]}), - ("none", {"a": [2], "b": [6], "z": [9]}), - ], -) -def test_unique( - pyspark_constructor: Constructor, - subset: str | list[str] | None, - keep: str, - expected: dict[str, list[float]], -) -> None: - if keep == "any": - context: Any = does_not_raise() - elif keep == "none": - context = pytest.raises( - ValueError, - match=r"`LazyFrame.unique` with PySpark backend only supports `keep='any'`.", - ) - else: - context = pytest.raises(ValueError, match=f": {keep}") - - with context: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - - result = df.unique(subset, keep=keep).sort("z") # type: ignore[arg-type] - assert_equal_data(result, expected) - - -def test_unique_none(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.unique().sort("z") - assert_equal_data(result, data) - - -def test_inner_join_two_keys(pyspark_constructor: Constructor) -> None: - data = { - "antananarivo": [1, 3, 2], - "bob": [4, 4, 6], - "zorro": [7.0, 8, 9], - "idx": [0, 1, 2], - } - df = nw.from_native(pyspark_constructor(data)) - df_right = nw.from_native(pyspark_constructor(data)) - result = df.join( - df_right, # type: ignore[arg-type] - left_on=["antananarivo", "bob"], - right_on=["antananarivo", "bob"], - how="inner", - ) - result = result.sort("idx").drop("idx_right") - - df = nw.from_native(pyspark_constructor(data)) - df_right = nw.from_native(pyspark_constructor(data)) - - result_on = df.join(df_right, on=["antananarivo", "bob"], how="inner") # type: ignore[arg-type] - result_on = result_on.sort("idx").drop("idx_right") - expected = { - "antananarivo": [1, 3, 2], - "bob": [4, 4, 6], - "zorro": [7.0, 8, 9], - "idx": [0, 1, 2], - "zorro_right": [7.0, 8, 9], - } - assert_equal_data(result, expected) - assert_equal_data(result_on, expected) - - -def test_inner_join_single_key(pyspark_constructor: Constructor) -> None: - data = { - "antananarivo": [1, 3, 2], - "bob": [4, 4, 6], - "zorro": [7.0, 8, 9], - "idx": [0, 1, 2], - } - df = nw.from_native(pyspark_constructor(data)) - df_right = nw.from_native(pyspark_constructor(data)) - result = ( - df.join( - df_right, # type: ignore[arg-type] - left_on="antananarivo", - right_on="antananarivo", - how="inner", - ) - .sort("idx") - .drop("idx_right") - ) - - df = nw.from_native(pyspark_constructor(data)) - df_right = nw.from_native(pyspark_constructor(data)) - result_on = ( - df.join( - df_right, # type: ignore[arg-type] - on="antananarivo", - how="inner", - ) - .sort("idx") - .drop("idx_right") - ) - - expected = { - "antananarivo": [1, 3, 2], - "bob": [4, 4, 6], - "zorro": [7.0, 8, 9], - "idx": [0, 1, 2], - "bob_right": [4, 4, 6], - "zorro_right": [7.0, 8, 9], - } - assert_equal_data(result, expected) - assert_equal_data(result_on, expected) - - -def test_cross_join(pyspark_constructor: Constructor) -> None: - data = {"antananarivo": [1, 3, 2]} - df = nw.from_native(pyspark_constructor(data)) - other = nw.from_native(pyspark_constructor(data)) - result = df.join(other, how="cross").sort("antananarivo", "antananarivo_right") # type: ignore[arg-type] - expected = { - "antananarivo": [1, 1, 1, 2, 2, 2, 3, 3, 3], - "antananarivo_right": [1, 2, 3, 1, 2, 3, 1, 2, 3], - } - assert_equal_data(result, expected) - - with pytest.raises( - ValueError, match="Can not pass `left_on`, `right_on` or `on` keys for cross join" - ): - df.join(other, how="cross", left_on="antananarivo") # type: ignore[arg-type] - - -@pytest.mark.parametrize("how", ["inner", "left"]) -@pytest.mark.parametrize("suffix", ["_right", "_custom_suffix"]) -def test_suffix(pyspark_constructor: Constructor, how: str, suffix: str) -> None: - data = { - "antananarivo": [1, 3, 2], - "bob": [4, 4, 6], - "zorro": [7.0, 8, 9], - } - df = nw.from_native(pyspark_constructor(data)) - df_right = nw.from_native(pyspark_constructor(data)) - result = df.join( - df_right, # type: ignore[arg-type] - left_on=["antananarivo", "bob"], - right_on=["antananarivo", "bob"], - how=how, # type: ignore[arg-type] - suffix=suffix, - ) - result_cols = result.collect_schema().names() - assert result_cols == ["antananarivo", "bob", "zorro", f"zorro{suffix}"] - - -@pytest.mark.parametrize("suffix", ["_right", "_custom_suffix"]) -def test_cross_join_suffix(pyspark_constructor: Constructor, suffix: str) -> None: - data = {"antananarivo": [1, 3, 2]} - df = nw.from_native(pyspark_constructor(data)) - other = nw.from_native(pyspark_constructor(data)) - result = df.join(other, how="cross", suffix=suffix).sort( # type: ignore[arg-type] - "antananarivo", f"antananarivo{suffix}" - ) - expected = { - "antananarivo": [1, 1, 1, 2, 2, 2, 3, 3, 3], - f"antananarivo{suffix}": [1, 2, 3, 1, 2, 3, 1, 2, 3], - } - assert_equal_data(result, expected) - - -@pytest.mark.parametrize( - ("join_key", "filter_expr", "expected"), - [ - ( - ["antananarivo", "bob"], - (nw.col("bob") < 5), - {"antananarivo": [2], "bob": [6], "zorro": [9]}, - ), - (["bob"], (nw.col("bob") < 5), {"antananarivo": [2], "bob": [6], "zorro": [9]}), - ( - ["bob"], - (nw.col("bob") > 5), - {"antananarivo": [1, 3], "bob": [4, 4], "zorro": [7.0, 8.0]}, - ), - ], -) -def test_anti_join( - pyspark_constructor: Constructor, - join_key: list[str], - filter_expr: nw.Expr, - expected: dict[str, list[Any]], -) -> None: - data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - other = df.filter(filter_expr) - result = df.join(other, how="anti", left_on=join_key, right_on=join_key) # type: ignore[arg-type] - assert_equal_data(result, expected) - - -@pytest.mark.parametrize( - ("join_key", "filter_expr", "expected"), - [ - ( - "antananarivo", - (nw.col("bob") > 5), - {"antananarivo": [2], "bob": [6], "zorro": [9]}, - ), - ( - ["antananarivo"], - (nw.col("bob") > 5), - {"antananarivo": [2], "bob": [6], "zorro": [9]}, - ), - ( - ["bob"], - (nw.col("bob") < 5), - {"antananarivo": [1, 3], "bob": [4, 4], "zorro": [7, 8]}, - ), - ( - ["antananarivo", "bob"], - (nw.col("bob") < 5), - {"antananarivo": [1, 3], "bob": [4, 4], "zorro": [7, 8]}, - ), - ], -) -def test_semi_join( - pyspark_constructor: Constructor, - join_key: list[str], - filter_expr: nw.Expr, - expected: dict[str, list[Any]], -) -> None: - data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - other = df.filter(filter_expr) - result = df.join(other, how="semi", left_on=join_key, right_on=join_key).sort( # type: ignore[arg-type] - "antananarivo" - ) - assert_equal_data(result, expected) - - -@pytest.mark.filterwarnings("ignore:the default coalesce behavior") -def test_left_join(pyspark_constructor: Constructor) -> None: - data_left = { - "antananarivo": [1.0, 2, 3], - "bob": [4.0, 5, 6], - "idx": [0.0, 1.0, 2.0], - } - data_right = { - "antananarivo": [1.0, 2, 3], - "co": [4.0, 5, 7], - "idx": [0.0, 1.0, 2.0], - } - df_left = nw.from_native(pyspark_constructor(data_left)) - df_right = nw.from_native(pyspark_constructor(data_right)) - result = ( - df_left.join(df_right, left_on="bob", right_on="co", how="left") # type: ignore[arg-type] - .sort("idx") - .drop("idx_right") - ) - expected = { - "antananarivo": [1, 2, 3], - "bob": [4, 5, 6], - "idx": [0, 1, 2], - "antananarivo_right": [1, 2, None], - } - assert_equal_data(result, expected) - - df_left = nw.from_native(pyspark_constructor(data_left)) - df_right = nw.from_native(pyspark_constructor(data_right)) - result_on_list = df_left.join( - df_right, # type: ignore[arg-type] - on=["antananarivo", "idx"], - how="left", - ) - result_on_list = result_on_list.sort("idx") - expected_on_list = { - "antananarivo": [1, 2, 3], - "bob": [4, 5, 6], - "idx": [0, 1, 2], - "co": [4, 5, 7], - } - assert_equal_data(result_on_list, expected_on_list) - - -@pytest.mark.filterwarnings("ignore: the default coalesce behavior") -def test_left_join_multiple_column(pyspark_constructor: Constructor) -> None: - data_left = {"antananarivo": [1, 2, 3], "bob": [4, 5, 6], "idx": [0, 1, 2]} - data_right = {"antananarivo": [1, 2, 3], "c": [4, 5, 6], "idx": [0, 1, 2]} - df_left = nw.from_native(pyspark_constructor(data_left)) - df_right = nw.from_native(pyspark_constructor(data_right)) - result = ( - df_left.join( - df_right, # type: ignore[arg-type] - left_on=["antananarivo", "bob"], - right_on=["antananarivo", "c"], - how="left", - ) - .sort("idx") - .drop("idx_right") - ) - expected = {"antananarivo": [1, 2, 3], "bob": [4, 5, 6], "idx": [0, 1, 2]} - assert_equal_data(result, expected) - - -@pytest.mark.filterwarnings("ignore: the default coalesce behavior") -def test_left_join_overlapping_column(pyspark_constructor: Constructor) -> None: - data_left = { - "antananarivo": [1.0, 2, 3], - "bob": [4.0, 5, 6], - "d": [1.0, 4, 2], - "idx": [0.0, 1.0, 2.0], - } - data_right = { - "antananarivo": [1.0, 2, 3], - "c": [4.0, 5, 6], - "d": [1.0, 4, 2], - "idx": [0.0, 1.0, 2.0], - } - df_left = nw.from_native(pyspark_constructor(data_left)) - df_right = nw.from_native(pyspark_constructor(data_right)) - result = df_left.join(df_right, left_on="bob", right_on="c", how="left").sort("idx") # type: ignore[arg-type] - result = result.drop("idx_right") - expected: dict[str, list[Any]] = { - "antananarivo": [1, 2, 3], - "bob": [4, 5, 6], - "d": [1, 4, 2], - "idx": [0, 1, 2], - "antananarivo_right": [1, 2, 3], - "d_right": [1, 4, 2], - } - assert_equal_data(result, expected) - - df_left = nw.from_native(pyspark_constructor(data_left)) - df_right = nw.from_native(pyspark_constructor(data_right)) - result = ( - df_left.join( - df_right, # type: ignore[arg-type] - left_on="antananarivo", - right_on="d", - how="left", - ) - .sort("idx") - .drop("idx_right") - ) - expected = { - "antananarivo": [1, 2, 3], - "bob": [4, 5, 6], - "d": [1, 4, 2], - "idx": [0, 1, 2], - "antananarivo_right": [1.0, 3.0, None], - "c": [4.0, 6.0, None], - } - assert_equal_data(result, expected)