Skip to content

Commit

Permalink
fix: fix PYSPARK_VERSION for median calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhanunjaya-Elluri committed Jan 5, 2025
1 parent cafd092 commit 137b1be
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 37 deletions.
21 changes: 1 addition & 20 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,7 @@ Here's how you can set up your local development environment to contribute.

#### Prerequisites for PySpark tests

If you want to run PySpark-related tests, you'll need to have Java installed:

- On Ubuntu/Debian:
```bash
sudo apt-get update
sudo apt-get install default-jdk
sudo apt-get install default-jre
```

- On macOS:
Follow the instructions [here](https://www.java.com/en/download/help/mac_install.html)

- On Windows:
Follow the instructions [here](https://www.java.com/en/download/help/windows_manual_download.html)
- Add JAVA_HOME to your environment variables

You can verify your Java installation by running:
```bash
java -version
```
If you want to run PySpark-related tests, you'll need to have Java installed. Refer to the [Spark documentation](https://spark.apache.org/docs/latest/#downloading) for more information.

#### Option 1: Use UV (recommended)

Expand Down
14 changes: 12 additions & 2 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from narwhals._spark_like.utils import maybe_evaluate
from narwhals.typing import CompliantExpr
from narwhals.utils import Implementation
from narwhals.utils import get_module_version_as_tuple
from narwhals.utils import parse_version

if TYPE_CHECKING:
Expand All @@ -20,6 +21,8 @@
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals.utils import Version

PYSPARK_VERSION: tuple[int, ...] = get_module_version_as_tuple("pyspark")


class SparkLikeExpr(CompliantExpr["Column"]):
_implementation = Implementation.PYSPARK
Expand Down Expand Up @@ -222,9 +225,16 @@ def mean(self) -> Self:
return self._from_call(F.mean, "mean", returns_scalar=True)

def median(self) -> Self:
from pyspark.sql import functions as F # noqa: N812
def _median(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

if PYSPARK_VERSION < (3, 4):
# Use percentile_approx with default accuracy parameter (10000)
return F.percentile_approx(_input.cast("double"), 0.5)

return F.median(_input)

return self._from_call(F.median, "median", returns_scalar=True)
return self._from_call(_median, "median", returns_scalar=True)

def min(self) -> Self:
from pyspark.sql import functions as F # noqa: N812
Expand Down
13 changes: 12 additions & 1 deletion narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ def is_pandas_like(self) -> bool:
>>> df.implementation.is_pandas_like()
True
"""
return self in {Implementation.PANDAS, Implementation.MODIN, Implementation.CUDF}
return self in {
Implementation.PANDAS,
Implementation.MODIN,
Implementation.CUDF,
}

def is_polars(self) -> bool:
"""Return whether implementation is Polars.
Expand Down Expand Up @@ -1054,3 +1058,10 @@ def generate_repr(header: str, native_repr: str) -> str:
"| Use `.to_native` to see native output |\n└"
f"{'─' * 39}┘"
)


def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]:
try:
return parse_version(__import__(module_name).__version__)
except ImportError:
return (0, 0, 0)
5 changes: 0 additions & 5 deletions tests/spark_like_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from __future__ import annotations

import sys
from contextlib import nullcontext as does_not_raise
from typing import TYPE_CHECKING
from typing import Any
Expand Down Expand Up @@ -958,10 +957,6 @@ def test_left_join_overlapping_column(pyspark_constructor: Constructor) -> None:


# Copied from tests/expr_and_series/median_test.py
@pytest.mark.xfail(
sys.version_info < (3, 9),
reason="median() not supported on Python 3.8",
)
def test_median(pyspark_constructor: Constructor) -> None:
data = {"a": [3, 8, 2, None], "b": [5, 5, None, 7], "z": [7.0, 8, 9, None]}
df = nw.from_native(pyspark_constructor(data))
Expand Down
9 changes: 1 addition & 8 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,14 @@
from narwhals.typing import IntoDataFrame
from narwhals.typing import IntoFrame
from narwhals.utils import Implementation
from narwhals.utils import parse_version
from narwhals.utils import get_module_version_as_tuple

if sys.version_info >= (3, 10):
from typing import TypeAlias # pragma: no cover
else:
from typing_extensions import TypeAlias # pragma: no cover


def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]:
try:
return parse_version(__import__(module_name).__version__)
except ImportError:
return (0, 0, 0)


IBIS_VERSION: tuple[int, ...] = get_module_version_as_tuple("ibis")
NUMPY_VERSION: tuple[int, ...] = get_module_version_as_tuple("numpy")
PANDAS_VERSION: tuple[int, ...] = get_module_version_as_tuple("pandas")
Expand Down
2 changes: 1 addition & 1 deletion tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from pandas.testing import assert_series_equal

import narwhals.stable.v1 as nw
from narwhals.utils import get_module_version_as_tuple
from tests.utils import PANDAS_VERSION
from tests.utils import get_module_version_as_tuple

if TYPE_CHECKING:
from narwhals.series import Series
Expand Down

0 comments on commit 137b1be

Please sign in to comment.