Skip to content

Commit

Permalink
Simplify np.datetime64 subclass thanks to stack overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
rlskoeser committed Jun 6, 2024
1 parent 220dc7e commit 81299bb
Showing 1 changed file with 26 additions and 28 deletions.
54 changes: 26 additions & 28 deletions src/undate/undate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,39 @@
ONE_DAY = np.timedelta64(1, "D")


class LocalDate:
def __init__(self, year: str, month: str = None, day: str = None):
class LocalDate(np.ndarray):
# shim to make np.datetime64 act more like datetime.date

# extend np.datetime64 datatype
# adapted from https://stackoverflow.com/a/27129510/9706217

def __new__(cls, year: str, month: str = None, day: str = None):
if isinstance(year, np.datetime64):
self._date = year
data = year
else:
datestr = year
if month is not None:
datestr = f"{year}-{month:02d}"
if day is not None:
datestr = f"{datestr}-{day:02d}"
self._date = np.datetime64(datestr)
data = np.datetime64(datestr)

def __str__(self):
return str(self._date)
data = np.asarray(data, dtype="datetime64")
if data.dtype != "datetime64[D]":
raise Exception(
"Unable to parse dates adequately to datetime64[D]: %s" % data
)
obj = data.view(cls)
return obj

def Export(self):
return self

def __array_finalize__(self, obj):
if obj is None:
return

# custom properties to access year, month, day

@property
def year(self):
Expand All @@ -43,27 +62,6 @@ def month(self):
def day(self):
return int(str(self._date.astype("datetime64[D]")).split("-")[-1])

def __eq__(self, other: object) -> bool:
return self._date == other._date

def __gt__(self, other: object) -> bool:
# define gt ourselves so we can support > comparison with datetime.date,
# but rely on existing less than implementation.
# strictly greater than must rule out equals
return not (self._date < other._date or self._date == other._date)

def __le__(self, other: Union["Undate", datetime.date]) -> bool:
return self._date == other._date or self._date < other._date

def __add__(self, other):
if isinstance(other, LocalDate):
return LocalDate(self._date + other._date)
if isinstance(other, np.timedelta64):
return LocalDate(self._date + other)

def __sub__(self, other) -> np.timedelta64:
return self._date - other._date


class DatePrecision(IntEnum):
"""date precision, to indicate date precision independent from how much
Expand All @@ -82,7 +80,7 @@ class DatePrecision(IntEnum):
def __str__(self):
return f"{self.name}"

# numpy date units are years (‘Y’), months (‘M’), weeks (‘W’), and days (‘D’),
# numpy date units are years (‘Y’), months (‘M’), weeks (‘W’), and days (‘D’)


class Undate:
Expand Down

0 comments on commit 81299bb

Please sign in to comment.