Skip to content

Commit

Permalink
downstream?
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Sep 14, 2024
1 parent 20e36a1 commit ec1cb5e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 18 deletions.
24 changes: 8 additions & 16 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,12 @@ def __init__(
self.time_unit = time_unit
self.time_zone = time_zone

def __eq__(self: Self, other: type[DType] | DType) -> bool: # type: ignore[override]
# allow comparing object instances to class
if type(other) is type and issubclass(other, Datetime):
return True
elif isinstance(other, Datetime):
return self.time_unit == other.time_unit and self.time_zone == other.time_zone
else:
return False
def __eq__(self: Self, other: object) -> bool:
return (
isinstance(other, Datetime)
and self.time_unit == other.time_unit
and self.time_zone == other.time_zone
)

def __hash__(self: Self) -> int: # pragma: no cover
return hash((self.__class__, self.time_unit, self.time_zone))
Expand Down Expand Up @@ -149,14 +147,8 @@ def __init__(

self.time_unit = time_unit

def __eq__(self: Self, other: type[DType] | DType) -> bool: # type: ignore[override]
# allow comparing object instances to class
if type(other) is type and issubclass(other, Duration):
return True
elif isinstance(other, Duration):
return self.time_unit == other.time_unit
else:
return False
def __eq__(self: Self, other: object) -> bool:
return isinstance(other, Duration) and self.time_unit == other.time_unit

def __hash__(self: Self) -> int: # pragma: no cover
return hash((self.__class__, self.time_unit))
Expand Down
2 changes: 0 additions & 2 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def test_datetime_valid(
dtype = nw.Datetime(time_unit=time_unit, time_zone=time_zone)

assert dtype == nw.Datetime(time_unit=time_unit, time_zone=time_zone)
assert dtype == nw.Datetime

if time_zone:
assert dtype != nw.Datetime(time_unit=time_unit)
Expand All @@ -35,7 +34,6 @@ def test_duration_valid(time_unit: Literal["us", "ns", "ms"]) -> None:
dtype = nw.Duration(time_unit=time_unit)

assert dtype == nw.Duration(time_unit=time_unit)
assert dtype == nw.Duration

if time_unit != "ms":
assert dtype != nw.Duration(time_unit="ms")
Expand Down

0 comments on commit ec1cb5e

Please sign in to comment.