Skip to content

Commit

Permalink
fix: nw.lit(date, dtype=nw.Date)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jan 5, 2025
1 parent 31158b2 commit f4df861
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 18 deletions.
5 changes: 1 addition & 4 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,10 +637,7 @@ def narwhals_to_native_dtype( # noqa: PLR0915
else f"timedelta64[{du_time_unit}]"
)
if isinstance_or_issubclass(dtype, dtypes.Date):
if dtype_backend == "pyarrow-nullable":
return "date32[pyarrow]"
msg = "Date dtype only supported for pyarrow-backed data types in pandas"
raise NotImplementedError(msg)
return "date32[pyarrow]"
if isinstance_or_issubclass(dtype, dtypes.Enum):
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7023,7 +7023,7 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
)


def lit(value: Any, dtype: DType | None = None) -> Expr:
def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr:
"""Return an expression representing a literal value.
Arguments:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,7 +2542,7 @@ def len() -> Expr:
return _stableify(nw.len())


def lit(value: Any, dtype: DType | None = None) -> Expr:
def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr:
"""Return an expression representing a literal value.
Arguments:
Expand Down
7 changes: 7 additions & 0 deletions tests/frame/lit_test.py → tests/expr_and_series/lit_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from datetime import date
from typing import TYPE_CHECKING
from typing import Any

Expand Down Expand Up @@ -89,3 +90,9 @@ def test_lit_operation(
result = df.select(expr.alias(col_name))
expected = {col_name: expected_result}
assert_equal_data(result, expected)


def test_date_lit(constructor: Constructor) -> None:
df = nw.from_native(constructor({"a": [1]}))
result = df.with_columns(nw.lit(date(2020, 1, 1), dtype=nw.Date)).collect_schema()
assert result == {"a": nw.Int64, "literal": nw.Date}
12 changes: 0 additions & 12 deletions tests/series_only/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,6 @@ def test_cast_date_datetime_pandas() -> None:
assert df.schema == {"a": nw.Date}


@pytest.mark.skipif(
PANDAS_VERSION < (2, 0, 0),
reason="pyarrow dtype not available",
)
def test_cast_date_datetime_invalid() -> None:
# pandas: pyarrow datetime to date
dfpd = pd.DataFrame({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]})
df = nw.from_native(dfpd)
with pytest.raises(NotImplementedError, match="pyarrow"):
df.select(nw.col("a").cast(nw.Date))


@pytest.mark.filterwarnings("ignore: casting period")
def test_unknown_to_int() -> None:
df = pd.DataFrame({"a": pd.period_range("2000", periods=3, freq="min")})
Expand Down

0 comments on commit f4df861

Please sign in to comment.