Skip to content

Commit

Permalink
sort out np.datetime64 pyscalar
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Nov 2, 2024
1 parent 5c3db5b commit 7e3300e
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 23 deletions.
6 changes: 2 additions & 4 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,8 @@ def native_to_narwhals_dtype(
) -> DType:
dtype = str(native_column.dtype)

pd_datetime_rgx = (
r"^datetime64\[(?P<time_unit>s|ms|us|ns)(?:, (?P<time_zone>[a-zA-Z\/]+))?\]$"
)
pa_datetime_rgx = r"^timestamp\[(?P<time_unit>s|ms|us|ns)(?:, tz=(?P<time_zone>[a-zA-Z\/]+))?\]\[pyarrow\]$"
pd_datetime_rgx = r"^datetime64\[(?P<time_unit>s|ms|us|ns)(?:, (?P<time_zone>[a-zA-Z\/]+(?:[+-]\d{2}:\d{2})?))?\]$"
pa_datetime_rgx = r"^timestamp\[(?P<time_unit>s|ms|us|ns)(?:, tz=(?P<time_zone>[a-zA-Z\/]*(?:[+-]\d{2}:\d{2})?))?\]\[pyarrow\]$"

pd_duration_rgx = r"^timedelta64\[(?P<time_unit>s|ms|us|ns)\]$"
pa_duration_rgx = r"^duration\[(?P<time_unit>s|ms|us|ns)\]\[pyarrow\]$"
Expand Down
7 changes: 7 additions & 0 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,13 @@ def to_py_scalar(scalar_like: Any) -> Any:
return scalar_like

np = get_numpy()
if (
np
and isinstance(scalar_like, np.datetime64)
and scalar_like.dtype == "datetime64[ns]"
):
return datetime(1970, 1, 1) + timedelta(microseconds=scalar_like.item() // 1000)

if np and np.isscalar(scalar_like) and hasattr(scalar_like, "item"):
return scalar_like.item()

Expand Down
15 changes: 15 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,18 @@ def test_pandas_inplace_modification_1267(request: pytest.FixtureRequest) -> Non
assert snw.dtype == nw.Int64
s[0] = 999.5
assert snw.dtype == nw.Float64


def test_pandas_fixed_offset_1302() -> None:
result = nw.from_native(
pd.Series(pd.to_datetime(["2020-01-01T00:00:00.000000000+01:00"])),
series_only=True,
).dtype
assert result == nw.Datetime("ns", "UTC+01:00")
result = nw.from_native(
pd.Series(pd.to_datetime(["2020-01-01T00:00:00.000000000+01:00"])).convert_dtypes(
dtype_backend="pyarrow"
),
series_only=True,
).dtype
assert result == nw.Datetime("ns", "+01:00")
8 changes: 8 additions & 0 deletions tests/selectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def test_selectors(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_by_dtype_with_datetime(constructor: Constructor) -> None:
data = {"a": ["a", "b"], "b": [1, 2]}
df = nw.from_native(constructor(data))
result = df.select(nw.selectors.by_dtype(nw.Datetime).dt.year())
expected = {"a": [2020, 2020]}
assert_equal_data(result, expected)


def test_numeric(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(numeric() + 1)
Expand Down
27 changes: 8 additions & 19 deletions tests/translate/to_py_scalar_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Any

import numpy as np
Expand All @@ -12,9 +11,6 @@
import narwhals.stable.v1 as nw
from narwhals.dependencies import get_cudf

if TYPE_CHECKING:
from tests.utils import ConstructorEager


@pytest.mark.parametrize(
("input_value", "expected"),
Expand All @@ -28,28 +24,21 @@
(b"a", b"a"),
(datetime(2021, 1, 1), datetime(2021, 1, 1)),
(timedelta(days=1), timedelta(days=1)),
(pd.Timestamp("2020-01-01"), datetime(2020, 1, 1)),
(pd.Timedelta(days=3), timedelta(days=3)),
(np.datetime64("2020-01-01", "s"), datetime(2020, 1, 1)),
(np.datetime64("2020-01-01", "ms"), datetime(2020, 1, 1)),
(np.datetime64("2020-01-01", "us"), datetime(2020, 1, 1)),
(np.datetime64("2020-01-01", "ns"), datetime(2020, 1, 1)),
],
)
def test_to_py_scalar(
constructor_eager: ConstructorEager,
input_value: Any,
expected: Any,
request: pytest.FixtureRequest,
) -> None:
if isinstance(input_value, bytes) and "cudf" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor_eager({"a": [input_value]}))
output = nw.to_py_scalar(df["a"].item(0))
if expected == 1 and constructor_eager.__name__.startswith("pandas"):
output = nw.to_py_scalar(input_value)
if expected == 1:
assert not isinstance(output, np.int64)
elif isinstance(expected, datetime) and constructor_eager.__name__.startswith(
"pandas"
):
assert not isinstance(output, pd.Timestamp)
elif isinstance(expected, timedelta) and constructor_eager.__name__.startswith(
"pandas"
):
assert not isinstance(output, pd.Timedelta)
assert output == expected


Expand Down

0 comments on commit 7e3300e

Please sign in to comment.