-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Datetime(time_unit, time_zone)
and Duration(time_unit)
types
#960
Changes from 15 commits
121f6f8
4896df2
cd2ed40
eb1468e
e71f9c3
32385d0
3abeaf8
c5b7635
4415e3c
5309d4f
85fdd80
91bfb7a
20e36a1
ec1cb5e
2147ec6
0f69ec1
22836a0
a1f56bc
a84480d
80a574d
916eac5
e94b517
180b86e
da884e8
114be74
587d917
dd050a8
b4de1f7
458f2a2
34c27ef
0de71a6
d105911
a773d85
0149431
2249af0
942a77b
ad38667
38898a8
43da4c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
from __future__ import annotations | ||
|
||
import re | ||
from typing import TYPE_CHECKING | ||
from typing import Any | ||
from typing import Iterable | ||
from typing import Literal | ||
from typing import TypeVar | ||
|
||
from narwhals.dependencies import get_cudf | ||
|
@@ -221,6 +223,15 @@ def translate_dtype(column: Any) -> DType: | |
from narwhals import dtypes | ||
|
||
dtype = column.dtype | ||
|
||
pd_datetime_rgx = ( | ||
r"^datetime64\[(?P<time_unit>ms|us|ns)(?:, (?P<time_zone>[a-zA-Z\/]+))?\]$" | ||
) | ||
pa_datetime_rgx = r"^timestamp\[(?P<time_unit>ms|us|ns)(?:, tz=(?P<time_zone>[a-zA-Z\/]+))?\]\[pyarrow\]$" | ||
|
||
pd_duration_rgx = r"^timedelta64\[(?P<time_unit>ms|us|ns)\]$" | ||
pa_duration_rgx = r"^duration\[(?P<time_unit>ms|us|ns)\]\[pyarrow\]$" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pandas / pyarrow support 'second' time unit, I think that should be allowed to pass through There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just by passing it along or doing manipulation for the user? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we can just pass it through - adding a commit soon |
||
|
||
if str(dtype) in ("int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"): | ||
return dtypes.Int64() | ||
if str(dtype) in ("int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"): | ||
|
@@ -264,16 +275,17 @@ def translate_dtype(column: Any) -> DType: | |
return dtypes.Boolean() | ||
if str(dtype) in ("category",) or str(dtype).startswith("dictionary<"): | ||
return dtypes.Categorical() | ||
if str(dtype).startswith("datetime64"): | ||
# TODO(Unassigned): different time units and time zones | ||
return dtypes.Datetime() | ||
if str(dtype).startswith("timedelta64") or str(dtype).startswith("duration"): | ||
# TODO(Unassigned): different time units | ||
return dtypes.Duration() | ||
if str(dtype).startswith("timestamp["): | ||
# pyarrow-backed datetime | ||
# TODO(Unassigned): different time units and time zones | ||
return dtypes.Datetime() | ||
if (match_ := re.match(pd_datetime_rgx, str(dtype))) or ( | ||
match_ := re.match(pa_datetime_rgx, str(dtype)) | ||
): | ||
dt_time_unit: Literal["us", "ns", "ms"] = match_.group("time_unit") # type: ignore[assignment] | ||
dt_time_zone: str | None = match_.group("time_zone") | ||
return dtypes.Datetime(dt_time_unit, dt_time_zone) | ||
if (match_ := re.match(pd_duration_rgx, str(dtype))) or ( | ||
match_ := re.match(pa_duration_rgx, str(dtype)) | ||
): | ||
du_time_unit: Literal["us", "ns", "ms"] = match_.group("time_unit") # type: ignore[assignment] | ||
return dtypes.Duration(du_time_unit) | ||
if str(dtype) == "date32[day][pyarrow]": | ||
return dtypes.Date() | ||
if str(dtype) == "object": | ||
|
@@ -321,7 +333,10 @@ def get_dtype_backend(dtype: Any, implementation: Implementation) -> str: | |
|
||
|
||
def narwhals_to_native_dtype( # noqa: PLR0915 | ||
dtype: DType | type[DType], starting_dtype: Any, implementation: Implementation | ||
dtype: DType | type[DType], | ||
starting_dtype: Any, | ||
implementation: Implementation, | ||
backend_version: tuple[int, ...], | ||
) -> Any: | ||
from narwhals import dtypes | ||
|
||
|
@@ -425,15 +440,32 @@ def narwhals_to_native_dtype( # noqa: PLR0915 | |
# convert to it? | ||
return "category" | ||
if isinstance_or_issubclass(dtype, dtypes.Datetime): | ||
# TODO(Unassigned): different time units and time zones | ||
dt_time_unit = getattr(dtype, "time_unit", "us") | ||
dt_time_zone = getattr(dtype, "time_zone", None) | ||
|
||
# Pandas does not support "ms" or "us" time units before version 1.5.0 | ||
# Let's overwrite with "ns" | ||
if implementation is Implementation.PANDAS and backend_version < ( | ||
1, | ||
5, | ||
0, | ||
): # pragma: no cover | ||
dt_time_unit = "ns" | ||
MarcoGorelli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if dtype_backend == "pyarrow-nullable": | ||
return "timestamp[ns][pyarrow]" | ||
return "datetime64[ns]" | ||
tz_part = f", tz={dt_time_zone}" if dt_time_zone else "" | ||
return f"timestamp[{dt_time_unit}{tz_part}][pyarrow]" | ||
else: | ||
tz_part = f", {dt_time_zone}" if dt_time_zone else "" | ||
return f"datetime64[{dt_time_unit}{tz_part}]" | ||
if isinstance_or_issubclass(dtype, dtypes.Duration): | ||
# TODO(Unassigned): different time units and time zones | ||
if dtype_backend == "pyarrow-nullable": | ||
return "duration[ns][pyarrow]" | ||
return "timedelta64[ns]" | ||
du_time_unit = getattr(dtype, "time_unit", "us") | ||
return ( | ||
f"duration[{du_time_unit}][pyarrow]" | ||
if dtype_backend == "pyarrow-nullable" | ||
else f"timedelta64[{du_time_unit}]" | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to do the same pre-1.5.0 check here |
||
|
||
if isinstance_or_issubclass(dtype, dtypes.Date): | ||
if dtype_backend == "pyarrow-nullable": | ||
return "date32[pyarrow]" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ pyarrow | |
pytest | ||
pytest-cov | ||
pytest-env | ||
pytz | ||
hypothesis | ||
scikit-learn | ||
typing_extensions | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from __future__ import annotations | ||
|
||
from datetime import timezone | ||
from typing import Literal | ||
|
||
import pytest | ||
|
||
import narwhals.stable.v1 as nw | ||
|
||
|
||
@pytest.mark.parametrize("time_unit", ["us", "ns", "ms"]) | ||
@pytest.mark.parametrize("time_zone", ["Europe/Rome", timezone.utc, None]) | ||
def test_datetime_valid( | ||
time_unit: Literal["us", "ns", "ms"], time_zone: str | timezone | None | ||
) -> None: | ||
dtype = nw.Datetime(time_unit=time_unit, time_zone=time_zone) | ||
|
||
assert dtype == nw.Datetime(time_unit=time_unit, time_zone=time_zone) | ||
assert dtype == nw.Datetime | ||
|
||
if time_zone: | ||
assert dtype != nw.Datetime(time_unit=time_unit) | ||
if time_unit != "ms": | ||
assert dtype != nw.Datetime(time_unit="ms") | ||
|
||
|
||
@pytest.mark.parametrize("time_unit", ["abc", "s"]) | ||
def test_datetime_invalid(time_unit: str) -> None: | ||
with pytest.raises(ValueError, match="invalid `time_unit`"): | ||
nw.Datetime(time_unit=time_unit) # type: ignore[arg-type] | ||
|
||
|
||
@pytest.mark.parametrize("time_unit", ["us", "ns", "ms"]) | ||
def test_duration_valid(time_unit: Literal["us", "ns", "ms"]) -> None: | ||
dtype = nw.Duration(time_unit=time_unit) | ||
|
||
assert dtype == nw.Duration(time_unit=time_unit) | ||
assert dtype == nw.Duration | ||
|
||
if time_unit != "ms": | ||
assert dtype != nw.Duration(time_unit="ms") | ||
|
||
|
||
@pytest.mark.parametrize("time_unit", ["abc", "s"]) | ||
def test_duration_invalid(time_unit: str) -> None: | ||
with pytest.raises(ValueError, match="invalid `time_unit`"): | ||
nw.Duration(time_unit=time_unit) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please try to break these π