Skip to content

Commit

Permalink
feat: allow inspecting the inner type of List (#1104)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Oct 1, 2024
1 parent 7fb0c5d commit 9fb12be
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 8 deletions.
2 changes: 1 addition & 1 deletion narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
if pa.types.is_struct(dtype):
return dtypes.Struct()
if pa.types.is_list(dtype) or pa.types.is_large_list(dtype):
return dtypes.List()
return dtypes.List(native_to_narwhals_dtype(dtype.value_type, dtypes))
if pa.types.is_fixed_size_list(dtype):
return dtypes.Array()
return dtypes.Unknown() # pragma: no cover
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def map_duckdb_dtype_to_narwhals_dtype(duckdb_dtype: Any, dtypes: DTypes) -> DTy
return dtypes.Duration()
if duckdb_dtype.startswith("STRUCT"):
return dtypes.Struct()
if re.match(r"\w+\[\]", duckdb_dtype):
return dtypes.List()
if match_ := re.match(r"(.*)\[\]$", duckdb_dtype):
return dtypes.List(map_duckdb_dtype_to_narwhals_dtype(match_.group(1), dtypes))
if re.match(r"\w+\[\d+\]", duckdb_dtype):
return dtypes.Array()
return dtypes.Unknown()
Expand Down
4 changes: 3 additions & 1 deletion narwhals/_ibis/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def map_ibis_dtype_to_narwhals_dtype(ibis_dtype: Any, dtypes: DTypes) -> DType:
if ibis_dtype.is_timestamp():
return dtypes.Datetime()
if ibis_dtype.is_array():
return dtypes.List()
return dtypes.List(
map_ibis_dtype_to_narwhals_dtype(ibis_dtype.value_type, dtypes)
)
if ibis_dtype.is_struct():
return dtypes.Struct()
return dtypes.Unknown() # pragma: no cover
Expand Down
7 changes: 6 additions & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from typing import Literal
from typing import TypeVar

from narwhals._arrow.utils import (
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
)
from narwhals.utils import Implementation
from narwhals.utils import isinstance_or_issubclass

Expand Down Expand Up @@ -276,7 +279,9 @@ def native_to_narwhals_dtype(column: Any, dtypes: DTypes) -> DType:
if dtype == "date32[day][pyarrow]":
return dtypes.Date()
if dtype.startswith(("large_list", "list")):
return dtypes.List()
return dtypes.List(
arrow_native_to_narwhals_dtype(column.dtype.pyarrow_dtype.value_type, dtypes)
)
if dtype.startswith("fixed_size_list"):
return dtypes.Array()
if dtype.startswith("struct"):
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
if dtype == pl.Struct:
return dtypes.Struct()
if dtype == pl.List:
return dtypes.List()
return dtypes.List(native_to_narwhals_dtype(dtype.inner, dtypes))
if dtype == pl.Array:
return dtypes.Array()
return dtypes.Unknown()
Expand Down
26 changes: 25 additions & 1 deletion narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,31 @@ class Enum(DType): ...
class Struct(DType): ...


class List(DType): ...
class List(DType):
def __init__(self, inner: DType | type[DType]) -> None:
self.inner = inner

def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
# This equality check allows comparison of type classes and type instances.
# If a parent type is not specific about its inner type, we infer it as equal:
# > list[i64] == list[i64] -> True
# > list[i64] == list[f32] -> False
# > list[i64] == list -> True

# allow comparing object instances to class
if type(other) is type and issubclass(other, self.__class__):
return True
elif isinstance(other, self.__class__):
return self.inner == other.inner
else:
return False

def __hash__(self) -> int:
return hash((self.__class__, self.inner))

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}({self.inner!r})"


class Array(DType): ...
Expand Down
16 changes: 15 additions & 1 deletion tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,21 @@ def test_duration_invalid(time_unit: str) -> None:
nw.Duration(time_unit=time_unit) # type: ignore[arg-type]


def test_second_tu() -> None:
def test_list_valid() -> None:
dtype = nw.List(nw.Int64)
assert dtype == nw.List(nw.Int64)
assert dtype == nw.List
assert dtype != nw.List(nw.Float32)
assert dtype != nw.Duration
assert repr(dtype) == "List(<class 'narwhals.dtypes.Int64'>)"
dtype = nw.List(nw.List(nw.Int64))
assert dtype == nw.List(nw.List(nw.Int64))
assert dtype == nw.List
assert dtype != nw.List(nw.List(nw.Float32))
assert dtype in {nw.List(nw.List(nw.Int64))}


def test_second_time_unit() -> None:
s = pd.Series(np.array([np.datetime64("2020-01-01", "s")]))
result = nw.from_native(s, series_only=True)
if parse_version(pd.__version__) < (2,): # pragma: no cover
Expand Down

0 comments on commit 9fb12be

Please sign in to comment.