diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 1cc9d3327..7940ddeb0 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -11,6 +11,7 @@ from typing import cast from typing import overload +from narwhals._pandas_like.utils import Implementation from narwhals.dependencies import get_numpy from narwhals.utils import flatten @@ -192,8 +193,12 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]: if expr._output_names is not None and ( [s.name for s in out] != expr._output_names ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) + if not ( + hasattr(expr, "_implementation") + and expr._implementation is Implementation.DASK + ): + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) return out # Try tracking root and output names by combining them from all diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index fa6e707ca..97691dd8b 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -19,6 +19,7 @@ from narwhals._pandas_like.utils import validate_dataframe_comparand from narwhals._pandas_like.utils import validate_indices from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas @@ -66,7 +67,9 @@ def __native_namespace__(self) -> Any: return get_modin() if self._implementation is Implementation.CUDF: # pragma: no cover return get_cudf() - msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover + if self._implementation is Implementation.DASK: # pragma: no cover + return get_dask() + msg = f"Expected pandas/modin/cudf/dask, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) def __len__(self) -> int: @@ -200,6 +203,9 @@ def select( new_series = evaluate_into_exprs(self, *exprs, **named_exprs) if not new_series: # return empty dataframe, like Polars does + if self._implementation is Implementation.DASK: + dd = get_dask() + return self._from_native_dataframe(dd.from_dict({}, npartitions=1)) return self._from_native_dataframe(self._native_dataframe.__class__()) new_series = validate_indices(new_series) df = horizontal_concat( @@ -312,9 +318,15 @@ def sort( # --- convert --- def collect(self) -> PandasLikeDataFrame: + if self._implementation is Implementation.DASK: + return_df = self._native_dataframe.compute() + return_implementation = Implementation.PANDAS + else: + return_df = self._native_dataframe + return_implementation = self._implementation return PandasLikeDataFrame( - self._native_dataframe, - implementation=self._implementation, + return_df, + implementation=return_implementation, backend_version=self._backend_version, ) @@ -487,6 +499,8 @@ def to_numpy(self) -> Any: import numpy as np return np.hstack([self[col].to_numpy()[:, None] for col in self.columns]) + if self._implementation is Implementation.DASK: + return self._native_dataframe.compute().to_numpy() return self._native_dataframe.to_numpy() def to_pandas(self) -> Any: @@ -494,6 +508,8 @@ def to_pandas(self) -> Any: return self._native_dataframe if self._implementation is Implementation.MODIN: # pragma: no cover return self._native_dataframe._to_pandas() + if self._implementation is Implementation.DASK: # pragma: no cover + return self._native_dataframe.compute() return self._native_dataframe.to_pandas() # pragma: no cover def write_parquet(self, file: Any) -> Any: diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 22a2d8621..9a4f61043 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -28,11 +28,15 @@ class PandasLikeGroupBy: def __init__(self, df: PandasLikeDataFrame, keys: list[str]) -> None: self._df = df self._keys = list(keys) + keywords: dict[str, bool] = {} + if df._implementation is not Implementation.DASK: + keywords |= {"as_index": True} self._grouped = self._df._native_dataframe.groupby( list(self._keys), sort=False, as_index=True, dropna=False, + **keywords, ) def agg( @@ -57,13 +61,18 @@ def agg( raise ValueError(msg) output_names.extend(expr._output_names) + dataframe_is_empty = ( + self._df._native_dataframe.empty + if self._df._implementation != Implementation.DASK + else len(self._df._native_dataframe) == 0 + ) return agg_pandas( self._grouped, exprs, self._keys, output_names, self._from_native_dataframe, - dataframe_is_empty=self._df._native_dataframe.empty, + dataframe_is_empty=dataframe_is_empty, implementation=implementation, backend_version=self._df._backend_version, ) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index dbae3bbb7..6a250fdc3 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -12,6 +12,7 @@ from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.selectors import PandasSelectorNamespace from narwhals._pandas_like.series import PandasLikeSeries +from narwhals._pandas_like.utils import Implementation from narwhals._pandas_like.utils import create_native_series from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import vertical_concat @@ -19,7 +20,6 @@ if TYPE_CHECKING: from narwhals._pandas_like.typing import IntoPandasLikeExpr - from narwhals._pandas_like.utils import Implementation class PandasLikeNamespace: @@ -78,10 +78,15 @@ def _create_expr_from_callable( def _create_series_from_scalar( self, value: Any, series: PandasLikeSeries ) -> PandasLikeSeries: + index = ( + series._native_series.index[0:1] + if self._implementation is not Implementation.DASK + else None + ) return PandasLikeSeries._from_iterable( [value], name=series._native_series.name, - index=series._native_series.index[0:1], + index=index, implementation=self._implementation, backend_version=self._backend_version, ) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 042a5eb0f..8c43f30f5 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -9,11 +9,13 @@ from narwhals._pandas_like.utils import Implementation from narwhals._pandas_like.utils import int_dtype_mapper from narwhals._pandas_like.utils import native_series_from_iterable +from narwhals._pandas_like.utils import not_implemented_in from narwhals._pandas_like.utils import reverse_translate_dtype from narwhals._pandas_like.utils import to_datetime from narwhals._pandas_like.utils import translate_dtype from narwhals._pandas_like.utils import validate_column_comparand from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas from narwhals.dependencies import get_pyarrow_compute @@ -107,12 +109,15 @@ def __native_namespace__(self) -> Any: return get_modin() if self._implementation is Implementation.CUDF: # pragma: no cover return get_cudf() - msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover + if self._implementation is Implementation.DASK: # pragma: no cover + return get_dask() + msg = f"Expected pandas/modin/cudf/dask, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) def __narwhals_series__(self) -> Self: return self + @not_implemented_in(Implementation.DASK) def __getitem__(self, idx: int | slice | Sequence[int]) -> Any: if isinstance(idx, int): return self._native_series.iloc[idx] @@ -152,7 +157,7 @@ def _from_iterable( ) def __len__(self) -> int: - return self.shape[0] + return len(self._native_series) @property def name(self) -> str: @@ -183,7 +188,11 @@ def item(self: Self, index: int | None = None) -> Any: f" or an explicit index is provided (Series is of length {len(self)})" ) raise ValueError(msg) + if self._implementation is Implementation.DASK: + return self._native_series.max() # hack: taking aggregation of 1 item return self._native_series.iloc[0] + if self._implementation is Implementation.DASK: + raise NotImplementedError("Dask does not support index locating") return self._native_series.iloc[index] def to_frame(self) -> Any: @@ -196,6 +205,8 @@ def to_frame(self) -> Any: ) def to_list(self) -> Any: + if self._implementation is Implementation.DASK: + return self._native_series.compute().to_list() return self._native_series.to_list() def is_between( @@ -504,10 +515,13 @@ def to_pandas(self) -> Any: return self._native_series.to_pandas() elif self._implementation is Implementation.MODIN: # pragma: no cover return self._native_series._to_pandas() + elif self._implementation is Implementation.DASK: # pragma: no cover + return self._native_series.compute() msg = f"Unknown implementation: {self._implementation}" # pragma: no cover raise AssertionError(msg) # --- descriptive --- + @not_implemented_in(Implementation.DASK) def is_duplicated(self: Self) -> Self: return self._from_native_series(self._native_series.duplicated(keep=False)) @@ -520,9 +534,11 @@ def is_unique(self: Self) -> Self: def null_count(self: Self) -> int: return self._native_series.isna().sum() # type: ignore[no-any-return] + @not_implemented_in(Implementation.DASK) def is_first_distinct(self: Self) -> Self: return self._from_native_series(~self._native_series.duplicated(keep="first")) + @not_implemented_in(Implementation.DASK) def is_last_distinct(self: Self) -> Self: return self._from_native_series(~self._native_series.duplicated(keep="last")) @@ -559,6 +575,15 @@ def quantile( quantile: float, interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], ) -> Any: + if self._implementation is Implementation.DASK: + if interpolation == "linear": + return self._native_series.quantile(q=quantile) + message = ( + "Dask performs approximate quantile calculations " + "and does not support specific interpolations methods. " + "Interpolation keywords other than 'linear' are not supported" + ) + raise NotImplementedError(message) return self._native_series.quantile(q=quantile, interpolation=interpolation) def zip_with(self: Self, mask: Any, other: Any) -> PandasLikeSeries: @@ -594,6 +619,13 @@ def __init__(self, series: PandasLikeSeries) -> None: def get_categories(self) -> PandasLikeSeries: s = self._pandas_series._native_series + if self._pandas_series._implementation is Implementation.DASK: + pd = get_pandas() + dd = get_dask() + native_series = pd.Series(s.cat.as_known().cat.categories, name=s.name).pipe( + dd.from_pandas + ) + return self._pandas_series._from_native_series(native_series) return self._pandas_series._from_native_series( s.__class__(s.cat.categories, name=s.name) ) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 3548966ca..316495aa4 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -3,12 +3,15 @@ import secrets from enum import Enum from enum import auto +from functools import wraps from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import TypeVar from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas from narwhals.utils import isinstance_or_issubclass @@ -27,6 +30,7 @@ class Implementation(Enum): PANDAS = auto() MODIN = auto() CUDF = auto() + DASK = auto() def validate_column_comparand(index: Any, other: Any) -> Any: @@ -53,7 +57,10 @@ def validate_column_comparand(index: Any, other: Any) -> Any: if other.len() == 1: # broadcast return other.item() - if other._native_series.index is not index: + if ( + other._native_series.index is not index + and other._implementation is not Implementation.DASK + ): return set_axis( other._native_series, index, @@ -79,7 +86,10 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: if other.len() == 1: # broadcast return other._native_series.iloc[0] - if other._native_series.index is not index: + if ( + other._native_series.index is not index + and other._implementation is not Implementation.DASK + ): return set_axis( other._native_series, index, @@ -87,7 +97,7 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: backend_version=other._backend_version, ) return other._native_series - msg = "Please report a bug" # pragma: no cover + msg = "Please report a bug" raise AssertionError(msg) @@ -109,6 +119,10 @@ def create_native_series( elif implementation is Implementation.CUDF: cudf = get_cudf() series = cudf.Series(iterable, index=index, name="") + elif implementation is Implementation.DASK: + pd = get_pandas() + dd = get_dask() + series = pd.Series(iterable, index=index, name="").pipe(dd.from_pandas) return PandasLikeSeries( series, implementation=implementation, backend_version=backend_version ) @@ -136,6 +150,12 @@ def horizontal_concat( mpd = get_modin() return mpd.concat(dfs, axis=1) + if implementation is Implementation.DASK: # pragma: no cover + dd = get_dask() + pd = get_pandas() + if isinstance(dfs[0], pd.Series): + return dd.concat([i.pipe(dd.from_pandas) for i in dfs], axis=1) + return dd.concat(dfs, axis=1) msg = f"Unknown implementation: {implementation}" # pragma: no cover raise TypeError(msg) # pragma: no cover @@ -171,6 +191,10 @@ def vertical_concat( mpd = get_modin() return mpd.concat(dfs, axis=0) + if implementation is Implementation.DASK: # pragma: no cover + dd = get_dask() + + return dd.concat(dfs, axis=0) msg = f"Unknown implementation: {implementation}" # pragma: no cover raise TypeError(msg) # pragma: no cover @@ -194,6 +218,17 @@ def native_series_from_iterable( mpd = get_modin() return mpd.Series(data, name=name, index=index) + if implementation is Implementation.DASK: # pragma: no cover + dd = get_dask() + pd = get_pandas() + if hasattr(data[0], "compute"): # type: ignore[index] + return dd.concat([i.to_series() for i in data]).rename(name) + return pd.Series( + data, + name=name, + index=index, + copy=False, + ).pipe(dd.from_pandas) msg = f"Unknown implementation: {implementation}" # pragma: no cover raise TypeError(msg) # pragma: no cover @@ -205,6 +240,10 @@ def set_axis( implementation: Implementation, backend_version: tuple[int, ...], ) -> T: + if implementation is Implementation.DASK: + return ( + obj # HACK: dask doesn't really reset indexes so much, so assuming its fine + ) if implementation is Implementation.PANDAS and backend_version < ( 1, ): # pragma: no cover @@ -281,6 +320,12 @@ def translate_dtype(column: Any) -> DType: if str(dtype) == "date32[day][pyarrow]": return dtypes.Date() if str(dtype) == "object": + if (dd := get_dask()) is not None and isinstance(column, dd.Series): + # below we'll try to infer strings or objects from values but + # with dask we can only do this if we compute so we'll avoid and + # treat as a string (this *may* be a bad call, as it does not allow + # for object types which are potentially valid) + return dtypes.String() if (idx := column.first_valid_index()) is not None and isinstance( column.loc[idx], str ): @@ -449,6 +494,8 @@ def to_datetime(implementation: Implementation) -> Any: return get_modin().to_datetime if implementation is Implementation.CUDF: return get_cudf().to_datetime + if implementation is Implementation.DASK: + return get_dask().to_datetime raise AssertionError @@ -486,3 +533,26 @@ def generate_unique_token(n_bytes: int, columns: list[str]) -> str: # pragma: n "join operation" ) raise AssertionError(msg) + + +def not_implemented_in( + *implementations: Implementation, +) -> Callable[[Callable], Callable]: # type: ignore[type-arg] + """ + Produces method decorator to raise not implemented warnings for given implementations + """ + + def check_implementation_wrapper(func: Callable) -> Callable: # type: ignore[type-arg] + """Wraps function to return same function + implementation check""" + + @wraps(func) + def wrapped_func(self, *args, **kwargs): # type: ignore[no-untyped-def] # noqa: ANN001, ANN002, ANN003, ANN202 + """Checks implementation then carries out wrapped call""" + if (implementation := self._implementation) in implementations: + msg = f"Not implemented in {implementation}" + raise NotImplementedError(msg) + return func(self, *args, **kwargs) + + return wrapped_func + + return check_implementation_wrapper diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 85130affc..43aab0a77 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -559,10 +559,13 @@ def __getitem__( @overload def to_dict(self, *, as_series: Literal[True] = ...) -> dict[str, Series]: ... + @overload def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload def to_dict(self, *, as_series: bool) -> dict[str, Series] | dict[str, list[Any]]: ... + def to_dict( self, *, as_series: bool = True ) -> dict[str, Series] | dict[str, list[Any]]: diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index cbb5eca3c..bf3a11b03 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -32,6 +32,11 @@ def get_modin() -> Any: # pragma: no cover return None +def get_dask() -> Any: + """Get dask.dataframe module (if already imported - else return None).""" + return sys.modules.get("dask.dataframe", None) + + def get_cudf() -> Any: """Get cudf module (if already imported - else return None).""" return sys.modules.get("cudf", None) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 45468d782..324109271 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -119,10 +119,13 @@ def lazy(self) -> LazyFrame[Any]: # thing that I need to understand category theory for @overload # type: ignore[override] def to_dict(self, *, as_series: Literal[True] = ...) -> dict[str, Series]: ... + @overload def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload def to_dict(self, *, as_series: bool) -> dict[str, Series] | dict[str, list[Any]]: ... + def to_dict( self, *, as_series: bool = True ) -> dict[str, Series] | dict[str, list[Any]]: @@ -450,12 +453,20 @@ class Schema(NwSchema): @overload def _stableify(obj: NwDataFrame[IntoFrameT]) -> DataFrame[IntoFrameT]: ... + + @overload def _stableify(obj: NwLazyFrame[IntoFrameT]) -> LazyFrame[IntoFrameT]: ... + + @overload def _stableify(obj: NwSeries) -> Series: ... + + @overload def _stableify(obj: NwExpr) -> Expr: ... + + @overload def _stableify(obj: Any) -> Any: ... @@ -685,6 +696,7 @@ def from_native( - pandas.DataFrame - polars.DataFrame - polars.LazyFrame + - dask.dataframe.DataFrame - anything with a `__narwhals_dataframe__` or `__narwhals_lazyframe__` method - pandas.Series - polars.Series diff --git a/narwhals/translate.py b/narwhals/translate.py index e13fb9b3c..c934f8711 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -9,6 +9,7 @@ from typing import overload from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars @@ -28,12 +29,18 @@ def to_native( narwhals_object: DataFrame[IntoDataFrameT], *, strict: Literal[True] = ... ) -> IntoDataFrameT: ... + + @overload def to_native( narwhals_object: LazyFrame[IntoFrameT], *, strict: Literal[True] = ... ) -> IntoFrameT: ... + + @overload def to_native(narwhals_object: Series, *, strict: Literal[True] = ...) -> Any: ... + + @overload def to_native(narwhals_object: Any, *, strict: bool) -> Any: ... @@ -271,6 +278,7 @@ def from_native( # noqa: PLR0915 - pandas.DataFrame - polars.DataFrame - polars.LazyFrame + - dask.dataframe.DataFrame - anything with a `__narwhals_dataframe__` or `__narwhals_lazyframe__` method - pandas.Series - polars.Series @@ -389,6 +397,22 @@ def from_native( # noqa: PLR0915 backend_version=parse_version(pa.__version__), level="full", ) + elif (dd := get_dask()) is not None and isinstance(native_object, dd.DataFrame): + if series_only: # pragma: no cover (todo) + msg = "Cannot only use `series_only` with dask.dataframe.DataFrame" + raise TypeError(msg) + import dask + + return DataFrame( + PandasLikeDataFrame( + native_object, + implementation=Implementation.DASK, + backend_version=parse_version(dask.__version__), + ), + is_polars=False, + backend_version=parse_version(dask.__version__), + level="full", + ) elif hasattr(native_object, "__dataframe__"): if eager_only or series_only: msg = ( @@ -494,6 +518,22 @@ def from_native( # noqa: PLR0915 backend_version=parse_version(pa.__version__), level="full", ) + elif (dd := get_dask()) is not None and isinstance(native_object, dd.Series): + if not allow_series: # pragma: no cover (todo) + msg = "Please set `allow_series=True`" + raise TypeError(msg) + import dask + + return Series( + PandasLikeSeries( + native_object, + implementation=Implementation.DASK, + backend_version=parse_version(dask.__version__), + ), + is_polars=False, + backend_version=parse_version(dask.__version__), + level="full", + ) elif hasattr(native_object, "__narwhals_series__"): if not allow_series: msg = "Please set `allow_series=True`" diff --git a/narwhals/utils.py b/narwhals/utils.py index 1f6a9074a..2d667c11a 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -312,9 +312,7 @@ def is_ordered_categorical(series: Series) -> bool: isinstance(series._compliant_series, InterchangeSeries) and series.dtype == dtypes.Categorical ): - return series._compliant_series._native_series.describe_categorical[ # type: ignore[no-any-return] - "is_ordered" - ] + return series._compliant_series._native_series.describe_categorical["is_ordered"] # type: ignore[no-any-return] if series.dtype == dtypes.Enum: return True if series.dtype != dtypes.Categorical: diff --git a/requirements-dev.txt b/requirements-dev.txt index a9d6f04d8..5d6820791 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,4 +8,4 @@ pytest pytest-cov hypothesis scikit-learn - +dask[dataframe] diff --git a/tests/conftest.py b/tests/conftest.py index 3b1646657..bed7ab26f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import contextlib from typing import Any from typing import Callable @@ -6,10 +7,15 @@ import pyarrow as pa import pytest +from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin from narwhals.typing import IntoDataFrame from narwhals.utils import parse_version +with contextlib.suppress(ImportError): + import dask.dataframe # noqa: F401 + import modin # noqa: F401 + def pytest_addoption(parser: Any) -> None: parser.addoption( @@ -48,6 +54,11 @@ def modin_constructor(obj: Any) -> IntoDataFrame: # pragma: no cover return mpd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return] +def dask_constructor(obj: Any) -> IntoDataFrame: + dd = get_dask() + return pd.DataFrame(obj).pipe(dd.from_pandas, npartitions=1) # type: ignore[no-any-return] + + def polars_eager_constructor(obj: Any) -> IntoDataFrame: return pl.DataFrame(obj) @@ -74,6 +85,9 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame: if get_modin() is not None: # pragma: no cover eager_constructors.append(modin_constructor) +if get_dask() is not None: + eager_constructors.append(dask_constructor) + @pytest.fixture(params=eager_constructors) def constructor(request: Any) -> Callable[[Any], IntoDataFrame]: @@ -102,6 +116,11 @@ def modin_series_constructor(obj: Any) -> Any: # pragma: no cover return mpd.Series(obj).convert_dtypes(dtype_backend="pyarrow") +def dask_series_constructor(obj: Any) -> Any: # pragma: no cover + dd = get_dask() + return dd.Series(obj) + + def polars_series_constructor(obj: Any) -> Any: return pl.Series(obj) @@ -120,6 +139,8 @@ def pyarrow_series_constructor(obj: Any) -> Any: params_series = [pandas_series_constructor] if get_modin() is not None: # pragma: no cover params_series.append(modin_series_constructor) +if get_dask() is not None: # pragma: no cover + params_series.append(dask_series_constructor) params_series.extend([polars_series_constructor, pyarrow_series_constructor]) diff --git a/tests/frame/test_common.py b/tests/frame/test_common.py index 03efae10a..d9003ecdd 100644 --- a/tests/frame/test_common.py +++ b/tests/frame/test_common.py @@ -10,9 +10,11 @@ import pytest import narwhals.stable.v1 as nw +from narwhals.dependencies import get_dask from narwhals.functions import _get_deps_info from narwhals.functions import _get_sys_info from narwhals.functions import show_versions +from tests.conftest import dask_constructor from tests.utils import compare_dicts data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} @@ -21,8 +23,10 @@ def test_empty_select(constructor: Any) -> None: - result = nw.from_native(constructor({"a": [1, 2, 3]}), eager_only=True).select() - assert result.shape == (0, 0) + result = nw.from_native(constructor({"a": [1, 2, 3]}), eager_only=True).select().shape + if constructor == dask_constructor and (dd := get_dask()) is not None: + result = dd.compute(result)[0] + assert result == (0, 0) def test_std(constructor: Any) -> None: diff --git a/tests/series_only/test_common.py b/tests/series_only/test_common.py index 5f53c6ad7..52855e77a 100644 --- a/tests/series_only/test_common.py +++ b/tests/series_only/test_common.py @@ -11,13 +11,25 @@ from pandas.testing import assert_series_equal import narwhals.stable.v1 as nw +from narwhals._pandas_like.utils import Implementation from narwhals.utils import parse_version +from tests.conftest import dask_series_constructor data = [1, 3, 2] data_dups = [4, 4, 6] data_sorted = [7.0, 8, 9] +def compute_if_dask(result: Any) -> Any: + if ( + hasattr(result, "_native_series") + and hasattr(result._native_series, "_implementation") + and result._series._implementation is Implementation.DASK + ): + return result.to_pandas() + return result + + def test_len(constructor_series: Any) -> None: series = nw.from_native(constructor_series(data), series_only=True) @@ -120,6 +132,7 @@ def test_is_duplicated(request: Any, constructor_series: Any) -> None: series = nw.from_native(constructor_series(data_dups), series_only=True) result = series.is_duplicated() + result = compute_if_dask(result) expected = np.array([True, True, False]) assert (result.to_numpy() == expected).all() @@ -193,6 +206,8 @@ def test_is_sorted( series = nw.from_native(constructor_series(input_data), series_only=True) result = series.is_sorted(descending=descending) + if constructor_series == dask_series_constructor: + result = result.compute() # type: ignore[attr-defined] assert result == expected @@ -227,9 +242,14 @@ def test_quantile( request.applymarker(pytest.mark.xfail) q = 0.3 + if is_dask_test := constructor_series == dask_series_constructor: + interpolation = "linear" # other interpolation unsupported in dask series = nw.from_native(constructor_series(data_sorted), allow_series=True) + result = series.quantile(quantile=q, interpolation=interpolation) # type: ignore[union-attr] + if is_dask_test: + result = result.compute() assert result == expected diff --git a/tests/tpch_q1_test.py b/tests/tpch_q1_test.py index da540e465..c38aae677 100644 --- a/tests/tpch_q1_test.py +++ b/tests/tpch_q1_test.py @@ -11,13 +11,14 @@ import pytest import narwhals.stable.v1 as nw +from narwhals.dependencies import get_dask from narwhals.utils import parse_version from tests.utils import compare_dicts @pytest.mark.parametrize( "library", - ["pandas", "polars", "pyarrow"], + ["pandas", "polars", "pyarrow", "dask"], ) @pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning") def test_q1(library: str, request: Any) -> None: @@ -28,6 +29,8 @@ def test_q1(library: str, request: Any) -> None: df_raw["l_shipdate"] = pd.to_datetime(df_raw["l_shipdate"]) elif library == "polars": df_raw = pl.scan_parquet("tests/data/lineitem.parquet") + elif library == "dask" and (dd := get_dask()) is not None: + df_raw = dd.read_parquet("tests/data/lineitem.parquet") else: df_raw = pq.read_table("tests/data/lineitem.parquet") var_1 = datetime(1998, 9, 2) @@ -85,7 +88,7 @@ def test_q1(library: str, request: Any) -> None: @pytest.mark.parametrize( "library", - ["pandas", "polars"], + ["pandas", "polars", "dask"], ) @pytest.mark.filterwarnings( "ignore:.*Passing a BlockManager.*:DeprecationWarning", @@ -97,8 +100,15 @@ def test_q1_w_generic_funcs(library: str, request: Any) -> None: elif library == "pandas": df_raw = pd.read_parquet("tests/data/lineitem.parquet") df_raw["l_shipdate"] = pd.to_datetime(df_raw["l_shipdate"]) - else: + elif library == "polars": df_raw = pl.read_parquet("tests/data/lineitem.parquet") + elif library == "dask" and (dd := get_dask()) is not None: + df_raw = dd.read_parquet("tests/data/lineitem.parquet") + df_raw["l_shipdate"] = dd.to_datetime(df_raw["l_shipdate"]) + else: + request.applymarker(pytest.mark.xfail) + df_raw = pq.read_table("tests/data/lineitem.parquet") + var_1 = datetime(1998, 9, 2) df = nw.from_native(df_raw, eager_only=True) query_result = ( diff --git a/tests/utils.py b/tests/utils.py index 6ab703c3b..47e9c1592 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,8 @@ import pandas as pd +from narwhals._pandas_like.utils import Implementation + def zip_strict(left: Sequence[Any], right: Sequence[Any]) -> Iterator[Any]: if len(left) != len(right): @@ -20,6 +22,15 @@ def zip_strict(left: Sequence[Any], right: Sequence[Any]) -> Iterator[Any]: def compare_dicts(result: Any, expected: dict[str, Any]) -> None: if hasattr(result, "collect"): result = result.collect() + if ( + hasattr(result, "_native_dataframe") + and hasattr(result._native_dataframe, "_implementation") + and result._dataframe._implementation is Implementation.DASK + ) or ( + hasattr(result, "__native_namespace__") + and "dask" in str(result.__native_namespace__()) + ): + result = result.to_pandas() if hasattr(result, "columns"): for key in result.columns: assert key in expected @@ -51,6 +62,17 @@ def maybe_get_modin_df(df_pandas: pd.DataFrame) -> Any: return mpd.DataFrame(df_pandas.to_dict(orient="list")) +def maybe_get_dask_df(df_pandas: pd.DataFrame) -> Any: + """Convert a pandas DataFrame to a Dask Dataframe if Dask is available.""" + try: + import dask.dataframe as dd + + except ImportError: + return df_pandas.copy() + else: + return dd.from_pandas(df_pandas, npartitions=1) + + def is_windows() -> bool: """Check if the current platform is Windows.""" return sys.platform in ["win32", "cygwin"]