diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index b8294839c..55b5c360e 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -46,7 +46,7 @@ def translate_dtype(dtype: Any) -> dtypes.DType: if pa.types.is_date32(dtype): return dtypes.Date() if pa.types.is_timestamp(dtype): - return dtypes.Datetime() + return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz) if pa.types.is_duration(dtype): return dtypes.Duration() if pa.types.is_dictionary(dtype): @@ -88,8 +88,10 @@ def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any: # with Polars for now return pa.dictionary(pa.uint32(), pa.string()) if isinstance_or_issubclass(dtype, dtypes.Datetime): - # Use Polars' default - return pa.timestamp("us") + time_unit = getattr(dtype, "time_unit", "us") + time_zone = getattr(dtype, "time_zone", None) + return pa.timestamp(time_unit, tz=time_zone) + if isinstance_or_issubclass(dtype, dtypes.Duration): # Use Polars' default return pa.duration("us") diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 9e1d79ce9..0a3981734 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -1,8 +1,10 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING from typing import Any from typing import Iterable +from typing import Literal from typing import TypeVar from narwhals.dependencies import get_cudf @@ -221,6 +223,12 @@ def translate_dtype(column: Any) -> DType: from narwhals import dtypes dtype = column.dtype + + pd_datetime_rgx = ( + r"^datetime64\[(?Pms|us|ns)(?:, (?P[a-zA-Z\/]+))?\]$" + ) + pa_datetime_rgx = r"^timestamp\[(?Pms|us|ns)(?:, tz=(?P[a-zA-Z\/]+))?\]\[pyarrow\]$" + if str(dtype) in ("int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"): return dtypes.Int64() if str(dtype) in ("int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"): @@ -264,16 +272,15 @@ def translate_dtype(column: Any) -> DType: return dtypes.Boolean() if str(dtype) in ("category",) or str(dtype).startswith("dictionary<"): return dtypes.Categorical() - if str(dtype).startswith("datetime64"): - # TODO(Unassigned): different time units and time zones - return dtypes.Datetime() + if (match_ := re.match(pd_datetime_rgx, str(dtype))) or ( + match_ := re.match(pa_datetime_rgx, str(dtype)) + ): + time_unit: Literal["us", "ns", "ms"] = match_.group("time_unit") # type: ignore[assignment] + time_zone: str | None = match_.group("time_zone") + return dtypes.Datetime(time_unit, time_zone) if str(dtype).startswith("timedelta64") or str(dtype).startswith("duration"): # TODO(Unassigned): different time units return dtypes.Duration() - if str(dtype).startswith("timestamp["): - # pyarrow-backed datetime - # TODO(Unassigned): different time units and time zones - return dtypes.Datetime() if str(dtype) == "date32[day][pyarrow]": return dtypes.Date() if str(dtype) == "object": @@ -425,10 +432,16 @@ def narwhals_to_native_dtype( # noqa: PLR0915 # convert to it? return "category" if isinstance_or_issubclass(dtype, dtypes.Datetime): - # TODO(Unassigned): different time units and time zones + time_unit = getattr(dtype, "time_unit", "us") + time_zone = getattr(dtype, "time_zone", None) + if dtype_backend == "pyarrow-nullable": - return "timestamp[ns][pyarrow]" - return "datetime64[ns]" + tz_part = f", tz={time_zone}" if time_zone else "" + return f"timestamp[{time_unit}{tz_part}][pyarrow]" + else: + tz_part = f", {time_zone}" if time_zone else "" + return f"datetime64[{time_unit}{tz_part}]" + if isinstance_or_issubclass(dtype, dtypes.Duration): # TODO(Unassigned): different time units and time zones if dtype_backend == "pyarrow-nullable": diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 51f0b1898..dc8696fa6 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -1,9 +1,11 @@ from __future__ import annotations from typing import Any +from typing import Literal from narwhals import dtypes from narwhals.dependencies import get_polars +from narwhals.utils import isinstance_or_issubclass def extract_native(obj: Any) -> Any: @@ -59,8 +61,10 @@ def translate_dtype(dtype: Any) -> dtypes.DType: return dtypes.Categorical() if dtype == pl.Enum: return dtypes.Enum() - if dtype == pl.Datetime: - return dtypes.Datetime() + if isinstance_or_issubclass(dtype, pl.Datetime): + time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us") + time_zone = getattr(dtype, "time_zone", None) + return dtypes.Datetime(time_unit=time_unit, time_zone=time_zone) if dtype == pl.Duration: return dtypes.Duration() if dtype == pl.Date: @@ -103,8 +107,10 @@ def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any: if dtype == dtypes.Enum: msg = "Converting to Enum is not (yet) supported" raise NotImplementedError(msg) - if dtype == dtypes.Datetime: - return pl.Datetime() + if isinstance_or_issubclass(dtype, dtypes.Datetime): + time_unit = getattr(dtype, "time_unit", "us") + time_zone = getattr(dtype, "time_zone", None) + return pl.Datetime(time_unit, time_zone) if dtype == dtypes.Duration: return pl.Duration() if dtype == dtypes.Date: diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 4d8da4293..157967938 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -1,6 +1,8 @@ from __future__ import annotations +from datetime import timezone from typing import TYPE_CHECKING +from typing import Literal if TYPE_CHECKING: from typing_extensions import Self @@ -71,7 +73,39 @@ class Object(DType): ... class Unknown(DType): ... -class Datetime(TemporalType): ... +class Datetime(TemporalType): + """ + Data type representing a calendar date and time of day. + + Arguments: + time_unit: Unit of time. Defaults to `'us'` (microseconds). + time_zone: Time zone string, as defined in zoneinfo (to see valid strings run + `import zoneinfo; zoneinfo.available_timezones()` for a full list). + When used to match dtypes, can set this to "*" to check for Datetime + columns that have any (non-null) timezone. + + Notes: + Adapted from Polars implementation at: + https://github.com/pola-rs/polars/blob/py-1.7.1/py-polars/polars/datatypes/classes.py#L398-L457 + """ + + def __init__( + self: Self, + time_unit: Literal["us", "ns", "ms"] = "us", + time_zone: str | timezone | None = None, + ) -> None: + if time_unit not in {"ms", "us", "ns"}: + msg = ( + "invalid `time_unit`" + f"\n\nExpected one of {{'ns','us','ms'}}, got {time_unit!r}." + ) + raise ValueError(msg) + + if isinstance(time_zone, timezone): + time_zone = str(time_zone) + + self.time_unit = time_unit + self.time_zone = time_zone class Duration(TemporalType): ... diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 0b496d7ae..f16f46ff9 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +from datetime import datetime +from datetime import timedelta +from datetime import timezone from typing import Any import pandas as pd @@ -6,6 +11,7 @@ import narwhals.stable.v1 as nw from narwhals.utils import parse_version +from tests.utils import compare_dicts data = { "a": [1], @@ -179,3 +185,27 @@ class Banana: with pytest.raises(AssertionError, match=r"Unknown dtype"): df.select(nw.col("a").cast(Banana)) + + +def test_cast_datetime_tz_aware(constructor: Any, request: Any) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + + data = { + "date": [ + datetime(2024, 1, 1, tzinfo=timezone.utc) + timedelta(days=i) + for i in range(3) + ] + } + expected = { + "date": ["2024-01-01 01:00:00", "2024-01-02 01:00:00", "2024-01-03 01:00:00"] + } + + df = nw.from_native(constructor(data)) + result = df.select( + nw.col("date") + .cast(nw.Datetime("ms", time_zone="Europe/Rome")) + .cast(nw.String()) + .str.slice(offset=0, length=19) + ) + compare_dicts(result, expected)