Skip to content

Commit

Permalink
feature: weekly predictions ahead in time
Browse files Browse the repository at this point in the history
  • Loading branch information
knutdrand committed Sep 20, 2024
1 parent 9ef2430 commit 7dfdc7d
Show file tree
Hide file tree
Showing 13 changed files with 194 additions and 69 deletions.
16 changes: 11 additions & 5 deletions climate_health/assessment/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from climate_health.assessment.dataset_splitting import train_test_split_with_weather
from climate_health.assessment.prediction_evaluator import Estimator, Predictor
from climate_health.climate_predictor import MonthlyClimatePredictor
from climate_health.climate_predictor import MonthlyClimatePredictor, get_climate_predictor
from climate_health.data.gluonts_adaptor.dataset import ForecastAdaptor
from climate_health.plotting.prediction_plot import plot_forecast_from_summaries
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
Expand Down Expand Up @@ -60,10 +60,16 @@ def forecast_ahead(estimator: Estimator, dataset: DataSet, prediction_length: in


def forecast_with_predicted_weather(predictor: Predictor, historic_data: DataSet, prediction_length: int, ):
prediction_range = PeriodRange.from_start_and_n_periods(
Month(historic_data.end_timestamp).to_string(), prediction_length)
climate_predictor = MonthlyClimatePredictor()
climate_predictor.train(historic_data)
delta = historic_data.period_range[0].time_delta
prediction_range = PeriodRange(historic_data.end_timestamp,
historic_data.end_timestamp + delta * prediction_length,
delta)

#prediction_range = PeriodRange.from_start_and_n_periods(
# Month(historic_data.end_timestamp).to_string(), prediction_length)
#climate_predictor = MonthlyClimatePredictor()
#climate_predictor.train(historic_data)
climate_predictor= get_climate_predictor(historic_data)
future_weather = climate_predictor.predict(prediction_range)
predictions = predictor.predict(historic_data, future_weather)
return predictions
7 changes: 5 additions & 2 deletions climate_health/assessment/prediction_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def plot_forecasts(predictors: list[Predictor], test_instance, truth, pdf_filena


def plot_forecasts(predictor, test_instance, truth, pdf_filename):

forecast_dict = _get_forecast_dict(predictor, test_instance)
with PdfPages(pdf_filename) as pdf:
for location, forecasts in forecast_dict.items():
Expand All @@ -243,7 +244,8 @@ def plot_forecasts(predictor, test_instance, truth, pdf_filename):
plt.figure(figsize=(8, 4)) # Set the figure size
t = _t[_t.index <= forecast.index[-1]]
forecast.plot(show_label=True)
plt.plot(t[-150:].to_timestamp())
plotting_context = 52*6
plt.plot(t[-plotting_context:].to_timestamp())
plt.title(location)
plt.legend()
pdf.savefig()
Expand All @@ -260,7 +262,8 @@ def plot_predictions(predictions: DataSet[Samples], truth: DataSet, pdf_filename
plt.figure(figsize=(8, 4)) # Set the figure size
# t = _t[_t.index <= prediction.index[-1]]
prediction.plot(show_label=True)
plt.plot(t[-150:].to_timestamp())
context_length = 52*6
plt.plot(t[-context_length:].to_timestamp())
plt.title(location)
plt.legend()
pdf.savefig()
Expand Down
24 changes: 18 additions & 6 deletions climate_health/climate_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,26 @@

from .datatypes import ClimateData
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from climate_health.time_period import PeriodRange
from climate_health.time_period import PeriodRange, Month, Week


def get_climate_predictor(train_data: DataSet[ClimateData]):
if isinstance(train_data.period_range[0], Month):
estimator = MonthlyClimatePredictor()
else:
assert isinstance(train_data.period_range[0], Week)
estimator = WeeklyClimatePredictor()
estimator.train(train_data)
return estimator


class MonthlyClimatePredictor:
def __init__(self):
self._models = defaultdict(dict)
self._cls = None

def _feature_matrix(self, time_period: PeriodRange):
return time_period.month[:,None] == np.arange(1, 13)
return time_period.month[:, None] == np.arange(1, 13)

def train(self, train_data: DataSet[ClimateData]):
train_data = train_data.remove_field('disease_cases')
Expand All @@ -35,10 +44,13 @@ def predict(self, time_period: PeriodRange):
x = self._feature_matrix(time_period)
prediction_dict = {}
for location, models in self._models.items():
prediction_dict[location] = self._cls(time_period, **{field: model.predict(x).ravel() for field, model in models.items()})
prediction_dict[location] = self._cls(time_period, **{field: model.predict(x).ravel() for field, model in
models.items()})
return DataSet(prediction_dict)





class WeeklyClimatePredictor(MonthlyClimatePredictor):
def _feature_matrix(self, time_period: PeriodRange):
t = time_period.week[:, None] == np.arange(1, 53)
t[..., -1] |= time_period.week == 53
return t
5 changes: 4 additions & 1 deletion climate_health/spatio_temporal_data/temporal_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,7 @@ def from_fields(cls, dataclass: type[TimeSeriesData], fields: dict[str, 'DataSet
new_dict[location] = dataclass(period_range, **{field: fields[field][location].fill_to_range(start_timestamp, end_timestamp).value for field in field_names})
return cls(new_dict)


def plot(self):
for location, data in self.items():
df = data.to_pandas()
df.plot(x='time_period', title=location)
2 changes: 1 addition & 1 deletion climate_health/time_period/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .date_util_wrapper import TimePeriod, Year, Month, Day, PeriodRange, delta_month, delta_week
from .date_util_wrapper import TimePeriod, Year, Month, Day, PeriodRange, delta_month, delta_week, Week
#from ._legacy_implementation import TimePeriod, Year, Month, Day
from .period_range import period_range as get_period_range
get_period_range = PeriodRange.from_time_periods
71 changes: 56 additions & 15 deletions climate_health/time_period/date_util_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import logging
from datetime import datetime
from numbers import Number
from typing import Union, Iterable
from typing import Union, Iterable, Tuple

import dateutil
import numpy as np
import pandas as pd
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from pytz import utc


class DateUtilWrapper:
Expand All @@ -25,6 +27,10 @@ def __getattr__(self, item: str):
class TimeStamp(DateUtilWrapper):
_used_attributes = ('year', 'month', 'day', '__str__', '__repr__')

@property
def week(self):
return self._date.isocalendar()[1]

def __init__(self, date: datetime):
self._date = date

Expand Down Expand Up @@ -60,7 +66,7 @@ def __sub__(self, other: 'TimeStamp'):
return TimeDelta(relativedelta(self._date, other._date))

def _comparison(self, other: 'TimeStamp', func_name: str):
return getattr(self._date, func_name)(other._date)
return getattr(self._date.replace(tzinfo=utc), func_name)(other._date.replace(tzinfo=utc))


class TimePeriod:
Expand Down Expand Up @@ -150,7 +156,7 @@ def time_delta(self) -> 'TimeDelta':

@classmethod
def parse(cls, text_repr: str):
if 'W' in text_repr:
if 'W' in text_repr or '/' in text_repr:
return cls.parse_week(text_repr)
try:
year = int(text_repr)
Expand All @@ -172,9 +178,15 @@ def from_pandas(cls, period: pd.Period):

@classmethod
def parse_week(cls, week: str):
year, weeknr = week.split('W')
print('########', week)
return Week(int(year), int(weeknr))
if 'W' in week:
year, weeknr = week.split('W')
return Week(int(year), int(weeknr))
elif '/' in week:
start, end = week.split('/')
start_date = dateutil.parser.parse(start)
end_date = dateutil.parser.parse(end)
assert relativedelta(end_date, start_date).days == 6, f'Week must be 7 days {start_date} {end_date}'
return Week(start_date) # type: ignore

@property
def start_timestamp(self):
Expand Down Expand Up @@ -206,10 +218,20 @@ def id(self):
return self._date.strftime('%Y%m%d')


class WeekNumbering:
@staticmethod
def get_week_info(date: datetime) -> Tuple[int, int, int]:
return date.isocalendar()

@staticmethod
def get_date(year: int, week: int, day: int) -> datetime:
return datetime.strptime(f'{year}-W{week}-{day}', "%G-W%V-%w")


class Week(TimePeriod):
_used_attributes = []#'year']
_extension = relativedelta(weeks=1)
_week_numbering = WeekNumbering

@property
def id(self):
Expand All @@ -224,12 +246,15 @@ def __init__(self, date, *args, **kwargs):
week_nr = args[0] if args else kwargs['week']
self._date = self.__date_from_numbers(year, week_nr)
self.week = week_nr
self.year = self._date.year
self.year = year
#self.year = self._date.year
else:
if isinstance(date, TimeStamp):
date = date._date
self.week = date.isocalendar()[1]
self.year = date.isocalendar()[0]
year, week, day = date.isocalendar()
self.week = week
self.year = year

self._date = date

def __sub__(self, other: 'TimePeriod'):
Expand All @@ -243,11 +268,18 @@ def __str__(self):

__repr__ = __str__

def __date_from_numbers(cls, year: int, week_nr: int):
return datetime.strptime(f'{year}-W{week_nr}-1', "%Y-W%W-%w")
def __date_from_numbers(self, year: int, week_nr: int):
date = self._week_numbering.get_date(year, week_nr, 1)
#date = datetime.strptime(f'{year}-W{week_nr}-1', "%Y-W%W-%w")
assert date.isocalendar()[:2] == (year, week_nr), (date.isocalendar()[:2], year, week_nr)
return date

@classmethod
def _isocalendar_week_to_date(cls, year: int, week_nr: int, day: int):
return datetime.strptime(f'{year}-W{week_nr}-{day}', "%Y-W%V-%w")

def topandas(self):
return self.__str__()
#return self.__str__()
return pd.Period(self._date, freq='W-MON')


Expand Down Expand Up @@ -365,6 +397,10 @@ def month(self):
def year(self):
return np.array([p.start_timestamp.year for p in self])

@property
def week(self):
return np.array([p.start_timestamp.week for p in self])

@property
def delta(self):
return self._time_delta
Expand Down Expand Up @@ -418,7 +454,8 @@ def _period_class(self):
raise ValueError(f'Unknown time delta {self._time_delta}')

def __iter__(self):
return (self._period_class((self._start_timestamp + self._time_delta * i)._date) for i in range(len(self)))
return (self._period_class((self._start_timestamp + self._time_delta * i)._date)
for i in range(len(self)))

def __getitem__(self, item: slice | int):
''' Slice by numeric index in the period range'''
Expand Down Expand Up @@ -517,7 +554,10 @@ def from_ids(cls, ids: Iterable[str], fill_missing=False):

@classmethod
def from_start_and_n_periods(cls, start_period: pd.Period, n_periods: int):
period = TimePeriod.from_pandas(start_period)
if not isinstance(start_period, TimePeriod):
period = TimePeriod.from_pandas(start_period)
else:
period = start_period
delta = period.time_delta
return cls.from_time_periods(period, period + delta * (n_periods-1))

Expand All @@ -543,7 +583,8 @@ def searchsorted(self, period: TimePeriod, side='left'):
if side not in ('left', 'right'):
raise ValueError(f'Invalid side {side}')
assert period.time_delta == self._time_delta, (period, self._time_delta)
n_steps = TimeDelta(relativedelta(period._date, self._start_timestamp._date)) // self._time_delta
n_steps = self._time_delta.n_periods(self._start_timestamp, period.start_timestamp)
# n_steps = TimeDelta(relativedelta(period._date, self._start_timestamp._date)) // self._time_delta
if side == 'right':
n_steps += 1
n_steps = min(max(0, n_steps), len(self)) # if period is outside
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[pytest]
#norecursedirs = tests/data_wrangling tests/external tests/spatio_temporal_data tests/
norecursedirs = tests/data_wrangling tests/spatio_temporal_data tests/
norecursedirs = tests/data_wrangling tests/spatio_temporal_data tests/ .mypy_cache
ignore = ['tests/test_meteostat_wrapper']
log_cli = True

Expand Down
Loading

0 comments on commit 7dfdc7d

Please sign in to comment.