diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index 6a9882846..4a44f8963 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -27,7 +27,10 @@ jobs: - name: install-minimum-versions run: uv pip install tox virtualenv setuptools pandas==0.25.3 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system - name: install-reqs - run: uv pip install -r requirements-dev.txt --system + run: | + uv pip install -r requirements-dev.txt --system + : # pyspark >= 3.3.0 is not compatible with pandas==0.25.3 + uv pip uninstall pyspark --system - name: show-deps run: uv pip freeze - name: Run pytest @@ -52,7 +55,7 @@ jobs: cache-suffix: ${{ matrix.python-version }} cache-dependency-glob: "**requirements*.txt" - name: install-minimum-versions - run: uv pip install tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system + run: uv pip install tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 pyspark==3.3.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system - name: install-reqs run: uv pip install -r requirements-dev.txt --system - name: show-deps @@ -79,7 +82,7 @@ jobs: cache-suffix: ${{ matrix.python-version }} cache-dependency-glob: "**requirements*.txt" - name: install-not-so-old-versions - run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==14.0.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.7 tzdata --system + run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==14.0.0 pyspark==3.4.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.7 tzdata --system - name: install-reqs run: uv pip install -r requirements-dev.txt --system - name: show-deps diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 87a91dfa9..e10a4730f 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -35,13 +35,21 @@ from narwhals._polars.namespace import PolarsNamespace from narwhals._polars.series import PolarsSeries from narwhals._polars.typing import IntoPolarsExpr + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + from narwhals._spark_like.expr import SparkLikeExpr + from narwhals._spark_like.namespace import SparkLikeNamespace + from narwhals._spark_like.typing import IntoSparkLikeExpr CompliantNamespace = Union[ - PandasLikeNamespace, ArrowNamespace, DaskNamespace, PolarsNamespace + PandasLikeNamespace, + ArrowNamespace, + DaskNamespace, + PolarsNamespace, + SparkLikeNamespace, ] - CompliantExpr = Union[PandasLikeExpr, ArrowExpr, DaskExpr, PolarsExpr] + CompliantExpr = Union[PandasLikeExpr, ArrowExpr, DaskExpr, PolarsExpr, SparkLikeExpr] IntoCompliantExpr = Union[ - IntoPandasLikeExpr, IntoArrowExpr, IntoDaskExpr, IntoPolarsExpr + IntoPandasLikeExpr, IntoArrowExpr, IntoDaskExpr, IntoPolarsExpr, IntoSparkLikeExpr ] IntoCompliantExprT = TypeVar("IntoCompliantExprT", bound=IntoCompliantExpr) CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExpr) @@ -50,9 +58,15 @@ list[PandasLikeSeries], list[ArrowSeries], list[DaskExpr], list[PolarsSeries] ] ListOfCompliantExpr = Union[ - list[PandasLikeExpr], list[ArrowExpr], list[DaskExpr], list[PolarsExpr] + list[PandasLikeExpr], + list[ArrowExpr], + list[DaskExpr], + list[PolarsExpr], + list[SparkLikeExpr], + ] + CompliantDataFrame = Union[ + PandasLikeDataFrame, ArrowDataFrame, DaskLazyFrame, SparkLikeLazyFrame ] - CompliantDataFrame = Union[PandasLikeDataFrame, ArrowDataFrame, DaskLazyFrame] T = TypeVar("T") @@ -152,6 +166,14 @@ def parse_into_exprs( ) -> list[PolarsExpr]: ... +@overload +def parse_into_exprs( + *exprs: IntoSparkLikeExpr, + namespace: SparkLikeNamespace, + **named_exprs: IntoSparkLikeExpr, +) -> list[SparkLikeExpr]: ... + + def parse_into_exprs( *exprs: IntoCompliantExpr, namespace: CompliantNamespace, diff --git a/narwhals/_spark_like/__init__.py b/narwhals/_spark_like/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py new file mode 100644 index 000000000..d488ed7f2 --- /dev/null +++ b/narwhals/_spark_like/dataframe.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any +from typing import Iterable +from typing import Sequence + +from narwhals._spark_like.utils import native_to_narwhals_dtype +from narwhals._spark_like.utils import parse_exprs_and_named_exprs +from narwhals.utils import Implementation +from narwhals.utils import flatten +from narwhals.utils import parse_columns_to_drop +from narwhals.utils import parse_version + +if TYPE_CHECKING: + from pyspark.sql import DataFrame + from typing_extensions import Self + + from narwhals._spark_like.expr import SparkLikeExpr + from narwhals._spark_like.group_by import SparkLikeLazyGroupBy + from narwhals._spark_like.namespace import SparkLikeNamespace + from narwhals._spark_like.typing import IntoSparkLikeExpr + from narwhals.dtypes import DType + from narwhals.utils import Version + + +class SparkLikeLazyFrame: + def __init__( + self, + native_dataframe: DataFrame, + *, + backend_version: tuple[int, ...], + version: Version, + ) -> None: + self._native_frame = native_dataframe + self._backend_version = backend_version + self._implementation = Implementation.PYSPARK + self._version = version + + def __native_namespace__(self) -> Any: # pragma: no cover + if self._implementation is Implementation.PYSPARK: + return self._implementation.to_native_namespace() + + msg = f"Expected pyspark, got: {type(self._implementation)}" # pragma: no cover + raise AssertionError(msg) + + def __narwhals_namespace__(self) -> SparkLikeNamespace: + from narwhals._spark_like.namespace import SparkLikeNamespace + + return SparkLikeNamespace( + backend_version=self._backend_version, version=self._version + ) + + def __narwhals_lazyframe__(self) -> Self: + return self + + def _change_version(self, version: Version) -> Self: + return self.__class__( + self._native_frame, backend_version=self._backend_version, version=version + ) + + def _from_native_frame(self, df: DataFrame) -> Self: + return self.__class__( + df, backend_version=self._backend_version, version=self._version + ) + + @property + def columns(self) -> list[str]: + return self._native_frame.columns # type: ignore[no-any-return] + + def collect(self) -> Any: + import pandas as pd # ignore-banned-import() + + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + + return PandasLikeDataFrame( + native_dataframe=self._native_frame.toPandas(), + implementation=Implementation.PANDAS, + backend_version=parse_version(pd.__version__), + version=self._version, + ) + + def select( + self: Self, + *exprs: IntoSparkLikeExpr, + **named_exprs: IntoSparkLikeExpr, + ) -> Self: + if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs: + # This is a simple select + return self._from_native_frame(self._native_frame.select(*exprs)) + + new_columns = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + + if not new_columns: + # return empty dataframe, like Polars does + from pyspark.sql.types import StructType + + spark_session = self._native_frame.sparkSession + spark_df = spark_session.createDataFrame([], StructType([])) + + return self._from_native_frame(spark_df) + + new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()] + return self._from_native_frame(self._native_frame.select(*new_columns_list)) + + def filter(self, *predicates: SparkLikeExpr) -> Self: + from narwhals._spark_like.namespace import SparkLikeNamespace + + if ( + len(predicates) == 1 + and isinstance(predicates[0], list) + and all(isinstance(x, bool) for x in predicates[0]) + ): + msg = "`LazyFrame.filter` is not supported for PySpark backend with boolean masks." + raise NotImplementedError(msg) + plx = SparkLikeNamespace( + backend_version=self._backend_version, version=self._version + ) + expr = plx.all_horizontal(*predicates) + # Safety: all_horizontal's expression only returns a single column. + condition = expr._call(self)[0] + spark_df = self._native_frame.where(condition) + return self._from_native_frame(spark_df) + + @property + def schema(self) -> dict[str, DType]: + return { + field.name: native_to_narwhals_dtype( + dtype=field.dataType, version=self._version + ) + for field in self._native_frame.schema + } + + def collect_schema(self) -> dict[str, DType]: + return self.schema + + def with_columns( + self: Self, + *exprs: IntoSparkLikeExpr, + **named_exprs: IntoSparkLikeExpr, + ) -> Self: + new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + return self._from_native_frame(self._native_frame.withColumns(new_columns_map)) + + def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001 + columns_to_drop = parse_columns_to_drop( + compliant_frame=self, columns=columns, strict=strict + ) + return self._from_native_frame(self._native_frame.drop(*columns_to_drop)) + + def head(self: Self, n: int) -> Self: + spark_session = self._native_frame.sparkSession + + return self._from_native_frame( + spark_session.createDataFrame(self._native_frame.take(num=n)) + ) + + def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy: + from narwhals._spark_like.group_by import SparkLikeLazyGroupBy + + return SparkLikeLazyGroupBy( + df=self, keys=list(keys), drop_null_keys=drop_null_keys + ) + + def sort( + self: Self, + by: str | Iterable[str], + *more_by: str, + descending: bool | Sequence[bool] = False, + nulls_last: bool = False, + ) -> Self: + import pyspark.sql.functions as F # noqa: N812 + + flat_by = flatten([*flatten([by]), *more_by]) + if isinstance(descending, bool): + descending = [descending] + + if nulls_last: + sort_funcs = [ + F.desc_nulls_last if d else F.asc_nulls_last for d in descending + ] + else: + sort_funcs = [ + F.desc_nulls_first if d else F.asc_nulls_first for d in descending + ] + + sort_cols = [sort_f(col) for col, sort_f in zip(flat_by, sort_funcs)] + return self._from_native_frame(self._native_frame.sort(*sort_cols)) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py new file mode 100644 index 000000000..c98c79a5a --- /dev/null +++ b/narwhals/_spark_like/expr.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import operator +from copy import copy +from typing import TYPE_CHECKING +from typing import Callable + +from narwhals._spark_like.utils import get_column_name +from narwhals._spark_like.utils import maybe_evaluate +from narwhals.utils import Implementation +from narwhals.utils import parse_version + +if TYPE_CHECKING: + from pyspark.sql import Column + from typing_extensions import Self + + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + from narwhals._spark_like.namespace import SparkLikeNamespace + from narwhals.utils import Version + + +class SparkLikeExpr: + _implementation = Implementation.PYSPARK + + def __init__( + self, + call: Callable[[SparkLikeLazyFrame], list[Column]], + *, + depth: int, + function_name: str, + root_names: list[str] | None, + output_names: list[str] | None, + # Whether the expression is a length-1 Column resulting from + # a reduction, such as `nw.col('a').sum()` + returns_scalar: bool, + backend_version: tuple[int, ...], + version: Version, + ) -> None: + self._call = call + self._depth = depth + self._function_name = function_name + self._root_names = root_names + self._output_names = output_names + self._returns_scalar = returns_scalar + self._backend_version = backend_version + self._version = version + + def __narwhals_expr__(self) -> None: ... + + def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover + # Unused, just for compatibility with PandasLikeExpr + from narwhals._spark_like.namespace import SparkLikeNamespace + + return SparkLikeNamespace( + backend_version=self._backend_version, version=self._version + ) + + @classmethod + def from_column_names( + cls: type[Self], + *column_names: str, + backend_version: tuple[int, ...], + version: Version, + ) -> Self: + def func(_: SparkLikeLazyFrame) -> list[Column]: + from pyspark.sql import functions as F # noqa: N812 + + return [F.col(col_name) for col_name in column_names] + + return cls( + func, + depth=0, + function_name="col", + root_names=list(column_names), + output_names=list(column_names), + returns_scalar=False, + backend_version=backend_version, + version=version, + ) + + def _from_call( + self, + call: Callable[..., Column], + expr_name: str, + *args: SparkLikeExpr, + returns_scalar: bool, + **kwargs: SparkLikeExpr, + ) -> Self: + def func(df: SparkLikeLazyFrame) -> list[Column]: + results = [] + inputs = self._call(df) + _args = [maybe_evaluate(df, arg) for arg in args] + _kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()} + for _input in inputs: + input_col_name = get_column_name(df, _input) + column_result = call(_input, *_args, **_kwargs) + if not returns_scalar: + column_result = column_result.alias(input_col_name) + results.append(column_result) + return results + + # Try tracking root and output names by combining them from all + # expressions appearing in args and kwargs. If any anonymous + # expression appears (e.g. nw.all()), then give up on tracking root names + # and just set it to None. + root_names = copy(self._root_names) + output_names = self._output_names + for arg in list(args) + list(kwargs.values()): + if root_names is not None and isinstance(arg, self.__class__): + if arg._root_names is not None: + root_names.extend(arg._root_names) + else: # pragma: no cover + root_names = None + output_names = None + break + elif root_names is None: + output_names = None + break + + if not ( + (output_names is None and root_names is None) + or (output_names is not None and root_names is not None) + ): # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + + return self.__class__( + func, + depth=self._depth + 1, + function_name=f"{self._function_name}->{expr_name}", + root_names=root_names, + output_names=output_names, + returns_scalar=self._returns_scalar or returns_scalar, + backend_version=self._backend_version, + version=self._version, + ) + + def __add__(self, other: SparkLikeExpr) -> Self: + return self._from_call(operator.add, "__add__", other, returns_scalar=False) + + def __sub__(self, other: SparkLikeExpr) -> Self: + return self._from_call(operator.sub, "__sub__", other, returns_scalar=False) + + def __mul__(self, other: SparkLikeExpr) -> Self: + return self._from_call(operator.mul, "__mul__", other, returns_scalar=False) + + def __lt__(self, other: SparkLikeExpr) -> Self: + return self._from_call(operator.lt, "__lt__", other, returns_scalar=False) + + def __gt__(self, other: SparkLikeExpr) -> Self: + return self._from_call(operator.gt, "__gt__", other, returns_scalar=False) + + def alias(self, name: str) -> Self: + def _alias(df: SparkLikeLazyFrame) -> list[Column]: + return [col.alias(name) for col in self._call(df)] + + # Define this one manually, so that we can + # override `output_names` and not increase depth + return self.__class__( + _alias, + depth=self._depth, + function_name=self._function_name, + root_names=self._root_names, + output_names=[name], + returns_scalar=self._returns_scalar, + backend_version=self._backend_version, + version=self._version, + ) + + def count(self) -> Self: + def _count(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.count(_input) + + return self._from_call(_count, "count", returns_scalar=True) + + def max(self) -> Self: + def _max(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.max(_input) + + return self._from_call(_max, "max", returns_scalar=True) + + def mean(self) -> Self: + def _mean(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.mean(_input) + + return self._from_call(_mean, "mean", returns_scalar=True) + + def min(self) -> Self: + def _min(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.min(_input) + + return self._from_call(_min, "min", returns_scalar=True) + + def std(self, ddof: int = 1) -> Self: + import numpy as np # ignore-banned-import + + def _std(_input: Column) -> Column: # pragma: no cover + if self._backend_version < (3, 5) or parse_version(np.__version__) > (2, 0): + from pyspark.sql import functions as F # noqa: N812 + + if ddof == 1: + return F.stddev_samp(_input) + + n_rows = F.count(_input) + return F.stddev_samp(_input) * F.sqrt((n_rows - 1) / (n_rows - ddof)) + + from pyspark.pandas.spark.functions import stddev + + return stddev(_input, ddof=ddof) + + return self._from_call(_std, "std", returns_scalar=True) diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py new file mode 100644 index 000000000..3b7ad78fd --- /dev/null +++ b/narwhals/_spark_like/group_by.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from copy import copy +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable + +from narwhals._expression_parsing import is_simple_aggregation +from narwhals._expression_parsing import parse_into_exprs +from narwhals.utils import remove_prefix + +if TYPE_CHECKING: + from pyspark.sql import Column + from pyspark.sql import GroupedData + + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + from narwhals._spark_like.expr import SparkLikeExpr + from narwhals._spark_like.typing import IntoSparkLikeExpr + +POLARS_TO_PYSPARK_AGGREGATIONS = { + "len": "count", + "std": "stddev", +} + + +class SparkLikeLazyGroupBy: + def __init__( + self, + df: SparkLikeLazyFrame, + keys: list[str], + drop_null_keys: bool, # noqa: FBT001 + ) -> None: + self._df = df + self._keys = keys + if drop_null_keys: + self._grouped = self._df._native_frame.dropna(subset=self._keys).groupBy( + *self._keys + ) + else: + self._grouped = self._df._native_frame.groupBy(*self._keys) + + def agg( + self, + *aggs: IntoSparkLikeExpr, + **named_aggs: IntoSparkLikeExpr, + ) -> SparkLikeLazyFrame: + exprs = parse_into_exprs( + *aggs, + namespace=self._df.__narwhals_namespace__(), + **named_aggs, + ) + output_names: list[str] = copy(self._keys) + for expr in exprs: + if expr._output_names is None: # pragma: no cover + msg = ( + "Anonymous expressions are not supported in group_by.agg.\n" + "Instead of `nw.all()`, try using a named expression, such as " + "`nw.col('a', 'b')`\n" + ) + raise ValueError(msg) + + output_names.extend(expr._output_names) + + return agg_pyspark( + self._grouped, + exprs, + self._keys, + self._from_native_frame, + ) + + def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame: + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + + return SparkLikeLazyFrame( + df, backend_version=self._df._backend_version, version=self._df._version + ) + + +def get_spark_function(function_name: str) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return getattr(F, function_name) + + +def agg_pyspark( + grouped: GroupedData, + exprs: list[SparkLikeExpr], + keys: list[str], + from_dataframe: Callable[[Any], SparkLikeLazyFrame], +) -> SparkLikeLazyFrame: + for expr in exprs: + if not is_simple_aggregation(expr): # pragma: no cover + msg = ( + "Non-trivial complex found.\n\n" + "Hint: you were probably trying to apply a non-elementary aggregation with a " + "dask dataframe.\n" + "Please rewrite your query such that group-by aggregations " + "are elementary. For example, instead of:\n\n" + " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" + "use:\n\n" + " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" + ) + raise ValueError(msg) + + simple_aggregations: dict[str, Column] = {} + for expr in exprs: + if expr._depth == 0: # pragma: no cover + # e.g. agg(nw.len()) # noqa: ERA001 + if expr._output_names is None: # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + + function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get( + expr._function_name, expr._function_name + ) + for output_name in expr._output_names: + agg_func = get_spark_function(function_name) + simple_aggregations[output_name] = agg_func(keys[0]) + continue + + # e.g. agg(nw.mean('a')) # noqa: ERA001 + if ( + expr._depth != 1 or expr._root_names is None or expr._output_names is None + ): # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + + function_name = remove_prefix(expr._function_name, "col->") + function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get(function_name, function_name) + + for root_name, output_name in zip(expr._root_names, expr._output_names): + agg_func = get_spark_function(function_name) + simple_aggregations[output_name] = agg_func(root_name) + agg_columns = [col_.alias(name) for name, col_ in simple_aggregations.items()] + try: + result_simple = grouped.agg(*agg_columns) + except ValueError as exc: # pragma: no cover + msg = "Failed to aggregated - does your aggregation function return a scalar?" + raise RuntimeError(msg) from exc + return from_dataframe(result_simple) diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py new file mode 100644 index 000000000..a762c26c8 --- /dev/null +++ b/narwhals/_spark_like/namespace.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import operator +from functools import reduce +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import NoReturn + +from narwhals._expression_parsing import combine_root_names +from narwhals._expression_parsing import parse_into_exprs +from narwhals._expression_parsing import reduce_output_names +from narwhals._spark_like.expr import SparkLikeExpr +from narwhals._spark_like.utils import get_column_name + +if TYPE_CHECKING: + from pyspark.sql import Column + + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + from narwhals._spark_like.typing import IntoSparkLikeExpr + from narwhals.utils import Version + + +class SparkLikeNamespace: + def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None: + self._backend_version = backend_version + self._version = version + + def _create_expr_from_series(self, _: Any) -> NoReturn: + msg = "`_create_expr_from_series` for PySparkNamespace exists only for compatibility" + raise NotImplementedError(msg) + + def _create_compliant_series(self, _: Any) -> NoReturn: + msg = "`_create_compliant_series` for PySparkNamespace exists only for compatibility" + raise NotImplementedError(msg) + + def _create_series_from_scalar( + self, value: Any, *, reference_series: SparkLikeExpr + ) -> NoReturn: + msg = "`_create_series_from_scalar` for PySparkNamespace exists only for compatibility" + raise NotImplementedError(msg) + + def _create_expr_from_callable( # pragma: no cover + self, + func: Callable[[SparkLikeLazyFrame], list[SparkLikeExpr]], + *, + depth: int, + function_name: str, + root_names: list[str] | None, + output_names: list[str] | None, + ) -> SparkLikeExpr: + msg = "`_create_expr_from_callable` for PySparkNamespace exists only for compatibility" + raise NotImplementedError(msg) + + def all(self) -> SparkLikeExpr: + def _all(df: SparkLikeLazyFrame) -> list[Column]: + import pyspark.sql.functions as F # noqa: N812 + + return [F.col(col_name) for col_name in df.columns] + + return SparkLikeExpr( + call=_all, + depth=0, + function_name="all", + root_names=None, + output_names=None, + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + ) + + def all_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: + parsed_exprs = parse_into_exprs(*exprs, namespace=self) + + def func(df: SparkLikeLazyFrame) -> list[Column]: + cols = [c for _expr in parsed_exprs for c in _expr._call(df)] + col_name = get_column_name(df, cols[0]) + return [reduce(operator.and_, cols).alias(col_name)] + + return SparkLikeExpr( + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="all_horizontal", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + ) + + def col(self, *column_names: str) -> SparkLikeExpr: + return SparkLikeExpr.from_column_names( + *column_names, backend_version=self._backend_version, version=self._version + ) diff --git a/narwhals/_spark_like/typing.py b/narwhals/_spark_like/typing.py new file mode 100644 index 000000000..fb343044b --- /dev/null +++ b/narwhals/_spark_like/typing.py @@ -0,0 +1,16 @@ +from __future__ import annotations # pragma: no cover + +from typing import TYPE_CHECKING # pragma: no cover +from typing import Union # pragma: no cover + +if TYPE_CHECKING: + import sys + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + from narwhals._spark_like.expr import SparkLikeExpr + + IntoSparkLikeExpr: TypeAlias = Union[SparkLikeExpr, str] diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py new file mode 100644 index 000000000..4a22bff7e --- /dev/null +++ b/narwhals/_spark_like/utils.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any + +from narwhals.exceptions import InvalidIntoExprError +from narwhals.utils import import_dtypes_module + +if TYPE_CHECKING: + from pyspark.sql import Column + from pyspark.sql import types as pyspark_types + + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + from narwhals._spark_like.typing import IntoSparkLikeExpr + from narwhals.dtypes import DType + from narwhals.utils import Version + + +def native_to_narwhals_dtype( + dtype: pyspark_types.DataType, + version: Version, +) -> DType: # pragma: no cover + dtypes = import_dtypes_module(version=version) + from pyspark.sql import types as pyspark_types + + if isinstance(dtype, pyspark_types.DoubleType): + return dtypes.Float64() + if isinstance(dtype, pyspark_types.FloatType): + return dtypes.Float32() + if isinstance(dtype, pyspark_types.LongType): + return dtypes.Int64() + if isinstance(dtype, pyspark_types.IntegerType): + return dtypes.Int32() + if isinstance(dtype, pyspark_types.ShortType): + return dtypes.Int16() + string_types = [ + pyspark_types.StringType, + pyspark_types.VarcharType, + pyspark_types.CharType, + ] + if any(isinstance(dtype, t) for t in string_types): + return dtypes.String() + if isinstance(dtype, pyspark_types.BooleanType): + return dtypes.Boolean() + if isinstance(dtype, pyspark_types.DateType): + return dtypes.Date() + datetime_types = [ + pyspark_types.TimestampType, + pyspark_types.TimestampNTZType, + ] + if any(isinstance(dtype, t) for t in datetime_types): + return dtypes.Datetime() + return dtypes.Unknown() + + +def get_column_name(df: SparkLikeLazyFrame, column: Column) -> str: + return str(df._native_frame.select(column).columns[0]) + + +def _columns_from_expr(df: SparkLikeLazyFrame, expr: IntoSparkLikeExpr) -> list[Column]: + if isinstance(expr, str): # pragma: no cover + from pyspark.sql import functions as F # noqa: N812 + + return [F.col(expr)] + elif hasattr(expr, "__narwhals_expr__"): + col_output_list = expr._call(df) + if expr._output_names is not None and ( + len(col_output_list) != len(expr._output_names) + ): # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + return col_output_list + else: + raise InvalidIntoExprError.from_invalid_type(type(expr)) + + +def parse_exprs_and_named_exprs( + df: SparkLikeLazyFrame, *exprs: IntoSparkLikeExpr, **named_exprs: IntoSparkLikeExpr +) -> dict[str, Column]: + result_columns: dict[str, list[Column]] = {} + for expr in exprs: + column_list = _columns_from_expr(df, expr) + if isinstance(expr, str): # pragma: no cover + output_names = [expr] + elif expr._output_names is None: + output_names = [get_column_name(df, col) for col in column_list] + else: + output_names = expr._output_names + result_columns.update(zip(output_names, column_list)) + for col_alias, expr in named_exprs.items(): + columns_list = _columns_from_expr(df, expr) + if len(columns_list) != 1: # pragma: no cover + msg = "Named expressions must return a single column" + raise AssertionError(msg) + result_columns[col_alias] = columns_list[0] + return result_columns + + +def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any: + from narwhals._spark_like.expr import SparkLikeExpr + + if isinstance(obj, SparkLikeExpr): + column_results = obj._call(df) + if len(column_results) != 1: # pragma: no cover + msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context" + raise NotImplementedError(msg) + column_result = column_results[0] + if obj._returns_scalar: + # Return scalar, let PySpark do its broadcasting + from pyspark.sql import functions as F # noqa: N812 + from pyspark.sql.window import Window + + return column_result.over(Window.partitionBy(F.lit(1))) + return column_result + return obj diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index d5f0a6c6f..463a64a67 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -22,6 +22,7 @@ import pandas as pd import polars as pl import pyarrow as pa + import pyspark.sql as pyspark_sql from narwhals.typing import IntoSeries @@ -93,6 +94,16 @@ def get_ibis() -> Any: return sys.modules.get("ibis", None) +def get_pyspark() -> Any: # pragma: no cover + """Get pyspark module (if already imported - else return None).""" + return sys.modules.get("pyspark", None) + + +def get_pyspark_sql() -> Any: + """Get pyspark.sql module (if already imported - else return None).""" + return sys.modules.get("pyspark.sql", None) + + def is_pandas_dataframe(df: Any) -> TypeGuard[pd.DataFrame]: """Check whether `df` is a pandas DataFrame without importing pandas.""" return ((pd := get_pandas()) is not None and isinstance(df, pd.DataFrame)) or any( @@ -196,6 +207,14 @@ def is_pyarrow_table(df: Any) -> TypeGuard[pa.Table]: return (pa := get_pyarrow()) is not None and isinstance(df, pa.Table) +def is_pyspark_dataframe(df: Any) -> TypeGuard[pyspark_sql.DataFrame]: + """Check whether `df` is a PySpark DataFrame without importing PySpark.""" + return bool( + (pyspark_sql := get_pyspark_sql()) is not None + and isinstance(df, pyspark_sql.DataFrame) + ) + + def is_numpy_array(arr: Any) -> TypeGuard[np.ndarray]: """Check whether `arr` is a NumPy Array without importing NumPy.""" return (np := get_numpy()) is not None and isinstance(arr, np.ndarray) diff --git a/narwhals/translate.py b/narwhals/translate.py index dcc2e1d66..2bd0949f9 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -19,6 +19,7 @@ from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow +from narwhals.dependencies import get_pyspark from narwhals.dependencies import is_cudf_dataframe from narwhals.dependencies import is_cudf_series from narwhals.dependencies import is_dask_dataframe @@ -33,6 +34,7 @@ from narwhals.dependencies import is_polars_series from narwhals.dependencies import is_pyarrow_chunked_array from narwhals.dependencies import is_pyarrow_table +from narwhals.dependencies import is_pyspark_dataframe from narwhals.utils import Version if TYPE_CHECKING: @@ -408,6 +410,7 @@ def _from_native_impl( # noqa: PLR0915 from narwhals._polars.dataframe import PolarsDataFrame from narwhals._polars.dataframe import PolarsLazyFrame from narwhals._polars.series import PolarsSeries + from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.series import Series @@ -708,6 +711,23 @@ def _from_native_impl( # noqa: PLR0915 level="interchange", ) + # PySpark + elif is_pyspark_dataframe(native_object): # pragma: no cover + if series_only: + msg = "Cannot only use `series_only` with pyspark DataFrame" + raise TypeError(msg) + if eager_only or eager_or_interchange_only: + msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with pyspark DataFrame" + raise TypeError(msg) + return LazyFrame( + SparkLikeLazyFrame( + native_object, + backend_version=parse_version(get_pyspark().__version__), + version=version, + ), + level="lazy", + ) + # Interchange protocol elif hasattr(native_object, "__dataframe__"): if eager_only or series_only: diff --git a/narwhals/utils.py b/narwhals/utils.py index 817c2bc9a..7e0c142ce 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -19,6 +19,7 @@ from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow +from narwhals.dependencies import get_pyspark_sql from narwhals.dependencies import is_cudf_series from narwhals.dependencies import is_modin_series from narwhals.dependencies import is_pandas_dataframe @@ -58,6 +59,7 @@ class Implementation(Enum): MODIN = auto() CUDF = auto() PYARROW = auto() + PYSPARK = auto() POLARS = auto() DASK = auto() @@ -80,6 +82,7 @@ def from_native_namespace( get_modin(): Implementation.MODIN, get_cudf(): Implementation.CUDF, get_pyarrow(): Implementation.PYARROW, + get_pyspark_sql(): Implementation.PYSPARK, get_polars(): Implementation.POLARS, get_dask_dataframe(): Implementation.DASK, } @@ -96,6 +99,7 @@ def to_native_namespace(self: Self) -> ModuleType: Implementation.MODIN: get_modin(), Implementation.CUDF: get_cudf(), Implementation.PYARROW: get_pyarrow(), + Implementation.PYSPARK: get_pyspark_sql(), Implementation.POLARS: get_polars(), Implementation.DASK: get_dask_dataframe(), } diff --git a/noxfile.py b/noxfile.py index e560479ad..ccb6bae27 100644 --- a/noxfile.py +++ b/noxfile.py @@ -55,6 +55,8 @@ def min_and_old_versions(session: Session, pandas_version: str) -> None: "scikit-learn==1.1.0", "tzdata", ) + if pandas_version == "1.1.5": + session.install("pyspark==3.3.0") run_common(session, coverage_threshold=50) diff --git a/pyproject.toml b/pyproject.toml index 281884d8b..f48f9ce69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,15 +6,15 @@ build-backend = "hatchling.build" name = "narwhals" version = "1.15.2" authors = [ - { name="Marco Gorelli", email="33491632+MarcoGorelli@users.noreply.github.com" }, + { name = "Marco Gorelli", email = "33491632+MarcoGorelli@users.noreply.github.com" }, ] description = "Extremely lightweight compatibility layer between dataframe libraries" readme = "README.md" requires-python = ">=3.8" classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", ] [tool.hatch.build] @@ -38,6 +38,7 @@ pandas = ["pandas>=0.25.3"] polars = ["polars>=0.20.3"] pyarrow = ["pyarrow>=11.0.0"] dask = ["dask[dataframe]>=2024.7"] +pyspark = ["pyspark>=3.3.0"] [project.urls] "Homepage" = "https://github.com/narwhals-dev/narwhals" @@ -124,11 +125,14 @@ filterwarnings = [ 'ignore:.*but when imported by', 'ignore:Distributing .*This may take some time', 'ignore:.*The default coalesce behavior', + 'ignore:is_datetime64tz_dtype is deprecated', + 'ignore: unclosed POLARS_VERSION", @@ -159,8 +166,8 @@ strict = true [[tool.mypy.overrides]] # the pandas API is just too inconsistent for type hinting to be useful. module = [ - "pandas.*", - "cudf.*", - "modin.*", + "pandas.*", + "cudf.*", + "modin.*", ] ignore_missing_imports = true diff --git a/requirements-dev.txt b/requirements-dev.txt index fb5afb257..ff8429207 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,6 +5,7 @@ pandas polars pre-commit pyarrow +pyspark; python_version < '3.12' pyarrow-stubs pytest pytest-cov diff --git a/tests/conftest.py b/tests/conftest.py index 68e03449e..cb8a982a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import contextlib from typing import TYPE_CHECKING from typing import Any +from typing import Generator import pandas as pd import polars as pl @@ -26,6 +27,15 @@ import dask.dataframe # noqa: F401 with contextlib.suppress(ImportError): import cudf # noqa: F401 +with contextlib.suppress(ImportError): + from pyspark.sql import SparkSession + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + + from narwhals.typing import IntoDataFrame + from narwhals.typing import IntoFrame + from tests.utils import Constructor def pytest_addoption(parser: Any) -> None: @@ -92,6 +102,37 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame: return pa.table(obj) # type: ignore[no-any-return] +@pytest.fixture(scope="session") +def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover + try: + from pyspark.sql import SparkSession + except ImportError: # pragma: no cover + pytest.skip("pyspark is not installed") + return + + import os + import warnings + + os.environ["PYARROW_IGNORE_TIMEZONE"] = "1" + with warnings.catch_warnings(): + # The spark session seems to trigger a polars warning. + # Polars is imported in the tests, but not used in the spark operations + warnings.filterwarnings( + "ignore", r"Using fork\(\) can cause Polars", category=RuntimeWarning + ) + session = ( + SparkSession.builder.appName("unit-tests") + .master("local[1]") + .config("spark.ui.enabled", "false") + # executing one task at a time makes the tests faster + .config("spark.default.parallelism", "1") + .config("spark.sql.shuffle.partitions", "2") + .getOrCreate() + ) + yield session + session.stop() + + if PANDAS_VERSION >= (2, 0, 0): eager_constructors = [ pandas_constructor, diff --git a/tests/no_imports_test.py b/tests/no_imports_test.py index a6fe26e31..c89a92567 100644 --- a/tests/no_imports_test.py +++ b/tests/no_imports_test.py @@ -16,6 +16,7 @@ def test_polars(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delitem(sys.modules, "pyarrow") monkeypatch.delitem(sys.modules, "dask", raising=False) monkeypatch.delitem(sys.modules, "ibis", raising=False) + monkeypatch.delitem(sys.modules, "pyspark", raising=False) df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) nw.from_native(df, eager_only=True).group_by("a").agg(nw.col("b").mean()).filter( nw.col("a") > 1 @@ -26,6 +27,7 @@ def test_polars(monkeypatch: pytest.MonkeyPatch) -> None: assert "pyarrow" not in sys.modules assert "dask" not in sys.modules assert "ibis" not in sys.modules + assert "pyspark" not in sys.modules def test_pandas(monkeypatch: pytest.MonkeyPatch) -> None: @@ -33,6 +35,7 @@ def test_pandas(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delitem(sys.modules, "pyarrow") monkeypatch.delitem(sys.modules, "dask", raising=False) monkeypatch.delitem(sys.modules, "ibis", raising=False) + monkeypatch.delitem(sys.modules, "pyspark", raising=False) df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) nw.from_native(df, eager_only=True).group_by("a").agg(nw.col("b").mean()).filter( nw.col("a") > 1 @@ -43,6 +46,7 @@ def test_pandas(monkeypatch: pytest.MonkeyPatch) -> None: assert "pyarrow" not in sys.modules assert "dask" not in sys.modules assert "ibis" not in sys.modules + assert "pyspark" not in sys.modules def test_dask(monkeypatch: pytest.MonkeyPatch) -> None: @@ -52,6 +56,7 @@ def test_dask(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delitem(sys.modules, "polars") monkeypatch.delitem(sys.modules, "pyarrow") + monkeypatch.delitem(sys.modules, "pyspark", raising=False) df = dd.from_pandas(pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]})) nw.from_native(df).group_by("a").agg(nw.col("b").mean()).filter(nw.col("a") > 1) assert "polars" not in sys.modules @@ -59,6 +64,7 @@ def test_dask(monkeypatch: pytest.MonkeyPatch) -> None: assert "numpy" in sys.modules assert "pyarrow" not in sys.modules assert "dask" in sys.modules + assert "pyspark" not in sys.modules def test_pyarrow(monkeypatch: pytest.MonkeyPatch) -> None: @@ -66,6 +72,7 @@ def test_pyarrow(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delitem(sys.modules, "pandas") monkeypatch.delitem(sys.modules, "dask", raising=False) monkeypatch.delitem(sys.modules, "ibis", raising=False) + monkeypatch.delitem(sys.modules, "pyspark", raising=False) df = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}) nw.from_native(df).group_by("a").agg(nw.col("b").mean()).filter(nw.col("a") > 1) assert "polars" not in sys.modules @@ -74,3 +81,4 @@ def test_pyarrow(monkeypatch: pytest.MonkeyPatch) -> None: assert "pyarrow" in sys.modules assert "dask" not in sys.modules assert "ibis" not in sys.modules + assert "pyspark" not in sys.modules diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py new file mode 100644 index 000000000..c4eb040c3 --- /dev/null +++ b/tests/spark_like_test.py @@ -0,0 +1,443 @@ +"""PySpark support in Narwhals is still _very_ limited. + +Start with a simple test file whilst we develop the basics. +Once we're a bit further along, we can integrate PySpark tests into the main test suite. +""" + +from __future__ import annotations + +from contextlib import nullcontext as does_not_raise +from typing import TYPE_CHECKING +from typing import Any + +import pandas as pd +import pytest + +import narwhals.stable.v1 as nw +from narwhals.exceptions import ColumnNotFoundError +from tests.utils import assert_equal_data + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + + from narwhals.typing import IntoFrame + from tests.utils import Constructor + + +def _pyspark_constructor_with_session(obj: Any, spark_session: SparkSession) -> IntoFrame: + # NaN and NULL are not the same in PySpark + pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index() + return ( # type: ignore[no-any-return] + spark_session.createDataFrame(pd_df).repartition(2).orderBy("index").drop("index") + ) + + +@pytest.fixture(params=[_pyspark_constructor_with_session]) +def pyspark_constructor( + request: pytest.FixtureRequest, spark_session: SparkSession +) -> Constructor: + def _constructor(obj: Any) -> IntoFrame: + return request.param(obj, spark_session) # type: ignore[no-any-return] + + return _constructor + + +# copied from tests/translate/from_native_test.py +def test_series_only(pyspark_constructor: Constructor) -> None: + obj = pyspark_constructor({"a": [1, 2, 3]}) + with pytest.raises(TypeError, match="Cannot only use `series_only`"): + _ = nw.from_native(obj, series_only=True) + + +def test_eager_only_lazy(pyspark_constructor: Constructor) -> None: + dframe = pyspark_constructor({"a": [1, 2, 3]}) + with pytest.raises(TypeError, match="Cannot only use `eager_only`"): + _ = nw.from_native(dframe, eager_only=True) + + +# copied from tests/frame/with_columns_test.py +def test_columns(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + result = df.columns + expected = ["a", "b", "z"] + assert result == expected + + +# copied from tests/frame/with_columns_test.py +def test_with_columns_order(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + result = df.with_columns(nw.col("a") + 1, d=nw.col("a") - 1) + assert result.collect_schema().names() == ["a", "b", "z", "d"] + expected = {"a": [2, 4, 3], "b": [4, 4, 6], "z": [7.0, 8, 9], "d": [0, 2, 1]} + assert_equal_data(result, expected) + + +@pytest.mark.filterwarnings("ignore:If `index_col` is not specified for `to_spark`") +def test_with_columns_empty(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select().with_columns() + assert_equal_data(result, {}) + + +def test_with_columns_order_single_row(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "i": [0, 1, 2]} + df = nw.from_native(pyspark_constructor(data)).filter(nw.col("i") < 1).drop("i") + result = df.with_columns(nw.col("a") + 1, d=nw.col("a") - 1) + assert result.collect_schema().names() == ["a", "b", "z", "d"] + expected = {"a": [2], "b": [4], "z": [7.0], "d": [0]} + assert_equal_data(result, expected) + + +# copied from tests/frame/select_test.py +def test_select(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select("a") + expected = {"a": [1, 3, 2]} + assert_equal_data(result, expected) + + +@pytest.mark.filterwarnings("ignore:If `index_col` is not specified for `to_spark`") +def test_empty_select(pyspark_constructor: Constructor) -> None: + result = nw.from_native(pyspark_constructor({"a": [1, 2, 3]})).lazy().select() + assert result.collect().shape == (0, 0) + + +# copied from tests/frame/filter_test.py +def test_filter(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + result = df.filter(nw.col("a") > 1) + expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} + assert_equal_data(result, expected) + + +@pytest.mark.filterwarnings("ignore:If `index_col` is not specified for `to_spark`") +def test_filter_with_boolean_list(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + + with pytest.raises( + NotImplementedError, + match="`LazyFrame.filter` is not supported for PySpark backend with boolean masks.", + ): + _ = df.filter([False, True, True]) + + +# copied from tests/frame/schema_test.py +@pytest.mark.filterwarnings("ignore:Determining|Resolving.*") +def test_schema(pyspark_constructor: Constructor) -> None: + df = nw.from_native( + pyspark_constructor({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8, 9]}) + ) + result = df.schema + expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} + + result = df.schema + assert result == expected + result = df.lazy().collect().schema + assert result == expected + + +def test_collect_schema(pyspark_constructor: Constructor) -> None: + df = nw.from_native( + pyspark_constructor({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8, 9]}) + ) + expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} + + result = df.collect_schema() + assert result == expected + result = df.lazy().collect().collect_schema() + assert result == expected + + +# copied from tests/frame/drop_test.py +@pytest.mark.parametrize( + ("to_drop", "expected"), + [ + ("abc", ["b", "z"]), + (["abc"], ["b", "z"]), + (["abc", "b"], ["z"]), + ], +) +def test_drop( + pyspark_constructor: Constructor, to_drop: list[str], expected: list[str] +) -> None: + data = {"abc": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + assert df.drop(to_drop).collect_schema().names() == expected + if not isinstance(to_drop, str): + assert df.drop(*to_drop).collect_schema().names() == expected + + +@pytest.mark.parametrize( + ("strict", "context"), + [ + (True, pytest.raises(ColumnNotFoundError, match="z")), + (False, does_not_raise()), + ], +) +def test_drop_strict( + pyspark_constructor: Constructor, context: Any, *, strict: bool +) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6]} + to_drop = ["a", "z"] + + df = nw.from_native(pyspark_constructor(data)) + + with context: + names_out = df.drop(to_drop, strict=strict).collect_schema().names() + assert names_out == ["b"] + + +# copied from tests/frame/head_test.py +def test_head(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]} + + df_raw = pyspark_constructor(data) + df = nw.from_native(df_raw) + + result = df.head(2) + assert_equal_data(result, expected) + + result = df.head(2) + assert_equal_data(result, expected) + + # negative indices not allowed for lazyframes + result = df.lazy().collect().head(-1) + assert_equal_data(result, expected) + + +# copied from tests/frame/sort_test.py +def test_sort(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + result = df.sort("a", "b") + expected = { + "a": [1, 2, 3], + "b": [4, 6, 4], + "z": [7.0, 9.0, 8.0], + } + assert_equal_data(result, expected) + result = df.sort("a", "b", descending=[True, False]).lazy().collect() + expected = { + "a": [3, 2, 1], + "b": [4, 6, 4], + "z": [8.0, 9.0, 7.0], + } + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("nulls_last", "expected"), + [ + (True, {"a": [0, 2, 0, -1], "b": [3, 2, 1, float("nan")]}), + (False, {"a": [-1, 0, 2, 0], "b": [float("nan"), 3, 2, 1]}), + ], +) +def test_sort_nulls( + pyspark_constructor: Constructor, *, nulls_last: bool, expected: dict[str, float] +) -> None: + data = {"a": [0, 0, 2, -1], "b": [1, 3, 2, None]} + df = nw.from_native(pyspark_constructor(data)) + result = df.sort("b", descending=True, nulls_last=nulls_last).lazy().collect() + assert_equal_data(result, expected) + + +# copied from tests/frame/add_test.py +def test_add(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + result = df.with_columns( + c=nw.col("a") + nw.col("b"), + d=nw.col("a") - nw.col("a").mean(), + e=nw.col("a") - nw.col("a").std(), + ) + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "z": [7.0, 8.0, 9.0], + "c": [5, 7, 8], + "d": [-1.0, 1.0, 0.0], + "e": [0.0, 2.0, 1.0], + } + assert_equal_data(result, expected) + + +# copied from tests/expr_and_series/all_horizontal_test.py +@pytest.mark.parametrize("expr1", ["a", nw.col("a")]) +@pytest.mark.parametrize("expr2", ["b", nw.col("b")]) +def test_allh(pyspark_constructor: Constructor, expr1: Any, expr2: Any) -> None: + data = { + "a": [False, False, True], + "b": [False, True, True], + } + df = nw.from_native(pyspark_constructor(data)) + result = df.select(all=nw.all_horizontal(expr1, expr2)) + + expected = {"all": [False, False, True]} + assert_equal_data(result, expected) + + +def test_allh_all(pyspark_constructor: Constructor) -> None: + data = { + "a": [False, False, True], + "b": [False, True, True], + } + df = nw.from_native(pyspark_constructor(data)) + result = df.select(all=nw.all_horizontal(nw.all())) + expected = {"all": [False, False, True]} + assert_equal_data(result, expected) + result = df.select(nw.all_horizontal(nw.all())) + expected = {"a": [False, False, True]} + assert_equal_data(result, expected) + + +# copied from tests/expr_and_series/count_test.py +def test_count(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, None, 6], "z": [7.0, None, None]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(nw.col("a", "b", "z").count()) + expected = {"a": [3], "b": [2], "z": [1]} + assert_equal_data(result, expected) + + +# copied from tests/expr_and_series/double_test.py +def test_double(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + result = df.with_columns(nw.all() * 2) + expected = {"a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]} + assert_equal_data(result, expected) + + +def test_double_alias(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + result = df.with_columns(nw.col("a").alias("o"), nw.all() * 2) + expected = { + "o": [1, 3, 2], + "a": [2, 6, 4], + "b": [8, 8, 12], + "z": [14.0, 16.0, 18.0], + } + assert_equal_data(result, expected) + + +# copied from tests/expr_and_series/max_test.py +def test_expr_max_expr(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + + df = nw.from_native(pyspark_constructor(data)) + result = df.select(nw.col("a", "b", "z").max()) + expected = {"a": [3], "b": [6], "z": [9.0]} + assert_equal_data(result, expected) + + +# copied from tests/expr_and_series/min_test.py +def test_expr_min_expr(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(nw.col("a", "b", "z").min()) + expected = {"a": [1], "b": [4], "z": [7.0]} + assert_equal_data(result, expected) + + +# copied from tests/expr_and_series/std_test.py +def test_std(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + + df = nw.from_native(pyspark_constructor(data)) + result = df.select( + nw.col("a").std().alias("a_ddof_default"), + nw.col("a").std(ddof=1).alias("a_ddof_1"), + nw.col("a").std(ddof=0).alias("a_ddof_0"), + nw.col("b").std(ddof=2).alias("b_ddof_2"), + nw.col("z").std(ddof=0).alias("z_ddof_0"), + ) + expected = { + "a_ddof_default": [1.0], + "a_ddof_1": [1.0], + "a_ddof_0": [0.816497], + "b_ddof_2": [1.632993], + "z_ddof_0": [0.816497], + } + assert_equal_data(result, expected) + + +# copied from tests/group_by_test.py +def test_group_by_std(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 1, 2, 2], "b": [5, 4, 3, 2]} + result = ( + nw.from_native(pyspark_constructor(data)) + .group_by("a") + .agg(nw.col("b").std()) + .sort("a") + ) + expected = {"a": [1, 2], "b": [0.707107] * 2} + assert_equal_data(result, expected) + + +def test_group_by_simple_named(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} + df = nw.from_native(pyspark_constructor(data)).lazy() + result = ( + df.group_by("a") + .agg( + b_min=nw.col("b").min(), + b_max=nw.col("b").max(), + ) + .collect() + .sort("a") + ) + expected = { + "a": [1, 2], + "b_min": [4, 6], + "b_max": [5, 6], + } + assert_equal_data(result, expected) + + +def test_group_by_simple_unnamed(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} + df = nw.from_native(pyspark_constructor(data)).lazy() + result = ( + df.group_by("a") + .agg( + nw.col("b").min(), + nw.col("c").max(), + ) + .collect() + .sort("a") + ) + expected = { + "a": [1, 2], + "b": [4, 6], + "c": [7, 1], + } + assert_equal_data(result, expected) + + +def test_group_by_multiple_keys(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 1, 2], "b": [4, 4, 6], "c": [7, 2, 1]} + df = nw.from_native(pyspark_constructor(data)).lazy() + result = ( + df.group_by("a", "b") + .agg( + c_min=nw.col("c").min(), + c_max=nw.col("c").max(), + ) + .collect() + .sort("a") + ) + expected = { + "a": [1, 2], + "b": [4, 6], + "c_min": [2, 1], + "c_max": [7, 1], + } + assert_equal_data(result, expected) diff --git a/tests/utils.py b/tests/utils.py index fb8b28a92..73ba50164 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,6 +33,7 @@ def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]: PANDAS_VERSION: tuple[int, ...] = get_module_version_as_tuple("pandas") POLARS_VERSION: tuple[int, ...] = get_module_version_as_tuple("polars") PYARROW_VERSION: tuple[int, ...] = get_module_version_as_tuple("pyarrow") +PYSPARK_VERSION: tuple[int, ...] = get_module_version_as_tuple("pyspark") Constructor: TypeAlias = Callable[[Any], IntoFrame] ConstructorEager: TypeAlias = Callable[[Any], IntoDataFrame] @@ -56,13 +57,29 @@ def _to_comparable_list(column_values: Any) -> Any: return list(column_values) +def _sort_dict_by_key( + data_dict: dict[str, list[Any]], key: str +) -> dict[str, list[Any]]: # pragma: no cover + sort_list = data_dict[key] + sorted_indices = sorted(range(len(sort_list)), key=lambda i: sort_list[i]) + return {key: [value[i] for i in sorted_indices] for key, value in data_dict.items()} + + def assert_equal_data(result: Any, expected: dict[str, Any]) -> None: + is_pyspark = ( + hasattr(result, "_compliant_frame") + and result._compliant_frame._implementation is Implementation.PYSPARK + ) if hasattr(result, "collect"): result = result.collect() if hasattr(result, "columns"): for key in result.columns: assert key in expected, (key, expected) result = {key: _to_comparable_list(result[key]) for key in expected} + if is_pyspark and expected: # pragma: no cover + sort_key = next(iter(expected.keys())) + expected = _sort_dict_by_key(expected, sort_key) + result = _sort_dict_by_key(result, sort_key) for key, expected_value in expected.items(): result_value = result[key] for i, (lhs, rhs) in enumerate(zip_strict(result_value, expected_value)):