Skip to content

Commit

Permalink
Merge branch 'main' into feat/collect-kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Jan 7, 2025
2 parents aedbff2 + 74dd9db commit 23e814d
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 49 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/extremes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "pyproject.toml"
- name: install-pretty-old-versions
run: uv pip install pipdeptree tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 "pyarrow-stubs<17" pyspark==3.3.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system
run: uv pip install pipdeptree tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 "pyarrow-stubs<17" pyspark==3.5.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system
- name: install-reqs
run: uv pip install -e ".[dev]" --system
- name: show-deps
Expand All @@ -75,7 +75,7 @@ jobs:
echo "$DEPS" | grep 'polars==0.20.3'
echo "$DEPS" | grep 'numpy==1.17.5'
echo "$DEPS" | grep 'pyarrow==11.0.0'
echo "$DEPS" | grep 'pyspark==3.3.0'
echo "$DEPS" | grep 'pyspark==3.5.0'
echo "$DEPS" | grep 'scipy==1.5.0'
echo "$DEPS" | grep 'scikit-learn==1.1.0'
- name: Run pytest
Expand All @@ -99,7 +99,7 @@ jobs:
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "pyproject.toml"
- name: install-not-so-old-versions
run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==15.0.0 "pyarrow-stubs<17" pyspark==3.4.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.10 tzdata --system
run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==15.0.0 "pyarrow-stubs<17" pyspark==3.5.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.10 tzdata --system
- name: install-reqs
run: uv pip install -e ".[dev]" --system
- name: show-deps
Expand All @@ -111,7 +111,7 @@ jobs:
echo "$DEPS" | grep 'polars==0.20.8'
echo "$DEPS" | grep 'numpy==1.24.4'
echo "$DEPS" | grep 'pyarrow==15.0.0'
echo "$DEPS" | grep 'pyspark==3.4.0'
echo "$DEPS" | grep 'pyspark==3.5.0'
echo "$DEPS" | grep 'scipy==1.8.0'
echo "$DEPS" | grep 'scikit-learn==1.3.0'
echo "$DEPS" | grep 'dask==2024.10'
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ci:
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: 'v0.8.1'
rev: 'v0.8.6'
hooks:
# Run the formatter.
- id: ruff-format
Expand All @@ -14,7 +14,7 @@ repos:
alias: check-docstrings
entry: python utils/check_docstrings.py
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.13.0'
rev: 'v1.14.1'
hooks:
- id: mypy
additional_dependencies: ['polars==1.4.1', 'pytest==8.3.2']
Expand Down
4 changes: 3 additions & 1 deletion docs/basics/dataframe_conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ which implements `__arrow_c_stream__`:
def df_to_polars(df_native: Any) -> pl.DataFrame:
if hasattr(df_native, "__arrow_c_stream__"):
return nw.from_arrow(df_native, native_namespace=pl).to_native()
msg = f"Expected object which implements '__arrow_c_stream__' got: {type(df)}"
msg = (
f"Expected object which implements '__arrow_c_stream__' got: {type(df_native)}"
)
raise TypeError(msg)


Expand Down
2 changes: 1 addition & 1 deletion narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def broadcast_and_extract_native(
rhs = rhs[0]

if isinstance(rhs, ArrowDataFrame):
return NotImplemented
return NotImplemented # type: ignore[no-any-return]

if isinstance(rhs, ArrowSeries):
if len(rhs) == 1:
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,13 @@ def arg_true(self) -> PandasLikeSeries:
def arg_min(self) -> int:
ser = self._native_series
if self._implementation is Implementation.PANDAS and self._backend_version < (1,):
return ser.values.argmin() # type: ignore[no-any-return] # noqa: PD011
return ser.values.argmin() # type: ignore[no-any-return]
return ser.argmin() # type: ignore[no-any-return]

def arg_max(self) -> int:
ser = self._native_series
if self._implementation is Implementation.PANDAS and self._backend_version < (1,):
return ser.values.argmax() # type: ignore[no-any-return] # noqa: PD011
return ser.values.argmax() # type: ignore[no-any-return]
return ser.argmax() # type: ignore[no-any-return]

# Binary comparisons
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def broadcast_align_and_extract_native(
lhs_index = lhs._native_series.index

if isinstance(rhs, PandasLikeDataFrame):
return NotImplemented
return NotImplemented # type: ignore[no-any-return]

if isinstance(rhs, PandasLikeSeries):
rhs_index = rhs._native_series.index
Expand Down
14 changes: 2 additions & 12 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,7 @@ def std(self: Self, ddof: int) -> Self:

from narwhals._spark_like.utils import _std

func = partial(
_std,
ddof=ddof,
backend_version=self._backend_version,
np_version=parse_version(np.__version__),
)
func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, "std", returns_scalar=True, ddof=ddof)

Expand All @@ -241,11 +236,6 @@ def var(self: Self, ddof: int) -> Self:

from narwhals._spark_like.utils import _var

func = partial(
_var,
ddof=ddof,
backend_version=self._backend_version,
np_version=parse_version(np.__version__),
)
func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, "var", returns_scalar=True, ddof=ddof)
13 changes: 3 additions & 10 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,13 @@ def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
)


def get_spark_function(
function_name: str, backend_version: tuple[int, ...], **kwargs: Any
) -> Column:
def get_spark_function(function_name: str, **kwargs: Any) -> Column:
if function_name in {"std", "var"}:
import numpy as np # ignore-banned-import

return partial(
_std if function_name == "std" else _var,
ddof=kwargs.get("ddof", 1),
backend_version=backend_version,
np_version=parse_version(np.__version__),
)
from pyspark.sql import functions as F # noqa: N812
Expand Down Expand Up @@ -127,9 +124,7 @@ def agg_pyspark(
function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
agg_func = get_spark_function(
function_name, backend_version=expr._backend_version, **expr._kwargs
)
agg_func = get_spark_function(function_name, **expr._kwargs)
simple_aggregations.update(
{output_name: agg_func(keys[0]) for output_name in expr._output_names}
)
Expand All @@ -146,9 +141,7 @@ def agg_pyspark(
pyspark_function = POLARS_TO_PYSPARK_AGGREGATIONS.get(
function_name, function_name
)
agg_func = get_spark_function(
pyspark_function, backend_version=expr._backend_version, **expr._kwargs
)
agg_func = get_spark_function(pyspark_function, **expr._kwargs)

simple_aggregations.update(
{
Expand Down
18 changes: 4 additions & 14 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,8 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
return obj


def _std(
_input: Column | str,
ddof: int,
backend_version: tuple[int, ...],
np_version: tuple[int, ...],
) -> Column:
if backend_version < (3, 5) or np_version > (2, 0):
def _std(_input: Column | str, ddof: int, np_version: tuple[int, ...]) -> Column:
if np_version > (2, 0):
from pyspark.sql import functions as F # noqa: N812

if ddof == 1:
Expand All @@ -142,13 +137,8 @@ def _std(
return stddev(input_col, ddof=ddof)


def _var(
_input: Column | str,
ddof: int,
backend_version: tuple[int, ...],
np_version: tuple[int, ...],
) -> Column:
if backend_version < (3, 5) or np_version > (2, 0):
def _var(_input: Column | str, ddof: int, np_version: tuple[int, ...]) -> Column:
if np_version > (2, 0):
from pyspark.sql import functions as F # noqa: N812

if ddof == 1:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pandas = ["pandas>=0.25.3"]
modin = ["modin"]
cudf = ["cudf>=24.10.0"]
pyarrow = ["pyarrow>=11.0.0"]
pyspark = ["pyspark>=3.3.0"]
pyspark = ["pyspark>=3.5.0"]
polars = ["polars>=0.20.3"]
dask = ["dask[dataframe]>=2024.8"]
duckdb = ["duckdb>=1.0"]
Expand Down
4 changes: 3 additions & 1 deletion utils/generate_backend_completeness.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ class Backend(NamedTuple):
MODULES = ["dataframe", "series", "expr"]

BACKENDS = [
Backend(name="pandas-like", module="_pandas_like", type_=BackendType.EAGER),
Backend(name="arrow", module="_arrow", type_=BackendType.EAGER),
Backend(name="dask", module="_dask", type_=BackendType.LAZY),
Backend(name="duckdb", module="_duckdb", type_=BackendType.LAZY),
Backend(name="pandas-like", module="_pandas_like", type_=BackendType.EAGER),
Backend(name="spark-like", module="_spark_like", type_=BackendType.LAZY),
]

Expand All @@ -55,6 +56,7 @@ def parse_module(module_name: str, backend: str, nw_class_name: str) -> list[str
inspect.isclass(c)
and c.__name__.endswith(nw_class_name)
and not c.__name__.startswith("Compliant") # Exclude protocols
and not c.__name__.startswith("DuckDBInterchange")
),
)

Expand Down

0 comments on commit 23e814d

Please sign in to comment.