diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 517aa0508..4930e0e75 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -36,7 +36,6 @@ from narwhals._pyspark.dataframe import PySparkLazyFrame from narwhals._pyspark.expr import PySparkExpr from narwhals._pyspark.namespace import PySparkNamespace - from narwhals._pyspark.series import PySparkSeries from narwhals._pyspark.typing import IntoPySparkExpr CompliantNamespace = Union[ @@ -52,13 +51,12 @@ ] IntoCompliantExprT = TypeVar("IntoCompliantExprT", bound=IntoCompliantExpr) CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExpr) - CompliantSeries = Union[PandasLikeSeries, ArrowSeries, PolarsSeries, PySparkSeries] + CompliantSeries = Union[PandasLikeSeries, ArrowSeries, PolarsSeries] ListOfCompliantSeries = Union[ list[PandasLikeSeries], list[ArrowSeries], list[DaskExpr], list[PolarsSeries], - list[PySparkSeries], ] ListOfCompliantExpr = Union[ list[PandasLikeExpr], diff --git a/narwhals/_pyspark/expr.py b/narwhals/_pyspark/expr.py index 7f01c8d69..7367e655d 100644 --- a/narwhals/_pyspark/expr.py +++ b/narwhals/_pyspark/expr.py @@ -1,11 +1,11 @@ from __future__ import annotations +import operator from copy import copy from typing import TYPE_CHECKING -from typing import Any from typing import Callable -from narwhals._expression_parsing import maybe_evaluate_expr +from narwhals._pyspark.utils import maybe_evaluate if TYPE_CHECKING: from pyspark.sql import Column @@ -60,22 +60,23 @@ def func(df: PySparkLazyFrame) -> list[Column]: returns_scalar=False, ) - def _from_function( + def _from_call( self, - function: Callable[..., Column], + call: Callable[..., Column], expr_name: str, - *args: Any, - **kwargs: Any, + *args: PySparkExpr, + returns_scalar: bool, + **kwargs: PySparkExpr, ) -> Self: def func(df: PySparkLazyFrame) -> list[Column]: col_results = [] inputs = self._call(df) - _args = [maybe_evaluate_expr(df, x) for x in args] - _kwargs = { - key: maybe_evaluate_expr(df, value) for key, value in kwargs.items() - } + _args = [maybe_evaluate(df, arg) for arg in args] + _kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()} for _input in inputs: - col_result = function(_input, *_args, **_kwargs) + col_result = call(_input, *_args, **_kwargs) + if returns_scalar: + raise NotImplementedError col_results.append(col_result) return col_results @@ -112,18 +113,14 @@ def func(df: PySparkLazyFrame) -> list[Column]: function_name=f"{self._function_name}->{expr_name}", root_names=root_names, output_names=output_names, - returns_scalar=False, + returns_scalar=self._returns_scalar or returns_scalar, ) def __and__(self, other: PySparkExpr) -> Self: - return self._from_function( - lambda _input, other: _input.__and__(other), "__and__", other - ) + return self._from_call(operator.and_, "__and__", other, returns_scalar=False) def __gt__(self, other: PySparkExpr) -> Self: - return self._from_function( - lambda _input, other: _input.__gt__(other), "__gt__", other - ) + return self._from_call(operator.gt, "__gt__", other, returns_scalar=False) def alias(self, name: str) -> Self: def func(df: PySparkLazyFrame) -> list[Column]: @@ -137,5 +134,5 @@ def func(df: PySparkLazyFrame) -> list[Column]: function_name=self._function_name, root_names=self._root_names, output_names=[name], - returns_scalar=False, + returns_scalar=self._returns_scalar, ) diff --git a/narwhals/_pyspark/utils.py b/narwhals/_pyspark/utils.py index 5db7f1404..9f8198724 100644 --- a/narwhals/_pyspark/utils.py +++ b/narwhals/_pyspark/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Any from narwhals import dtypes from narwhals._pandas_like.utils import translate_dtype @@ -81,3 +82,18 @@ def _cols_from_expr(expr: IntoPySparkExpr) -> list[Column]: raise AssertionError(msg) columns_list.extend([pyspark_cols[0].alias(col_alias)]) return columns_list + + +def maybe_evaluate(df: PySparkLazyFrame, obj: Any) -> Any: + from narwhals._pyspark.expr import PySparkExpr + + if isinstance(obj, PySparkExpr): + columns_result = obj._call(df) + if len(columns_result) != 1: # pragma: no cover + msg = "Multi-output expressions not supported in this context" + raise NotImplementedError(msg) + column = columns_result[0] + if obj._returns_scalar: + raise NotImplementedError + return column + return obj