Skip to content

Commit

Permalink
fix: nw.lit(date, dtype=nw.Date), loosen Dask minimum back to 2024.8 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jan 5, 2025
1 parent 31158b2 commit 19418cf
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 23 deletions.
2 changes: 2 additions & 0 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> An
return "category"
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return "datetime64[us]"
if isinstance_or_issubclass(dtype, dtypes.Date):
return "date32[day][pyarrow]"
if isinstance_or_issubclass(dtype, dtypes.Duration):
return "timedelta64[ns]"
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
Expand Down
9 changes: 5 additions & 4 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,10 +637,11 @@ def narwhals_to_native_dtype( # noqa: PLR0915
else f"timedelta64[{du_time_unit}]"
)
if isinstance_or_issubclass(dtype, dtypes.Date):
if dtype_backend == "pyarrow-nullable":
return "date32[pyarrow]"
msg = "Date dtype only supported for pyarrow-backed data types in pandas"
raise NotImplementedError(msg)
try:
import pyarrow as pa # ignore-banned-import
except ModuleNotFoundError: # pragma: no cover
msg = "PyArrow>=11.0.0 is required for `Date` dtype."
return "date32[pyarrow]"
if isinstance_or_issubclass(dtype, dtypes.Enum):
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7023,7 +7023,7 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
)


def lit(value: Any, dtype: DType | None = None) -> Expr:
def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr:
"""Return an expression representing a literal value.
Arguments:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,7 +2542,7 @@ def len() -> Expr:
return _stableify(nw.len())


def lit(value: Any, dtype: DType | None = None) -> Expr:
def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr:
"""Return an expression representing a literal value.
Arguments:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def is_ibis(self) -> bool:
Implementation.PYARROW: (11,),
Implementation.PYSPARK: (3, 3),
Implementation.POLARS: (0, 20, 3),
Implementation.DASK: (2024, 10),
Implementation.DASK: (2024, 8),
Implementation.DUCKDB: (1,),
Implementation.IBIS: (6,),
}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ cudf = ["cudf>=24.10.0"]
pyarrow = ["pyarrow>=11.0.0"]
pyspark = ["pyspark>=3.3.0"]
polars = ["polars>=0.20.3"]
dask = ["dask[dataframe]>=2024.10"]
dask = ["dask[dataframe]>=2024.8"]
duckdb = ["duckdb>=1.0"]
ibis = ["ibis-framework>=6.0.0", "rich", "packaging", "pyarrow_hotfix"]
dev = [
Expand Down
5 changes: 5 additions & 0 deletions tests/expr_and_series/arithmetic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from hypothesis import given

import narwhals.stable.v1 as nw
from tests.utils import DASK_VERSION
from tests.utils import PANDAS_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
Expand Down Expand Up @@ -67,6 +68,8 @@ def test_right_arithmetic_expr(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if "dask" in str(constructor) and DASK_VERSION < (2024, 10):
request.applymarker(pytest.mark.xfail)
if attr == "__rmod__" and any(
x in str(constructor) for x in ["pandas_pyarrow", "modin_pyarrow"]
):
Expand Down Expand Up @@ -241,6 +244,8 @@ def test_arithmetic_expr_left_literal(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if "dask" in str(constructor) and DASK_VERSION < (2024, 10):
request.applymarker(pytest.mark.xfail)
if attr == "__mod__" and any(
x in str(constructor) for x in ["pandas_pyarrow", "modin_pyarrow"]
):
Expand Down
7 changes: 6 additions & 1 deletion tests/expr_and_series/binary_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import DASK_VERSION
from tests.utils import Constructor
from tests.utils import assert_equal_data


def test_expr_binary(constructor: Constructor) -> None:
def test_expr_binary(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "dask" in str(constructor) and DASK_VERSION < (2024, 10):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = constructor(data)
result = nw.from_native(df_raw).with_columns(
Expand Down
20 changes: 20 additions & 0 deletions tests/frame/lit_test.py → tests/expr_and_series/lit_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from datetime import date
from typing import TYPE_CHECKING
from typing import Any

import numpy as np
import pytest

import narwhals.stable.v1 as nw
from tests.utils import DASK_VERSION
from tests.utils import PANDAS_VERSION
from tests.utils import Constructor
from tests.utils import assert_equal_data

Expand Down Expand Up @@ -82,10 +85,27 @@ def test_lit_operation(
col_name: str,
expr: nw.Expr,
expected_result: list[int],
request: pytest.FixtureRequest,
) -> None:
if (
"dask" in str(constructor)
and col_name in ("left_lit", "left_scalar")
and DASK_VERSION < (2024, 10)
):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2]}
df_raw = constructor(data)
df = nw.from_native(df_raw).lazy()
result = df.select(expr.alias(col_name))
expected = {col_name: expected_result}
assert_equal_data(result, expected)


@pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow")
def test_date_lit(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "dask" in str(constructor):
# https://github.com/dask/dask/issues/11637
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor({"a": [1]}))
result = df.with_columns(nw.lit(date(2020, 1, 1), dtype=nw.Date)).collect_schema()
assert result == {"a": nw.Int64, "literal": nw.Date}
12 changes: 11 additions & 1 deletion tests/expr_and_series/operators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import narwhals.stable.v1 as nw
from tests.utils import DASK_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data
Expand Down Expand Up @@ -75,8 +76,17 @@ def test_logic_operators_expr(
],
)
def test_logic_operators_expr_scalar(
constructor: Constructor, operator: str, expected: list[bool]
constructor: Constructor,
operator: str,
expected: list[bool],
request: pytest.FixtureRequest,
) -> None:
if (
"dask" in str(constructor)
and DASK_VERSION < (2024, 10)
and operator in ("__rand__", "__ror__")
):
request.applymarker(pytest.mark.xfail)
data = {"a": [True, True, False, False]}
df = nw.from_native(constructor(data))

Expand Down
2 changes: 1 addition & 1 deletion tests/frame/select_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_missing_columns(constructor: Constructor) -> None:
def test_left_to_right_broadcasting(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "dask" in str(constructor) and DASK_VERSION < (2024, 9):
if "dask" in str(constructor) and DASK_VERSION < (2024, 10):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor({"a": [1, 1, 2], "b": [4, 5, 6]}))
result = df.select(nw.col("a") + nw.col("b").sum())
Expand Down
12 changes: 0 additions & 12 deletions tests/series_only/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,6 @@ def test_cast_date_datetime_pandas() -> None:
assert df.schema == {"a": nw.Date}


@pytest.mark.skipif(
PANDAS_VERSION < (2, 0, 0),
reason="pyarrow dtype not available",
)
def test_cast_date_datetime_invalid() -> None:
# pandas: pyarrow datetime to date
dfpd = pd.DataFrame({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]})
df = nw.from_native(dfpd)
with pytest.raises(NotImplementedError, match="pyarrow"):
df.select(nw.col("a").cast(nw.Date))


@pytest.mark.filterwarnings("ignore: casting period")
def test_unknown_to_int() -> None:
df = pd.DataFrame({"a": pd.period_range("2000", periods=3, freq="min")})
Expand Down
3 changes: 3 additions & 0 deletions tests/tpch_q1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

import narwhals.stable.v1 as nw
from tests.utils import DASK_VERSION
from tests.utils import PANDAS_VERSION
from tests.utils import assert_equal_data

Expand All @@ -20,6 +21,8 @@
)
@pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning")
def test_q1(library: str, request: pytest.FixtureRequest) -> None:
if library == "dask" and DASK_VERSION < (2024, 10):
request.applymarker(pytest.mark.xfail)
if library == "pandas" and PANDAS_VERSION < (1, 5):
request.applymarker(pytest.mark.xfail)
elif library == "pandas":
Expand Down

0 comments on commit 19418cf

Please sign in to comment.