Skip to content

Commit

Permalink
fix all_horizontal
Browse files Browse the repository at this point in the history
  • Loading branch information
EdAbati committed Sep 7, 2024
1 parent 12f62c1 commit 2b114eb
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 22 deletions.
4 changes: 1 addition & 3 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from narwhals._pyspark.dataframe import PySparkLazyFrame
from narwhals._pyspark.expr import PySparkExpr
from narwhals._pyspark.namespace import PySparkNamespace
from narwhals._pyspark.series import PySparkSeries
from narwhals._pyspark.typing import IntoPySparkExpr

CompliantNamespace = Union[
Expand All @@ -52,13 +51,12 @@
]
IntoCompliantExprT = TypeVar("IntoCompliantExprT", bound=IntoCompliantExpr)
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExpr)
CompliantSeries = Union[PandasLikeSeries, ArrowSeries, PolarsSeries, PySparkSeries]
CompliantSeries = Union[PandasLikeSeries, ArrowSeries, PolarsSeries]
ListOfCompliantSeries = Union[
list[PandasLikeSeries],
list[ArrowSeries],
list[DaskExpr],
list[PolarsSeries],
list[PySparkSeries],
]
ListOfCompliantExpr = Union[
list[PandasLikeExpr],
Expand Down
35 changes: 16 additions & 19 deletions narwhals/_pyspark/expr.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import operator
from copy import copy
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable

from narwhals._expression_parsing import maybe_evaluate_expr
from narwhals._pyspark.utils import maybe_evaluate

if TYPE_CHECKING:
from pyspark.sql import Column
Expand Down Expand Up @@ -60,22 +60,23 @@ def func(df: PySparkLazyFrame) -> list[Column]:
returns_scalar=False,
)

def _from_function(
def _from_call(
self,
function: Callable[..., Column],
call: Callable[..., Column],
expr_name: str,
*args: Any,
**kwargs: Any,
*args: PySparkExpr,
returns_scalar: bool,
**kwargs: PySparkExpr,
) -> Self:
def func(df: PySparkLazyFrame) -> list[Column]:
col_results = []
inputs = self._call(df)
_args = [maybe_evaluate_expr(df, x) for x in args]
_kwargs = {
key: maybe_evaluate_expr(df, value) for key, value in kwargs.items()
}
_args = [maybe_evaluate(df, arg) for arg in args]
_kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()}
for _input in inputs:
col_result = function(_input, *_args, **_kwargs)
col_result = call(_input, *_args, **_kwargs)
if returns_scalar:
raise NotImplementedError
col_results.append(col_result)
return col_results

Expand Down Expand Up @@ -112,18 +113,14 @@ def func(df: PySparkLazyFrame) -> list[Column]:
function_name=f"{self._function_name}->{expr_name}",
root_names=root_names,
output_names=output_names,
returns_scalar=False,
returns_scalar=self._returns_scalar or returns_scalar,
)

def __and__(self, other: PySparkExpr) -> Self:
return self._from_function(
lambda _input, other: _input.__and__(other), "__and__", other
)
return self._from_call(operator.and_, "__and__", other, returns_scalar=False)

def __gt__(self, other: PySparkExpr) -> Self:
return self._from_function(
lambda _input, other: _input.__gt__(other), "__gt__", other
)
return self._from_call(operator.gt, "__gt__", other, returns_scalar=False)

def alias(self, name: str) -> Self:
def func(df: PySparkLazyFrame) -> list[Column]:
Expand All @@ -137,5 +134,5 @@ def func(df: PySparkLazyFrame) -> list[Column]:
function_name=self._function_name,
root_names=self._root_names,
output_names=[name],
returns_scalar=False,
returns_scalar=self._returns_scalar,
)
16 changes: 16 additions & 0 deletions narwhals/_pyspark/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

from narwhals import dtypes
from narwhals._pandas_like.utils import translate_dtype
Expand Down Expand Up @@ -81,3 +82,18 @@ def _cols_from_expr(expr: IntoPySparkExpr) -> list[Column]:
raise AssertionError(msg)
columns_list.extend([pyspark_cols[0].alias(col_alias)])
return columns_list


def maybe_evaluate(df: PySparkLazyFrame, obj: Any) -> Any:
from narwhals._pyspark.expr import PySparkExpr

if isinstance(obj, PySparkExpr):
columns_result = obj._call(df)
if len(columns_result) != 1: # pragma: no cover
msg = "Multi-output expressions not supported in this context"
raise NotImplementedError(msg)
column = columns_result[0]
if obj._returns_scalar:
raise NotImplementedError
return column
return obj

0 comments on commit 2b114eb

Please sign in to comment.