Skip to content

Commit

Permalink
feat: Add minimal PySpark support (#908)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
EdAbati and pre-commit-ci[bot] authored Dec 5, 2024
1 parent eebc89a commit ea278ae
Show file tree
Hide file tree
Showing 19 changed files with 1,375 additions and 16 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/extremes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ jobs:
- name: install-minimum-versions
run: uv pip install tox virtualenv setuptools pandas==0.25.3 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system
- name: install-reqs
run: uv pip install -r requirements-dev.txt --system
run: |
uv pip install -r requirements-dev.txt --system
: # pyspark >= 3.3.0 is not compatible with pandas==0.25.3
uv pip uninstall pyspark --system
- name: show-deps
run: uv pip freeze
- name: Run pytest
Expand All @@ -52,7 +55,7 @@ jobs:
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "**requirements*.txt"
- name: install-minimum-versions
run: uv pip install tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system
run: uv pip install tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 pyspark==3.3.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system
- name: install-reqs
run: uv pip install -r requirements-dev.txt --system
- name: show-deps
Expand All @@ -79,7 +82,7 @@ jobs:
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "**requirements*.txt"
- name: install-not-so-old-versions
run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==14.0.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.7 tzdata --system
run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==14.0.0 pyspark==3.4.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.7 tzdata --system
- name: install-reqs
run: uv pip install -r requirements-dev.txt --system
- name: show-deps
Expand Down
32 changes: 27 additions & 5 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,21 @@
from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.series import PolarsSeries
from narwhals._polars.typing import IntoPolarsExpr
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals._spark_like.typing import IntoSparkLikeExpr

CompliantNamespace = Union[
PandasLikeNamespace, ArrowNamespace, DaskNamespace, PolarsNamespace
PandasLikeNamespace,
ArrowNamespace,
DaskNamespace,
PolarsNamespace,
SparkLikeNamespace,
]
CompliantExpr = Union[PandasLikeExpr, ArrowExpr, DaskExpr, PolarsExpr]
CompliantExpr = Union[PandasLikeExpr, ArrowExpr, DaskExpr, PolarsExpr, SparkLikeExpr]
IntoCompliantExpr = Union[
IntoPandasLikeExpr, IntoArrowExpr, IntoDaskExpr, IntoPolarsExpr
IntoPandasLikeExpr, IntoArrowExpr, IntoDaskExpr, IntoPolarsExpr, IntoSparkLikeExpr
]
IntoCompliantExprT = TypeVar("IntoCompliantExprT", bound=IntoCompliantExpr)
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExpr)
Expand All @@ -50,9 +58,15 @@
list[PandasLikeSeries], list[ArrowSeries], list[DaskExpr], list[PolarsSeries]
]
ListOfCompliantExpr = Union[
list[PandasLikeExpr], list[ArrowExpr], list[DaskExpr], list[PolarsExpr]
list[PandasLikeExpr],
list[ArrowExpr],
list[DaskExpr],
list[PolarsExpr],
list[SparkLikeExpr],
]
CompliantDataFrame = Union[
PandasLikeDataFrame, ArrowDataFrame, DaskLazyFrame, SparkLikeLazyFrame
]
CompliantDataFrame = Union[PandasLikeDataFrame, ArrowDataFrame, DaskLazyFrame]

T = TypeVar("T")

Expand Down Expand Up @@ -152,6 +166,14 @@ def parse_into_exprs(
) -> list[PolarsExpr]: ...


@overload
def parse_into_exprs(
*exprs: IntoSparkLikeExpr,
namespace: SparkLikeNamespace,
**named_exprs: IntoSparkLikeExpr,
) -> list[SparkLikeExpr]: ...


def parse_into_exprs(
*exprs: IntoCompliantExpr,
namespace: CompliantNamespace,
Expand Down
Empty file.
188 changes: 188 additions & 0 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Sequence

from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.utils import Implementation
from narwhals.utils import flatten
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version

if TYPE_CHECKING:
from pyspark.sql import DataFrame
from typing_extensions import Self

from narwhals._spark_like.expr import SparkLikeExpr
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals._spark_like.typing import IntoSparkLikeExpr
from narwhals.dtypes import DType
from narwhals.utils import Version


class SparkLikeLazyFrame:
def __init__(
self,
native_dataframe: DataFrame,
*,
backend_version: tuple[int, ...],
version: Version,
) -> None:
self._native_frame = native_dataframe
self._backend_version = backend_version
self._implementation = Implementation.PYSPARK
self._version = version

def __native_namespace__(self) -> Any: # pragma: no cover
if self._implementation is Implementation.PYSPARK:
return self._implementation.to_native_namespace()

msg = f"Expected pyspark, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

def __narwhals_namespace__(self) -> SparkLikeNamespace:
from narwhals._spark_like.namespace import SparkLikeNamespace

return SparkLikeNamespace(
backend_version=self._backend_version, version=self._version
)

def __narwhals_lazyframe__(self) -> Self:
return self

def _change_version(self, version: Version) -> Self:
return self.__class__(
self._native_frame, backend_version=self._backend_version, version=version
)

def _from_native_frame(self, df: DataFrame) -> Self:
return self.__class__(
df, backend_version=self._backend_version, version=self._version
)

@property
def columns(self) -> list[str]:
return self._native_frame.columns # type: ignore[no-any-return]

def collect(self) -> Any:
import pandas as pd # ignore-banned-import()

from narwhals._pandas_like.dataframe import PandasLikeDataFrame

return PandasLikeDataFrame(
native_dataframe=self._native_frame.toPandas(),
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
version=self._version,
)

def select(
self: Self,
*exprs: IntoSparkLikeExpr,
**named_exprs: IntoSparkLikeExpr,
) -> Self:
if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs:
# This is a simple select
return self._from_native_frame(self._native_frame.select(*exprs))

new_columns = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)

if not new_columns:
# return empty dataframe, like Polars does
from pyspark.sql.types import StructType

spark_session = self._native_frame.sparkSession
spark_df = spark_session.createDataFrame([], StructType([]))

return self._from_native_frame(spark_df)

new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.select(*new_columns_list))

def filter(self, *predicates: SparkLikeExpr) -> Self:
from narwhals._spark_like.namespace import SparkLikeNamespace

if (
len(predicates) == 1
and isinstance(predicates[0], list)
and all(isinstance(x, bool) for x in predicates[0])
):
msg = "`LazyFrame.filter` is not supported for PySpark backend with boolean masks."
raise NotImplementedError(msg)
plx = SparkLikeNamespace(
backend_version=self._backend_version, version=self._version
)
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
condition = expr._call(self)[0]
spark_df = self._native_frame.where(condition)
return self._from_native_frame(spark_df)

@property
def schema(self) -> dict[str, DType]:
return {
field.name: native_to_narwhals_dtype(
dtype=field.dataType, version=self._version
)
for field in self._native_frame.schema
}

def collect_schema(self) -> dict[str, DType]:
return self.schema

def with_columns(
self: Self,
*exprs: IntoSparkLikeExpr,
**named_exprs: IntoSparkLikeExpr,
) -> Self:
new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
columns_to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
)
return self._from_native_frame(self._native_frame.drop(*columns_to_drop))

def head(self: Self, n: int) -> Self:
spark_session = self._native_frame.sparkSession

return self._from_native_frame(
spark_session.createDataFrame(self._native_frame.take(num=n))
)

def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy:
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy

return SparkLikeLazyGroupBy(
df=self, keys=list(keys), drop_null_keys=drop_null_keys
)

def sort(
self: Self,
by: str | Iterable[str],
*more_by: str,
descending: bool | Sequence[bool] = False,
nulls_last: bool = False,
) -> Self:
import pyspark.sql.functions as F # noqa: N812

flat_by = flatten([*flatten([by]), *more_by])
if isinstance(descending, bool):
descending = [descending]

if nulls_last:
sort_funcs = [
F.desc_nulls_last if d else F.asc_nulls_last for d in descending
]
else:
sort_funcs = [
F.desc_nulls_first if d else F.asc_nulls_first for d in descending
]

sort_cols = [sort_f(col) for col, sort_f in zip(flat_by, sort_funcs)]
return self._from_native_frame(self._native_frame.sort(*sort_cols))
Loading

0 comments on commit ea278ae

Please sign in to comment.