Skip to content

Commit

Permalink
Simplify np.datetime64 subclass thanks to stack overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
rlskoeser committed Jun 6, 2024
1 parent 220dc7e commit 32f8b52
Showing 1 changed file with 30 additions and 23 deletions.
53 changes: 30 additions & 23 deletions src/undate/undate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,29 @@
ONE_DAY = np.timedelta64(1, "D")


class LocalDate:
class LocalDate(np.ndarray):
# extend np.datetime64 datatype
# adapted from https://stackoverflow.com/a/27129510/9706217

def __new__(cls, year: str, month: str = None, day: str = None):
if isinstance(year, np.datetime64):
data = year
else:
datestr = year
if month is not None:
datestr = f"{year}-{month:02d}"
if day is not None:
datestr = f"{datestr}-{day:02d}"
data = np.datetime64(datestr)

data = np.asarray(data, dtype="datetime64")
if data.dtype != "datetime64[D]":
raise Exception(
"Unable to parse dates adequately to datetime64[D]: %s" % data
)
obj = data.view(cls)
return obj

def __init__(self, year: str, month: str = None, day: str = None):
if isinstance(year, np.datetime64):
self._date = year
Expand All @@ -28,12 +50,11 @@ def __init__(self, year: str, month: str = None, day: str = None):
datestr = f"{datestr}-{day:02d}"
self._date = np.datetime64(datestr)

def __str__(self):
return str(self._date)

@property
def year(self):
return int(str(self._date.astype("datetime64[Y]")))
# alternate solution from stack overflow
# return self.astype('datetime64[Y]').astype(int) + 1970

@property
def month(self):
Expand All @@ -43,26 +64,12 @@ def month(self):
def day(self):
return int(str(self._date.astype("datetime64[D]")).split("-")[-1])

def __eq__(self, other: object) -> bool:
return self._date == other._date

def __gt__(self, other: object) -> bool:
# define gt ourselves so we can support > comparison with datetime.date,
# but rely on existing less than implementation.
# strictly greater than must rule out equals
return not (self._date < other._date or self._date == other._date)

def __le__(self, other: Union["Undate", datetime.date]) -> bool:
return self._date == other._date or self._date < other._date

def __add__(self, other):
if isinstance(other, LocalDate):
return LocalDate(self._date + other._date)
if isinstance(other, np.timedelta64):
return LocalDate(self._date + other)
def Export(self):
return self

def __sub__(self, other) -> np.timedelta64:
return self._date - other._date
def __array_finalize__(self, obj):
if obj is None:
return


class DatePrecision(IntEnum):
Expand Down

0 comments on commit 32f8b52

Please sign in to comment.