Skip to content

Commit

Permalink
fix: consistent to_numpy behaviour for tz-aware
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Nov 2, 2024
1 parent 5c3db5b commit 434bd5e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 18 deletions.
28 changes: 12 additions & 16 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,34 +511,30 @@ def to_numpy(self, dtype: Any = None, copy: bool | None = None) -> Any:
# the default is meant to be None, but pandas doesn't allow it?
# https://numpy.org/doc/stable/reference/generated/numpy.ndarray.__array__.html
copy = copy or self._implementation is Implementation.CUDF
if self.dtype == self._dtypes.Datetime and self.dtype.time_zone is not None: # type: ignore[attr-defined]
s = self.dt.convert_time_zone("UTC").dt.replace_time_zone(None)._native_series
else:
s = self._native_series

has_missing = self._native_series.isna().any()
if (
has_missing
and str(self._native_series.dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING
):
has_missing = s.isna().any()
if has_missing and str(s.dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING:
if self._implementation is Implementation.PANDAS and self._backend_version < (
1,
): # pragma: no cover
kwargs = {}
else:
kwargs = {"na_value": float("nan")}
return self._native_series.to_numpy(
dtype=dtype
or PANDAS_TO_NUMPY_DTYPE_MISSING[str(self._native_series.dtype)],
return s.to_numpy(
dtype=dtype or PANDAS_TO_NUMPY_DTYPE_MISSING[str(s.dtype)],
copy=copy,
**kwargs,
)
if (
not has_missing
and str(self._native_series.dtype) in PANDAS_TO_NUMPY_DTYPE_NO_MISSING
):
return self._native_series.to_numpy(
dtype=dtype
or PANDAS_TO_NUMPY_DTYPE_NO_MISSING[str(self._native_series.dtype)],
if not has_missing and str(s.dtype) in PANDAS_TO_NUMPY_DTYPE_NO_MISSING:
return s.to_numpy(
dtype=dtype or PANDAS_TO_NUMPY_DTYPE_NO_MISSING[str(s.dtype)],
copy=copy,
)
return self._native_series.to_numpy(dtype=dtype, copy=copy)
return s.to_numpy(dtype=dtype, copy=copy)

def to_pandas(self) -> Any:
if self._implementation is Implementation.PANDAS:
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable

import pandas as pd
import polars as pl
Expand All @@ -19,6 +18,7 @@
from narwhals.typing import IntoDataFrame
from narwhals.typing import IntoFrame
from tests.utils import Constructor
from tests.utils import ConstructorEager

with contextlib.suppress(ImportError):
import modin.pandas # noqa: F401
Expand Down Expand Up @@ -117,7 +117,7 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame:
@pytest.fixture(params=eager_constructors)
def constructor_eager(
request: pytest.FixtureRequest,
) -> Callable[[Any], IntoDataFrame]:
) -> ConstructorEager:
return request.param # type: ignore[no-any-return]


Expand Down
19 changes: 19 additions & 0 deletions tests/series_only/to_numpy_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -30,3 +31,21 @@ def test_to_numpy(
assert s.shape == (3,)

assert_array_equal(s.to_numpy(), np.array(data, dtype=float))


def test_to_numpy_tz_aware(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(
constructor_eager({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]}),
eager_only=True,
)
df = df.select(nw.col("a").dt.replace_time_zone("Asia/Kathmandu"))
result = df["a"].to_numpy()
# for some reason, NumPy uses 'M' for datetimes
assert result.dtype.kind == "M"
assert (
result
== np.array(
["2019-12-31T18:15:00.000000", "2020-01-01T18:15:00.000000"],
dtype=result.dtype,
)
).all()

0 comments on commit 434bd5e

Please sign in to comment.