Skip to content

Commit

Permalink
chore: increase PySpark min version to 3.5.0 (#1744)
Browse files Browse the repository at this point in the history
  • Loading branch information
EdAbati authored Jan 7, 2025
1 parent 17546f2 commit 3e42edd
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 41 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/extremes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "pyproject.toml"
- name: install-pretty-old-versions
run: uv pip install pipdeptree tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 "pyarrow-stubs<17" pyspark==3.3.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system
run: uv pip install pipdeptree tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 "pyarrow-stubs<17" pyspark==3.5.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system
- name: install-reqs
run: uv pip install -e ".[dev]" --system
- name: show-deps
Expand All @@ -75,7 +75,7 @@ jobs:
echo "$DEPS" | grep 'polars==0.20.3'
echo "$DEPS" | grep 'numpy==1.17.5'
echo "$DEPS" | grep 'pyarrow==11.0.0'
echo "$DEPS" | grep 'pyspark==3.3.0'
echo "$DEPS" | grep 'pyspark==3.5.0'
echo "$DEPS" | grep 'scipy==1.5.0'
echo "$DEPS" | grep 'scikit-learn==1.1.0'
- name: Run pytest
Expand All @@ -99,7 +99,7 @@ jobs:
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "pyproject.toml"
- name: install-not-so-old-versions
run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==15.0.0 "pyarrow-stubs<17" pyspark==3.4.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.10 tzdata --system
run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==15.0.0 "pyarrow-stubs<17" pyspark==3.5.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.10 tzdata --system
- name: install-reqs
run: uv pip install -e ".[dev]" --system
- name: show-deps
Expand All @@ -111,7 +111,7 @@ jobs:
echo "$DEPS" | grep 'polars==0.20.8'
echo "$DEPS" | grep 'numpy==1.24.4'
echo "$DEPS" | grep 'pyarrow==15.0.0'
echo "$DEPS" | grep 'pyspark==3.4.0'
echo "$DEPS" | grep 'pyspark==3.5.0'
echo "$DEPS" | grep 'scipy==1.8.0'
echo "$DEPS" | grep 'scikit-learn==1.3.0'
echo "$DEPS" | grep 'dask==2024.10'
Expand Down
14 changes: 2 additions & 12 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,7 @@ def std(self: Self, ddof: int) -> Self:

from narwhals._spark_like.utils import _std

func = partial(
_std,
ddof=ddof,
backend_version=self._backend_version,
np_version=parse_version(np.__version__),
)
func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, "std", returns_scalar=True, ddof=ddof)

Expand All @@ -241,11 +236,6 @@ def var(self: Self, ddof: int) -> Self:

from narwhals._spark_like.utils import _var

func = partial(
_var,
ddof=ddof,
backend_version=self._backend_version,
np_version=parse_version(np.__version__),
)
func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, "var", returns_scalar=True, ddof=ddof)
13 changes: 3 additions & 10 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,13 @@ def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
)


def get_spark_function(
function_name: str, backend_version: tuple[int, ...], **kwargs: Any
) -> Column:
def get_spark_function(function_name: str, **kwargs: Any) -> Column:
if function_name in {"std", "var"}:
import numpy as np # ignore-banned-import

return partial(
_std if function_name == "std" else _var,
ddof=kwargs.get("ddof", 1),
backend_version=backend_version,
np_version=parse_version(np.__version__),
)
from pyspark.sql import functions as F # noqa: N812
Expand Down Expand Up @@ -127,9 +124,7 @@ def agg_pyspark(
function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
agg_func = get_spark_function(
function_name, backend_version=expr._backend_version, **expr._kwargs
)
agg_func = get_spark_function(function_name, **expr._kwargs)
simple_aggregations.update(
{output_name: agg_func(keys[0]) for output_name in expr._output_names}
)
Expand All @@ -146,9 +141,7 @@ def agg_pyspark(
pyspark_function = POLARS_TO_PYSPARK_AGGREGATIONS.get(
function_name, function_name
)
agg_func = get_spark_function(
pyspark_function, backend_version=expr._backend_version, **expr._kwargs
)
agg_func = get_spark_function(pyspark_function, **expr._kwargs)

simple_aggregations.update(
{
Expand Down
18 changes: 4 additions & 14 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,8 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
return obj


def _std(
_input: Column | str,
ddof: int,
backend_version: tuple[int, ...],
np_version: tuple[int, ...],
) -> Column:
if backend_version < (3, 5) or np_version > (2, 0):
def _std(_input: Column | str, ddof: int, np_version: tuple[int, ...]) -> Column:
if np_version > (2, 0):
from pyspark.sql import functions as F # noqa: N812

if ddof == 1:
Expand All @@ -142,13 +137,8 @@ def _std(
return stddev(input_col, ddof=ddof)


def _var(
_input: Column | str,
ddof: int,
backend_version: tuple[int, ...],
np_version: tuple[int, ...],
) -> Column:
if backend_version < (3, 5) or np_version > (2, 0):
def _var(_input: Column | str, ddof: int, np_version: tuple[int, ...]) -> Column:
if np_version > (2, 0):
from pyspark.sql import functions as F # noqa: N812

if ddof == 1:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pandas = ["pandas>=0.25.3"]
modin = ["modin"]
cudf = ["cudf>=24.10.0"]
pyarrow = ["pyarrow>=11.0.0"]
pyspark = ["pyspark>=3.3.0"]
pyspark = ["pyspark>=3.5.0"]
polars = ["polars>=0.20.3"]
dask = ["dask[dataframe]>=2024.8"]
duckdb = ["duckdb>=1.0"]
Expand Down

0 comments on commit 3e42edd

Please sign in to comment.