From 93d2fc793aaaa1f2e3f641b8767d6bbae1669347 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Mon, 30 Sep 2024 19:45:16 +0200 Subject: [PATCH] feat: `Datetime(time_unit, time_zone)` and `Duration(time_unit)` types (#960) --- .github/workflows/extremes.yml | 6 +- narwhals/_arrow/utils.py | 13 ++-- narwhals/_pandas_like/series.py | 2 +- narwhals/_pandas_like/utils.py | 62 +++++++++++++++---- narwhals/_polars/utils.py | 24 +++++--- narwhals/dtypes.py | 95 ++++++++++++++++++++++++++++-- narwhals/functions.py | 13 ++-- narwhals/stable/v1/dtypes.py | 39 +++++++++++- tests/dtypes_test.py | 74 +++++++++++++++++++++++ tests/expr_and_series/cast_test.py | 36 +++++++++++ tests/series_only/cast_test.py | 4 +- tests/stable_api_test.py | 12 +++- utils/check_api_reference.py | 2 +- 13 files changed, 332 insertions(+), 50 deletions(-) create mode 100644 tests/dtypes_test.py diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index f11a4f4bb..858d0b6e2 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -119,6 +119,8 @@ jobs: kaggle kernels output "marcogorelli/variable-brink-glacier" - name: install-polars run: python -m pip install *.whl + - name: install-pandas-nightly + run: pip install --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple pandas - name: install-reqs run: uv pip install --upgrade tox virtualenv setuptools pip -r requirements-dev.txt --system - name: uninstall pyarrow @@ -127,8 +129,8 @@ jobs: # run: uv pip install --extra-index-url https://pypi.fury.io/arrow-nightlies/ --pre pyarrow --system - name: uninstall pandas run: uv pip uninstall pandas --system - - name: install-pandas-nightly - run: uv pip install --prerelease=allow --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple pandas --system + - name: show-deps + run: uv pip freeze - name: uninstall numpy run: uv pip uninstall numpy --system - name: install numpy nightly diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index d51a4b25d..e34e949d5 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -50,9 +50,9 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType: if pa.types.is_date32(dtype): return dtypes.Date() if pa.types.is_timestamp(dtype): - return dtypes.Datetime() + return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz) if pa.types.is_duration(dtype): - return dtypes.Duration() + return dtypes.Duration(time_unit=dtype.unit) if pa.types.is_dictionary(dtype): return dtypes.Categorical() if pa.types.is_struct(dtype): @@ -94,11 +94,12 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: if isinstance_or_issubclass(dtype, dtypes.Categorical): return pa.dictionary(pa.uint32(), pa.string()) if isinstance_or_issubclass(dtype, dtypes.Datetime): - # Use Polars' default - return pa.timestamp("us") + time_unit = getattr(dtype, "time_unit", "us") + time_zone = getattr(dtype, "time_zone", None) + return pa.timestamp(time_unit, tz=time_zone) if isinstance_or_issubclass(dtype, dtypes.Duration): - # Use Polars' default - return pa.duration("us") + time_unit = getattr(dtype, "time_unit", "us") + return pa.duration(time_unit) if isinstance_or_issubclass(dtype, dtypes.Date): return pa.date32() if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index dc9a00009..6569f8b5d 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -197,7 +197,7 @@ def cast( ) -> Self: ser = self._native_series dtype = narwhals_to_native_dtype( - dtype, ser.dtype, self._implementation, self._dtypes + dtype, ser.dtype, self._implementation, self._backend_version, self._dtypes ) return self._from_native_series(ser.astype(dtype)) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 726a07c56..92fbd5193 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -1,8 +1,10 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING from typing import Any from typing import Iterable +from typing import Literal from typing import TypeVar from narwhals.utils import Implementation @@ -213,6 +215,15 @@ def set_axis( def native_to_narwhals_dtype(column: Any, dtypes: DTypes) -> DType: dtype = str(column.dtype) + + pd_datetime_rgx = ( + r"^datetime64\[(?Ps|ms|us|ns)(?:, (?P[a-zA-Z\/]+))?\]$" + ) + pa_datetime_rgx = r"^timestamp\[(?Ps|ms|us|ns)(?:, tz=(?P[a-zA-Z\/]+))?\]\[pyarrow\]$" + + pd_duration_rgx = r"^timedelta64\[(?Ps|ms|us|ns)\]$" + pa_duration_rgx = r"^duration\[(?Ps|ms|us|ns)\]\[pyarrow\]$" + if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}: return dtypes.Int64() if dtype in {"int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"}: @@ -251,12 +262,17 @@ def native_to_narwhals_dtype(column: Any, dtypes: DTypes) -> DType: return dtypes.Boolean() if dtype == "category" or dtype.startswith("dictionary<"): return dtypes.Categorical() - if dtype.startswith(("datetime64", "timestamp[")): - # TODO(Unassigned): different time units and time zones - return dtypes.Datetime() - if dtype.startswith(("timedelta64", "duration")): - # TODO(Unassigned): different time units - return dtypes.Duration() + if (match_ := re.match(pd_datetime_rgx, dtype)) or ( + match_ := re.match(pa_datetime_rgx, dtype) + ): + dt_time_unit: Literal["us", "ns", "ms", "s"] = match_.group("time_unit") # type: ignore[assignment] + dt_time_zone: str | None = match_.group("time_zone") + return dtypes.Datetime(dt_time_unit, dt_time_zone) + if (match_ := re.match(pd_duration_rgx, dtype)) or ( + match_ := re.match(pa_duration_rgx, dtype) + ): + du_time_unit: Literal["us", "ns", "ms", "s"] = match_.group("time_unit") # type: ignore[assignment] + return dtypes.Duration(du_time_unit) if dtype == "date32[day][pyarrow]": return dtypes.Date() if dtype.startswith(("large_list", "list")): @@ -314,6 +330,7 @@ def narwhals_to_native_dtype( # noqa: PLR0915 dtype: DType | type[DType], starting_dtype: Any, implementation: Implementation, + backend_version: tuple[int, ...], dtypes: DTypes, ) -> Any: if "polars" in str(type(dtype)): @@ -416,15 +433,34 @@ def narwhals_to_native_dtype( # noqa: PLR0915 # convert to it? return "category" if isinstance_or_issubclass(dtype, dtypes.Datetime): - # TODO(Unassigned): different time units and time zones + dt_time_unit = getattr(dtype, "time_unit", "us") + dt_time_zone = getattr(dtype, "time_zone", None) + + # Pandas does not support "ms" or "us" time units before version 2.0 + # Let's overwrite with "ns" + if implementation is Implementation.PANDAS and backend_version < ( + 2, + ): # pragma: no cover + dt_time_unit = "ns" + if dtype_backend == "pyarrow-nullable": - return "timestamp[ns][pyarrow]" - return "datetime64[ns]" + tz_part = f", tz={dt_time_zone}" if dt_time_zone else "" + return f"timestamp[{dt_time_unit}{tz_part}][pyarrow]" + else: + tz_part = f", {dt_time_zone}" if dt_time_zone else "" + return f"datetime64[{dt_time_unit}{tz_part}]" if isinstance_or_issubclass(dtype, dtypes.Duration): - # TODO(Unassigned): different time units and time zones - if dtype_backend == "pyarrow-nullable": - return "duration[ns][pyarrow]" - return "timedelta64[ns]" + du_time_unit = getattr(dtype, "time_unit", "us") + if implementation is Implementation.PANDAS and backend_version < ( + 2, + ): # pragma: no cover + dt_time_unit = "ns" + return ( + f"duration[{du_time_unit}][pyarrow]" + if dtype_backend == "pyarrow-nullable" + else f"timedelta64[{du_time_unit}]" + ) + if isinstance_or_issubclass(dtype, dtypes.Date): if dtype_backend == "pyarrow-nullable": return "date32[pyarrow]" diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index db5a4a96b..b2f060906 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Literal if TYPE_CHECKING: from narwhals.dtypes import DType @@ -62,12 +63,15 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType: return dtypes.Categorical() if dtype == pl.Enum: return dtypes.Enum() - if dtype == pl.Datetime: - return dtypes.Datetime() - if dtype == pl.Duration: - return dtypes.Duration() if dtype == pl.Date: return dtypes.Date() + if dtype == pl.Datetime or isinstance(dtype, pl.Datetime): + dt_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us") + dt_time_zone = getattr(dtype, "time_zone", None) + return dtypes.Datetime(time_unit=dt_time_unit, time_zone=dt_time_zone) + if dtype == pl.Duration or isinstance(dtype, pl.Duration): + du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us") + return dtypes.Duration(time_unit=du_time_unit) if dtype == pl.Struct: return dtypes.Struct() if dtype == pl.List: @@ -111,12 +115,16 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: if dtype == dtypes.Enum: msg = "Converting to Enum is not (yet) supported" raise NotImplementedError(msg) - if dtype == dtypes.Datetime: - return pl.Datetime() - if dtype == dtypes.Duration: - return pl.Duration() if dtype == dtypes.Date: return pl.Date() + if dtype == dtypes.Datetime or isinstance(dtype, dtypes.Datetime): + dt_time_unit = getattr(dtype, "time_unit", "us") + dt_time_zone = getattr(dtype, "time_zone", None) + return pl.Datetime(dt_time_unit, dt_time_zone) # type: ignore[arg-type] + if dtype == dtypes.Duration or isinstance(dtype, dtypes.Duration): + du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us") + return pl.Duration(time_unit=du_time_unit) + if dtype == dtypes.List: # pragma: no cover msg = "Converting to List dtype is not supported yet" return NotImplementedError(msg) diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 2d5de0f16..730f69849 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -1,6 +1,8 @@ from __future__ import annotations +from datetime import timezone from typing import TYPE_CHECKING +from typing import Literal if TYPE_CHECKING: from typing_extensions import Self @@ -71,10 +73,95 @@ class Object(DType): ... class Unknown(DType): ... -class Datetime(TemporalType): ... - - -class Duration(TemporalType): ... +class Datetime(TemporalType): + """ + Data type representing a calendar date and time of day. + + Arguments: + time_unit: Unit of time. Defaults to `'us'` (microseconds). + time_zone: Time zone string, as defined in zoneinfo (to see valid strings run + `import zoneinfo; zoneinfo.available_timezones()` for a full list). + + Notes: + Adapted from Polars implementation at: + https://github.com/pola-rs/polars/blob/py-1.7.1/py-polars/polars/datatypes/classes.py#L398-L457 + """ + + def __init__( + self: Self, + time_unit: Literal["us", "ns", "ms", "s"] = "us", + time_zone: str | timezone | None = None, + ) -> None: + if time_unit not in {"s", "ms", "us", "ns"}: + msg = ( + "invalid `time_unit`" + f"\n\nExpected one of {{'ns','us','ms', 's'}}, got {time_unit!r}." + ) + raise ValueError(msg) + + if isinstance(time_zone, timezone): + time_zone = str(time_zone) + + self.time_unit = time_unit + self.time_zone = time_zone + + def __eq__(self: Self, other: object) -> bool: + # allow comparing object instances to class + if type(other) is type and issubclass(other, self.__class__): + return True + elif isinstance(other, self.__class__): + return self.time_unit == other.time_unit and self.time_zone == other.time_zone + else: # pragma: no cover + return False + + def __hash__(self: Self) -> int: # pragma: no cover + return hash((self.__class__, self.time_unit, self.time_zone)) + + def __repr__(self: Self) -> str: # pragma: no cover + class_name = self.__class__.__name__ + return f"{class_name}(time_unit={self.time_unit!r}, time_zone={self.time_zone!r})" + + +class Duration(TemporalType): + """ + Data type representing a time duration. + + Arguments: + time_unit: Unit of time. Defaults to `'us'` (microseconds). + + Notes: + Adapted from Polars implementation at: + https://github.com/pola-rs/polars/blob/py-1.7.1/py-polars/polars/datatypes/classes.py#L460-L502 + """ + + def __init__( + self: Self, + time_unit: Literal["us", "ns", "ms", "s"] = "us", + ) -> None: + if time_unit not in ("s", "ms", "us", "ns"): + msg = ( + "invalid `time_unit`" + f"\n\nExpected one of {{'ns','us','ms', 's'}}, got {time_unit!r}." + ) + raise ValueError(msg) + + self.time_unit = time_unit + + def __eq__(self: Self, other: object) -> bool: + # allow comparing object instances to class + if type(other) is type and issubclass(other, self.__class__): + return True + elif isinstance(other, self.__class__): + return self.time_unit == other.time_unit + else: # pragma: no cover + return False + + def __hash__(self: Self) -> int: # pragma: no cover + return hash((self.__class__, self.time_unit)) + + def __repr__(self: Self) -> str: # pragma: no cover + class_name = self.__class__.__name__ + return f"{class_name}(time_unit={self.time_unit!r})" class Categorical(DType): ... diff --git a/narwhals/functions.py b/narwhals/functions.py index f0bf5d4ad..e1505e78f 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -13,6 +13,7 @@ from narwhals.dataframe import LazyFrame from narwhals.translate import from_native from narwhals.utils import Implementation +from narwhals.utils import parse_version from narwhals.utils import validate_laziness # Missing type parameters for generic type "DataFrame" @@ -235,11 +236,9 @@ def _new_series_impl( narwhals_to_native_dtype as pandas_like_narwhals_to_native_dtype, ) + backend_version = parse_version(native_namespace.__version__) dtype = pandas_like_narwhals_to_native_dtype( - dtype, - None, - implementation, - dtypes, + dtype, None, implementation, backend_version, dtypes ) native_series = native_namespace.Series(values, name=name, dtype=dtype) @@ -374,12 +373,10 @@ def _from_dict_impl( narwhals_to_native_dtype as pandas_like_narwhals_to_native_dtype, ) + backend_version = parse_version(native_namespace.__version__) schema = { name: pandas_like_narwhals_to_native_dtype( - schema[name], - native_type, - implementation, - dtypes, + schema[name], native_type, implementation, backend_version, dtypes ) for name, native_type in native_frame.dtypes.items() } diff --git a/narwhals/stable/v1/dtypes.py b/narwhals/stable/v1/dtypes.py index 942881ba4..0d1e58468 100644 --- a/narwhals/stable/v1/dtypes.py +++ b/narwhals/stable/v1/dtypes.py @@ -2,8 +2,8 @@ from narwhals.dtypes import Boolean from narwhals.dtypes import Categorical from narwhals.dtypes import Date -from narwhals.dtypes import Datetime -from narwhals.dtypes import Duration +from narwhals.dtypes import Datetime as NwDatetime +from narwhals.dtypes import Duration as NwDuration from narwhals.dtypes import Enum from narwhals.dtypes import Float32 from narwhals.dtypes import Float64 @@ -21,6 +21,41 @@ from narwhals.dtypes import UInt64 from narwhals.dtypes import Unknown + +class Datetime(NwDatetime): + """ + Data type representing a calendar date and time of day. + + Arguments: + time_unit: Unit of time. Defaults to `'us'` (microseconds). + time_zone: Time zone string, as defined in zoneinfo (to see valid strings run + `import zoneinfo; zoneinfo.available_timezones()` for a full list). + + Notes: + Adapted from Polars implementation at: + https://github.com/pola-rs/polars/blob/py-1.7.1/py-polars/polars/datatypes/classes.py#L398-L457 + """ + + def __hash__(self) -> int: + return hash(self.__class__) + + +class Duration(NwDuration): + """ + Data type representing a time duration. + + Arguments: + time_unit: Unit of time. Defaults to `'us'` (microseconds). + + Notes: + Adapted from Polars implementation at: + https://github.com/pola-rs/polars/blob/py-1.7.1/py-polars/polars/datatypes/classes.py#L460-L502 + """ + + def __hash__(self) -> int: + return hash(self.__class__) + + __all__ = [ "Array", "Boolean", diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py new file mode 100644 index 000000000..58061597f --- /dev/null +++ b/tests/dtypes_test.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from typing import Literal + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest + +import narwhals.stable.v1 as nw +from narwhals.utils import parse_version + + +@pytest.mark.parametrize("time_unit", ["us", "ns", "ms"]) +@pytest.mark.parametrize("time_zone", ["Europe/Rome", timezone.utc, None]) +def test_datetime_valid( + time_unit: Literal["us", "ns", "ms"], time_zone: str | timezone | None +) -> None: + dtype = nw.Datetime(time_unit=time_unit, time_zone=time_zone) + + assert dtype == nw.Datetime(time_unit=time_unit, time_zone=time_zone) + assert dtype == nw.Datetime + + if time_zone: + assert dtype != nw.Datetime(time_unit=time_unit) + if time_unit != "ms": + assert dtype != nw.Datetime(time_unit="ms") + + +@pytest.mark.parametrize("time_unit", ["abc"]) +def test_datetime_invalid(time_unit: str) -> None: + with pytest.raises(ValueError, match="invalid `time_unit`"): + nw.Datetime(time_unit=time_unit) # type: ignore[arg-type] + + +@pytest.mark.parametrize("time_unit", ["us", "ns", "ms"]) +def test_duration_valid(time_unit: Literal["us", "ns", "ms"]) -> None: + dtype = nw.Duration(time_unit=time_unit) + + assert dtype == nw.Duration(time_unit=time_unit) + assert dtype == nw.Duration + + if time_unit != "ms": + assert dtype != nw.Duration(time_unit="ms") + + +@pytest.mark.parametrize("time_unit", ["abc"]) +def test_duration_invalid(time_unit: str) -> None: + with pytest.raises(ValueError, match="invalid `time_unit`"): + nw.Duration(time_unit=time_unit) # type: ignore[arg-type] + + +def test_second_tu() -> None: + s = pd.Series(np.array([np.datetime64("2020-01-01", "s")])) + result = nw.from_native(s, series_only=True) + if parse_version(pd.__version__) < (2,): # pragma: no cover + assert result.dtype == nw.Datetime("ns") + else: + assert result.dtype == nw.Datetime("s") + s = pa.chunked_array([pa.array([datetime(2020, 1, 1)], type=pa.timestamp("s"))]) + result = nw.from_native(s, series_only=True) + assert result.dtype == nw.Datetime("s") + s = pd.Series(np.array([np.timedelta64(1, "s")])) + result = nw.from_native(s, series_only=True) + if parse_version(pd.__version__) < (2,): # pragma: no cover + assert result.dtype == nw.Duration("ns") + else: + assert result.dtype == nw.Duration("s") + s = pa.chunked_array([pa.array([timedelta(1)], type=pa.duration("s"))]) + result = nw.from_native(s, series_only=True) + assert result.dtype == nw.Duration("s") diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 00f242148..dafe876ab 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -1,3 +1,9 @@ +from __future__ import annotations + +from datetime import datetime +from datetime import timedelta +from datetime import timezone + import pandas as pd import pyarrow as pa import pytest @@ -5,6 +11,8 @@ import narwhals.stable.v1 as nw from narwhals.utils import parse_version from tests.utils import Constructor +from tests.utils import compare_dicts +from tests.utils import is_windows data = { "a": [1], @@ -180,3 +188,31 @@ class Banana: with pytest.raises(AssertionError, match=r"Unknown dtype"): df.select(nw.col("a").cast(Banana)) + + +def test_cast_datetime_tz_aware( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "dask" in str(constructor) or ( + "pyarrow_table" in str(constructor) and is_windows() + ): + request.applymarker(pytest.mark.xfail) + + data = { + "date": [ + datetime(2024, 1, 1, tzinfo=timezone.utc) + timedelta(days=i) + for i in range(3) + ] + } + expected = { + "date": ["2024-01-01 01:00:00", "2024-01-02 01:00:00", "2024-01-03 01:00:00"] + } + + df = nw.from_native(constructor(data)) + result = df.select( + nw.col("date") + .cast(nw.Datetime("ms", time_zone="Europe/Rome")) + .cast(nw.String()) + .str.slice(offset=0, length=19) + ) + compare_dicts(result, expected) diff --git a/tests/series_only/cast_test.py b/tests/series_only/cast_test.py index 37ae55a01..672cbebc2 100644 --- a/tests/series_only/cast_test.py +++ b/tests/series_only/cast_test.py @@ -75,13 +75,13 @@ def test_cast_date_datetime_pandas() -> None: df = df.select(nw.col("a").cast(nw.Datetime)) result = nw.to_native(df) expected = pd.DataFrame({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]}).astype( - {"a": "timestamp[ns][pyarrow]"} + {"a": "timestamp[us][pyarrow]"} ) pd.testing.assert_frame_equal(result, expected) # pandas: pyarrow datetime to date dfpd = pd.DataFrame({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]}).astype( - {"a": "timestamp[ns][pyarrow]"} + {"a": "timestamp[us][pyarrow]"} ) df = nw.from_native(dfpd) df = df.select(nw.col("a").cast(nw.Date)) diff --git a/tests/stable_api_test.py b/tests/stable_api_test.py index a12b20cc6..7a67f5723 100644 --- a/tests/stable_api_test.py +++ b/tests/stable_api_test.py @@ -1,4 +1,5 @@ from datetime import datetime +from datetime import timedelta from typing import Any import polars as pl @@ -139,7 +140,12 @@ def test_series_docstrings() -> None: def test_dtypes(constructor: Constructor) -> None: - df = nw.from_native(constructor({"a": [1], "b": [datetime(2020, 1, 1)]})) + df = nw_v1.from_native( + constructor({"a": [1], "b": [datetime(2020, 1, 1)], "c": [timedelta(1)]}) + ) dtype = df.collect_schema()["b"] - assert dtype in {nw.Datetime} - assert isinstance(dtype, nw.Datetime) + assert dtype in {nw_v1.Datetime} + assert isinstance(dtype, nw_v1.Datetime) + dtype = df.collect_schema()["c"] + assert dtype in {nw_v1.Duration} + assert isinstance(dtype, nw_v1.Duration) diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index 410956e04..d3f30aaa2 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -31,7 +31,7 @@ "zip_with", "__iter__", } -BASE_DTYPES = {"NumericType", "DType", "TemporalType"} +BASE_DTYPES = {"NumericType", "DType", "TemporalType", "Literal"} files = {remove_suffix(i, ".py") for i in os.listdir("narwhals")}