Skip to content

Commit

Permalink
feature: train test generator
Browse files Browse the repository at this point in the history
  • Loading branch information
knutdrand committed Sep 5, 2024
1 parent 96b0c98 commit a380d04
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
17 changes: 17 additions & 0 deletions climate_health/assessment/dataset_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ def train_test_split(data_set: IsSpatioTemporalDataSet, prediction_start_period:
return train_data, test_data


def train_test_generator(dataset: DataSet, prediction_length: int, n_test_sets: int = 1) -> tuple[
DataSet, Iterable[tuple[DataSet, DataSet]]]:
'''
Genereate a train set along with an iterator of test data that contains tuples of full data up until a
split point and data without target variables for the remaining steps
'''
split_idx = -(prediction_length + n_test_sets)
train_set = dataset.restrict_time_period(slice(None, dataset.period_range[split_idx]))
historic_data = (dataset.restrict_time_period(slice(None, dataset.period_range[split_idx + i]))
for i in range(n_test_sets))
future_data = (dataset.restrict_time_period(slice(dataset.period_range[split_idx + i + 1],
dataset.period_range[split_idx + i + prediction_length]))
for i in range(n_test_sets))
masked_future_data = (dataset.remove_field('disease_cases') for dataset in future_data)
return train_set, zip(historic_data, masked_future_data)


def train_test_split_with_weather(data_set: DataSet, prediction_start_period: TimePeriod,
extension: Optional[IsTimeDelta] = None,
future_weather_class: Type[ClimateData] = ClimateData):
Expand Down
34 changes: 32 additions & 2 deletions climate_health/assessment/prediction_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from dataclasses import dataclass
from typing import Protocol, TypeVar

from sklearn.metrics import root_mean_squared_error
import pandas as pd
import plotly.express as px
from climate_health.assessment.dataset_splitting import get_split_points_for_data_set, split_test_train_on_period
from climate_health.assessment.dataset_splitting import get_split_points_for_data_set, split_test_train_on_period, \
train_test_split
from climate_health.assessment.multi_location_evaluator import MultiLocationEvaluator
from climate_health.datatypes import TimeSeriesData, Samples
from climate_health.predictor.naive_predictor import MultiRegionPoissonModel
from climate_health.reports import HTMLReport, HTMLSummaryReport
import logging

from climate_health.spatio_temporal_data.temporal_dataclass import DataSet

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -38,7 +45,7 @@ def plot_rmse(rmse_dict, do_show=True):


def evaluate_model(data_set, external_model, max_splits=5, start_offset=20,
return_table=False, naive_model_cls=None, callback=None, mode = 'predict',
return_table=False, naive_model_cls=None, callback=None, mode='predict',
run_naive_predictor=True):
'''
Evaluate a model on a dataset using forecast cross validation
Expand Down Expand Up @@ -78,3 +85,26 @@ def evaluate_model(data_set, external_model, max_splits=5, start_offset=20,
results = pd.concat(results.values())
return report, results
return report


FetureType = TypeVar('FeatureType', bound=TimeSeriesData)


def without_disease(t):
return t


class Predictor(Protocol):
def predict(self, historic_data: DataSet[FetureType], future_data: DataSet[without_disease(FetureType)]) -> Samples:
...


class Estimator(Protocol):
def train(self, data: DataSet) -> Predictor:
...


def evaluate_model(self, estimator: Estimator, data: DataSet, n_periods):
train, test_generator = train_test_split(data, data.period_range[-n_periods])
predictor = estimator.train(data)
forecasts = predictor.predict()
11 changes: 9 additions & 2 deletions tests/test_dataset_splitting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from climate_health.time_period import Month
from climate_health.assessment.dataset_splitting import split_test_train_on_period, train_test_split, \
get_split_points_for_period_range
get_split_points_for_period_range, train_test_generator
from climate_health.time_period import PeriodRange
from .data_fixtures import full_data

Expand All @@ -25,9 +25,16 @@ def test_split_test_train_on_period(full_data):
assert len(test_table) == 12 - true_len



def test_get_split_points_for_period_range():
period_range = PeriodRange.from_time_periods(Month(2012, 1), Month(2012, 12))
split_points = get_split_points_for_period_range(1, period_range, start_offset=3)
assert split_points == [Month(2012, 8)]


def test_train_test_generator(full_data):
print(full_data)
train_data, test_pairs = train_test_generator(full_data, prediction_length=3, n_test_sets=2)
test_pairs = list(test_pairs)
assert len(test_pairs)==2
assert all(len(pair[1].period_range)==3 for pair in test_pairs)
assert all(test_pairs[-1][1].period_range == full_data.period_range[-3:])

0 comments on commit a380d04

Please sign in to comment.