Skip to content

Commit

Permalink
feat: add SparkLikeExpr methods: median, clip, is_between, `i…
Browse files Browse the repository at this point in the history
…s_duplicated`, `is_finite`, `is_in`, `is_unique`, `len`, `round` and `skew`(#1721)

* feat(spark): add missing methods to SparkLikeExpr

* feat(spark): add few missing methods

* fix: add xfail to median when python<3.9

* fix: fixing reviewd requests & updated tests

* fix: fix `PYSPARK_VERSION` for `median` calculation

* fix: fix refactor issue

* fix: remove `is_nan` method

* fix: fixing `is_duplicated` & `is_unique` & remove `n_unique`

---------

Co-authored-by: Francesco Bruzzesi <[email protected]>
  • Loading branch information
Dhanunjaya-Elluri and FBruzzesi authored Jan 7, 2025
1 parent 92e3b87 commit 46a030a
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 26 deletions.
4 changes: 4 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ where `YOUR-GITHUB-USERNAME` will be your GitHub user name.

Here's how you can set up your local development environment to contribute.

#### Prerequisites for PySpark tests

If you want to run PySpark-related tests, you'll need to have Java installed. Refer to the [Spark documentation](https://spark.apache.org/docs/latest/#downloading) for more information.

#### Option 1: Use UV (recommended)

1. Make sure you have Python3.12 installed, create a virtual environment,
Expand Down
179 changes: 156 additions & 23 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def __gt__(self, other: SparkLikeExpr) -> Self:
returns_scalar=False,
)

def abs(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.abs, "abs", returns_scalar=self._returns_scalar)

def alias(self, name: str) -> Self:
def _alias(df: SparkLikeLazyFrame) -> list[Column]:
return [col.alias(name) for col in self._call(df)]
Expand All @@ -179,44 +184,42 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]:
)

def count(self) -> Self:
def _count(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql import functions as F # noqa: N812

return F.count(_input)

return self._from_call(_count, "count", returns_scalar=True)
return self._from_call(F.count, "count", returns_scalar=True)

def max(self) -> Self:
def _max(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql import functions as F # noqa: N812

return F.max(_input)

return self._from_call(_max, "max", returns_scalar=True)
return self._from_call(F.max, "max", returns_scalar=True)

def mean(self) -> Self:
def _mean(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.mean, "mean", returns_scalar=True)

def median(self) -> Self:
def _median(_input: Column) -> Column:
import pyspark # ignore-banned-import
from pyspark.sql import functions as F # noqa: N812

return F.mean(_input)
if parse_version(pyspark.__version__) < (3, 4):
# Use percentile_approx with default accuracy parameter (10000)
return F.percentile_approx(_input.cast("double"), 0.5)

return self._from_call(_mean, "mean", returns_scalar=True)
return F.median(_input)

def min(self) -> Self:
def _min(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812
return self._from_call(_median, "median", returns_scalar=True)

return F.min(_input)
def min(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(_min, "min", returns_scalar=True)
return self._from_call(F.min, "min", returns_scalar=True)

def sum(self) -> Self:
def _sum(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

return F.sum(_input)
from pyspark.sql import functions as F # noqa: N812

return self._from_call(_sum, "sum", returns_scalar=True)
return self._from_call(F.sum, "sum", returns_scalar=True)

def std(self: Self, ddof: int) -> Self:
from functools import partial
Expand All @@ -239,3 +242,133 @@ def var(self: Self, ddof: int) -> Self:
func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, "var", returns_scalar=True, ddof=ddof)

def clip(
self,
lower_bound: Any | None = None,
upper_bound: Any | None = None,
) -> Self:
def _clip(_input: Column, lower_bound: Any, upper_bound: Any) -> Column:
from pyspark.sql import functions as F # noqa: N812

result = _input
if lower_bound is not None:
# Convert lower_bound to a literal Column
result = F.when(result < lower_bound, F.lit(lower_bound)).otherwise(
result
)
if upper_bound is not None:
# Convert upper_bound to a literal Column
result = F.when(result > upper_bound, F.lit(upper_bound)).otherwise(
result
)
return result

return self._from_call(
_clip,
"clip",
lower_bound=lower_bound,
upper_bound=upper_bound,
returns_scalar=self._returns_scalar,
)

def is_between(
self,
lower_bound: Any,
upper_bound: Any,
closed: str,
) -> Self:
def _is_between(_input: Column, lower_bound: Any, upper_bound: Any) -> Column:
if closed == "both":
return (_input >= lower_bound) & (_input <= upper_bound)
if closed == "none":
return (_input > lower_bound) & (_input < upper_bound)
if closed == "left":
return (_input >= lower_bound) & (_input < upper_bound)
return (_input > lower_bound) & (_input <= upper_bound)

return self._from_call(
_is_between,
"is_between",
lower_bound=lower_bound,
upper_bound=upper_bound,
returns_scalar=self._returns_scalar,
)

def is_duplicated(self) -> Self:
def _is_duplicated(_input: Column) -> Column:
from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

# Create a window spec that treats each value separately.
return F.count("*").over(Window.partitionBy(_input)) > 1

return self._from_call(
_is_duplicated, "is_duplicated", returns_scalar=self._returns_scalar
)

def is_finite(self) -> Self:
def _is_finite(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

# A value is finite if it's not NaN, not NULL, and not infinite
return (
~F.isnan(_input)
& ~F.isnull(_input)
& (_input != float("inf"))
& (_input != float("-inf"))
)

return self._from_call(
_is_finite, "is_finite", returns_scalar=self._returns_scalar
)

def is_in(self, values: Sequence[Any]) -> Self:
def _is_in(_input: Column, values: Sequence[Any]) -> Column:
return _input.isin(values)

return self._from_call(
_is_in,
"is_in",
values=values,
returns_scalar=self._returns_scalar,
)

def is_unique(self) -> Self:
def _is_unique(_input: Column) -> Column:
from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

# Create a window spec that treats each value separately
return F.count("*").over(Window.partitionBy(_input)) == 1

return self._from_call(
_is_unique, "is_unique", returns_scalar=self._returns_scalar
)

def len(self) -> Self:
def _len(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

# Use count(*) to count all rows including nulls
return F.count("*")

return self._from_call(_len, "len", returns_scalar=True)

def round(self, decimals: int) -> Self:
def _round(_input: Column, decimals: int) -> Column:
from pyspark.sql import functions as F # noqa: N812

return F.round(_input, decimals)

return self._from_call(
_round,
"round",
decimals=decimals,
returns_scalar=self._returns_scalar,
)

def skew(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.skewness, "skew", returns_scalar=True)
6 changes: 5 additions & 1 deletion narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ def is_pandas_like(self) -> bool:
>>> df.implementation.is_pandas_like()
True
"""
return self in {Implementation.PANDAS, Implementation.MODIN, Implementation.CUDF}
return self in {
Implementation.PANDAS,
Implementation.MODIN,
Implementation.CUDF,
}

def is_polars(self) -> bool:
"""Return whether implementation is Polars.
Expand Down
Loading

0 comments on commit 46a030a

Please sign in to comment.