Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Sep 12, 2024
1 parent aed2d51 commit 121f6f8
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 18 deletions.
8 changes: 5 additions & 3 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def translate_dtype(dtype: Any) -> dtypes.DType:
if pa.types.is_date32(dtype):
return dtypes.Date()
if pa.types.is_timestamp(dtype):
return dtypes.Datetime()
return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz)
if pa.types.is_duration(dtype):
return dtypes.Duration()
if pa.types.is_dictionary(dtype):
Expand Down Expand Up @@ -88,8 +88,10 @@ def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
# with Polars for now
return pa.dictionary(pa.uint32(), pa.string())
if isinstance_or_issubclass(dtype, dtypes.Datetime):
# Use Polars' default
return pa.timestamp("us")
time_unit = getattr(dtype, "time_unit", "us")
time_zone = getattr(dtype, "time_zone", None)
return pa.timestamp(time_unit, tz=time_zone)

if isinstance_or_issubclass(dtype, dtypes.Duration):
# Use Polars' default
return pa.duration("us")
Expand Down
33 changes: 23 additions & 10 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Literal
from typing import TypeVar

from narwhals.dependencies import get_cudf
Expand Down Expand Up @@ -221,6 +223,12 @@ def translate_dtype(column: Any) -> DType:
from narwhals import dtypes

dtype = column.dtype

pd_datetime_rgx = (
r"^datetime64\[(?P<time_unit>ms|us|ns)(?:, (?P<time_zone>[a-zA-Z\/]+))?\]$"
)
pa_datetime_rgx = r"^timestamp\[(?P<time_unit>ms|us|ns)(?:, tz=(?P<time_zone>[a-zA-Z\/]+))?\]\[pyarrow\]$"

if str(dtype) in ("int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"):
return dtypes.Int64()
if str(dtype) in ("int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"):
Expand Down Expand Up @@ -264,16 +272,15 @@ def translate_dtype(column: Any) -> DType:
return dtypes.Boolean()
if str(dtype) in ("category",) or str(dtype).startswith("dictionary<"):
return dtypes.Categorical()
if str(dtype).startswith("datetime64"):
# TODO(Unassigned): different time units and time zones
return dtypes.Datetime()
if (match_ := re.match(pd_datetime_rgx, str(dtype))) or (
match_ := re.match(pa_datetime_rgx, str(dtype))
):
time_unit: Literal["us", "ns", "ms"] = match_.group("time_unit") # type: ignore[assignment]
time_zone: str | None = match_.group("time_zone")
return dtypes.Datetime(time_unit, time_zone)
if str(dtype).startswith("timedelta64") or str(dtype).startswith("duration"):
# TODO(Unassigned): different time units
return dtypes.Duration()
if str(dtype).startswith("timestamp["):
# pyarrow-backed datetime
# TODO(Unassigned): different time units and time zones
return dtypes.Datetime()
if str(dtype) == "date32[day][pyarrow]":
return dtypes.Date()
if str(dtype) == "object":
Expand Down Expand Up @@ -425,10 +432,16 @@ def narwhals_to_native_dtype( # noqa: PLR0915
# convert to it?
return "category"
if isinstance_or_issubclass(dtype, dtypes.Datetime):
# TODO(Unassigned): different time units and time zones
time_unit = getattr(dtype, "time_unit", "us")
time_zone = getattr(dtype, "time_zone", None)

if dtype_backend == "pyarrow-nullable":
return "timestamp[ns][pyarrow]"
return "datetime64[ns]"
tz_part = f", tz={time_zone}" if time_zone else ""
return f"timestamp[{time_unit}{tz_part}][pyarrow]"
else:
tz_part = f", {time_zone}" if time_zone else ""
return f"datetime64[{time_unit}{tz_part}]"

if isinstance_or_issubclass(dtype, dtypes.Duration):
# TODO(Unassigned): different time units and time zones
if dtype_backend == "pyarrow-nullable":
Expand Down
14 changes: 10 additions & 4 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

from typing import Any
from typing import Literal

from narwhals import dtypes
from narwhals.dependencies import get_polars
from narwhals.utils import isinstance_or_issubclass


def extract_native(obj: Any) -> Any:
Expand Down Expand Up @@ -59,8 +61,10 @@ def translate_dtype(dtype: Any) -> dtypes.DType:
return dtypes.Categorical()
if dtype == pl.Enum:
return dtypes.Enum()
if dtype == pl.Datetime:
return dtypes.Datetime()
if isinstance_or_issubclass(dtype, pl.Datetime):
time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
time_zone = getattr(dtype, "time_zone", None)
return dtypes.Datetime(time_unit=time_unit, time_zone=time_zone)
if dtype == pl.Duration:
return dtypes.Duration()
if dtype == pl.Date:
Expand Down Expand Up @@ -103,8 +107,10 @@ def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
if dtype == dtypes.Enum:
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
if dtype == dtypes.Datetime:
return pl.Datetime()
if isinstance_or_issubclass(dtype, dtypes.Datetime):
time_unit = getattr(dtype, "time_unit", "us")
time_zone = getattr(dtype, "time_zone", None)
return pl.Datetime(time_unit, time_zone)
if dtype == dtypes.Duration:
return pl.Duration()
if dtype == dtypes.Date:
Expand Down
36 changes: 35 additions & 1 deletion narwhals/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from datetime import timezone
from typing import TYPE_CHECKING
from typing import Literal

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -71,7 +73,39 @@ class Object(DType): ...
class Unknown(DType): ...


class Datetime(TemporalType): ...
class Datetime(TemporalType):
"""
Data type representing a calendar date and time of day.
Arguments:
time_unit: Unit of time. Defaults to `'us'` (microseconds).
time_zone: Time zone string, as defined in zoneinfo (to see valid strings run
`import zoneinfo; zoneinfo.available_timezones()` for a full list).
When used to match dtypes, can set this to "*" to check for Datetime
columns that have any (non-null) timezone.
Notes:
Adapted from Polars implementation at:
https://github.com/pola-rs/polars/blob/py-1.7.1/py-polars/polars/datatypes/classes.py#L398-L457
"""

def __init__(
self: Self,
time_unit: Literal["us", "ns", "ms"] = "us",
time_zone: str | timezone | None = None,
) -> None:
if time_unit not in {"ms", "us", "ns"}:
msg = (
"invalid `time_unit`"
f"\n\nExpected one of {{'ns','us','ms'}}, got {time_unit!r}."
)
raise ValueError(msg)

if isinstance(time_zone, timezone):
time_zone = str(time_zone)

self.time_unit = time_unit
self.time_zone = time_zone


class Duration(TemporalType): ...
Expand Down
30 changes: 30 additions & 0 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from __future__ import annotations

from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any

import pandas as pd
Expand All @@ -6,6 +11,7 @@

import narwhals.stable.v1 as nw
from narwhals.utils import parse_version
from tests.utils import compare_dicts

data = {
"a": [1],
Expand Down Expand Up @@ -179,3 +185,27 @@ class Banana:

with pytest.raises(AssertionError, match=r"Unknown dtype"):
df.select(nw.col("a").cast(Banana))


def test_cast_datetime_tz_aware(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

data = {
"date": [
datetime(2024, 1, 1, tzinfo=timezone.utc) + timedelta(days=i)
for i in range(3)
]
}
expected = {
"date": ["2024-01-01 01:00:00", "2024-01-02 01:00:00", "2024-01-03 01:00:00"]
}

df = nw.from_native(constructor(data))
result = df.select(
nw.col("date")
.cast(nw.Datetime("ms", time_zone="Europe/Rome"))
.cast(nw.String())
.str.slice(offset=0, length=19)
)
compare_dicts(result, expected)

0 comments on commit 121f6f8

Please sign in to comment.