Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
ivargr committed Sep 4, 2024
2 parents c49aa56 + f55b1ea commit 27d6717
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 19 deletions.
2 changes: 0 additions & 2 deletions climate_health/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import numpy as np
import pandas as pd
from cyclopts import App

from climate_health.external.external_model import get_model_from_directory_or_github_url, get_model_maybe_yaml
from climate_health.spatio_temporal_data.multi_country_dataset import MultiCountryDataSet
from . import api
from climate_health.dhis2_interface.ChapProgram import ChapPullPost
Expand Down
1 change: 1 addition & 0 deletions climate_health/data/datasets.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from climate_health.spatio_temporal_data.multi_country_dataset import MultiCountryDataSet

ISIMIP_dengue_harmonized = MultiCountryDataSet.from_tar('https://github.com/dhis2/chap-core/raw/dev/example_data/full_data.tar.gz')
68 changes: 57 additions & 11 deletions climate_health/data/gluonts_adaptor/dataset.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,43 @@
import dataclasses
from pathlib import Path
from typing import Iterable
from typing import Iterable, TypeVar

import numpy as np

from climate_health.assessment.dataset_splitting import train_test_split
from climate_health.file_io.example_data_set import datasets
from climate_health.datatypes import TimeSeriesData, remove_field
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from climate_health.spatio_temporal_data.multi_country_dataset import MultiCountryDataSet
from climate_health.time_period import delta_month

from climate_health.time_period import delta_month, PeriodRange
import logging
logger = logging.getLogger(__name__)
GlunTSDataSet = Iterable[dict]


T = TypeVar('T', bound=TimeSeriesData)
class DataSetAdaptor:

@staticmethod
def _from_single_gluonts_series(series: dict, dataclass: type[T]) -> T:
field_names = [field.name for field in dataclasses.fields(dataclass) if field.name not in ['disease_cases', 'time_period']]
field_dict = {name: series['feat_dynamic_real'].T[:, i] for i, name in enumerate(field_names)}
field_dict['disease_cases'] = series['target']
field_dict['time_period'] = PeriodRange.from_start_and_n_periods(series['start'], len(series['target']))
return dataclass(**field_dict)

@staticmethod
def from_gluonts(self, gluonts_dataset: GlunTSDataSet, dataclass: type[TimeSeriesData]) -> DataSet:
raise NotImplementedError
def from_gluonts(gluonts_dataset: GlunTSDataSet, dataclass: type[T]) -> DataSet[T]:
return DataSet(
{series['feat_static_cat'][0]:
DataSetAdaptor._from_single_gluonts_series(series, dataclass) for series in gluonts_dataset})

to_dataset = from_gluonts

@staticmethod
def get_metadata(dataset: DataSet):
return {'static_cat':
[{i: location for i, location in enumerate(dataset.keys())}]}

@staticmethod
def to_gluonts(dataset: DataSet, start_index=0, static=None, real=None) -> GlunTSDataSet:
if isinstance(dataset, MultiCountryDataSet):
Expand All @@ -28,15 +47,37 @@ def to_gluonts(dataset: DataSet, start_index=0, static=None, real=None) -> GlunT
assert real is None
for i, (location, data) in enumerate(dataset.items(), start=start_index):
period = data.time_period[0]

yield {
'start': period.topandas(),
'target': data.disease_cases,
'feat_dynamic_real': remove_field(data, 'disease_cases').to_array(), # exclude the target
'feat_dynamic_real': remove_field(data, 'disease_cases').to_array().T, # exclude the target
'feat_static_cat': [i]+static,
}

from_dataset = to_gluonts

@staticmethod
def to_gluonts_testinstances(history: DataSet, future: DataSet, prediction_length):
for i, (location, historic_data) in enumerate(history.items()):
future_data = future[location]
assert future_data.start_timestamp == historic_data.end_timestamp

period = historic_data.time_period[0]
historic_predictors = remove_field(historic_data, 'disease_cases').to_array()

future_predictors = future_data.to_array()
logger.warning(
'Assuming location order is the same for test data')

yield {
'start': period.topandas(),
'target': historic_data.disease_cases,
'feat_dynamic_real': np.concatenate([historic_predictors, future_predictors], axis=0),
'feat_static_cat': [i],
}


@staticmethod
def to_gluonts_multicountry(dataset: MultiCountryDataSet) -> GlunTSDataSet:
offset = 0
Expand All @@ -45,11 +86,16 @@ def to_gluonts_multicountry(dataset: MultiCountryDataSet) -> GlunTSDataSet:
offset += len(data.keys())


def get_dataset(name):
def get_dataset(name, with_metadata=False):
if name == 'full':
data_set = MultiCountryDataSet.from_folder(Path('/home/knut/Data/ch_data/full_data'))
return DataSetAdaptor.to_gluonts(data_set)
return DataSetAdaptor.to_gluonts(datasets[name].load())
dataset = MultiCountryDataSet.from_folder(Path('/home/knut/Data/ch_data/full_data'))
ds = DataSetAdaptor.to_gluonts(dataset)
else:
dataset = datasets[name].load()
ds = DataSetAdaptor.to_gluonts(dataset)
if with_metadata:
return ds, DataSetAdaptor.get_metadata(dataset)
return ds

def get_split_dataset(name, n_periods=6) -> tuple[GlunTSDataSet, GlunTSDataSet]:
if name == 'full':
Expand Down
16 changes: 16 additions & 0 deletions climate_health/data/gluonts_adaptor/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Iterable
from .dataset import DataSetAdaptor
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet

GluonTSDataSet = Iterable[dict]

class GluonTSModel:
def __init__(self, model):
self._model = model

def train(self, dataset: GluonTSDataSet):
dataset = DataSetAdaptor.to_dataset(dataset)

def predict(self, dataset: GluonTSDataSet):
pass

1 change: 0 additions & 1 deletion climate_health/rest_api_src/worker_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import plotly.express as px
import json
import os
from pathlib import Path
Expand Down
4 changes: 4 additions & 0 deletions climate_health/spatio_temporal_data/multi_country_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def __getitem__(self, item):
def countries(self):
return list(self._data.keys())

def keys(self):
return self._data.keys()

@classmethod
def from_tar(cls, url, dataclass=FullData):
tar_gz_file_name = pooch.retrieve(url, known_hash=None)
Expand All @@ -26,6 +29,7 @@ def from_tar(cls, url, dataclass=FullData):
extracted_files = {Path(member.name).stem: tar_file.extractfile(member) for member in members}
print({name: ef.name for name, ef in extracted_files.items() if ef is not None})
data = {name: DataSet.from_csv(ef, dataclass) for name, ef in extracted_files.items() if ef is not None}

return MultiCountryDataSet(data)

def items(self):
Expand Down
9 changes: 9 additions & 0 deletions climate_health/time_period/date_util_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def parse(cls, text_repr: str):
return Month(date)
return Year(date)

@classmethod
def from_pandas(cls, period: pd.Period):
return cls.parse(str(period))

@classmethod
def parse_week(cls, week: str):
year, weeknr = week.split('W')
Expand Down Expand Up @@ -496,6 +500,11 @@ def from_ids(cls, ids: Iterable[str], fill_missing=False):
periods = [TimePeriod.from_id(id) for id in ids]
return cls.from_period_list(fill_missing, periods)

@classmethod
def from_start_and_n_periods(cls, start_period: pd.Period, n_periods: int):
period = TimePeriod.from_pandas(start_period)
delta = period.time_delta
return cls.from_time_periods(period, period + delta * (n_periods-1))

@classmethod
def from_period_list(cls, fill_missing, periods):
Expand Down
Empty file.
Empty file added tests/test_dataset.py
Empty file.
30 changes: 25 additions & 5 deletions tests/test_gluonts_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,47 @@

import pytest

from climate_health.assessment.dataset_splitting import train_test_split
from climate_health.data.gluonts_adaptor.dataset import DataSetAdaptor, get_dataset, get_split_dataset
from climate_health.datatypes import FullData, remove_field
from climate_health.spatio_temporal_data.multi_country_dataset import MultiCountryDataSet
from climate_health.file_io.example_data_set import datasets
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from .data_fixtures import train_data_pop, full_data
from climate_health.data.datasets import ISIMIP_dengue_harmonized


@pytest.fixture
def full_dataset():
foldername = Path('/home/knut/Data/ch_data/full_data')
if not foldername.exists():
pytest.skip()
dataset = MultiCountryDataSet.from_folder(foldername)
dataset = ISIMIP_dengue_harmonized
return dataset


@pytest.fixture
def gluonts_vietnam_dataset():
dataset = ISIMIP_dengue_harmonized['vietnam']
return DataSetAdaptor.from_dataset(dataset)


def test_to_dataset(gluonts_vietnam_dataset):
dataset = DataSetAdaptor.to_dataset(gluonts_vietnam_dataset, FullData)
assert isinstance(dataset, DataSet)
assert len(dataset.keys()) > 3


def test_to_testinstances(train_data_pop: DataSet):
train, test = train_test_split(train_data_pop, prediction_start_period=train_data_pop.period_range[-3])
ds = DataSetAdaptor().to_gluonts_testinstances(train, test.remove_field('disease_cases'), 3)
print(list(ds))



def test_to_gluonts(train_data_pop):
dataset = DataSetAdaptor().to_gluonts(train_data_pop)
dataset = list(dataset)
assert len(dataset) == 2
assert dataset[0]['target'].shape == (7,)
assert dataset[0]['feat_dynamic_real'].shape == (7, 3)
assert dataset[0]['feat_dynamic_real'].shape == (3, 7)
for i, data in enumerate(dataset):
assert data['feat_static_cat'] == [i]

Expand Down
8 changes: 8 additions & 0 deletions tests/time_period/test_dateutil_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,11 @@ def test_searchsorted(period_range, period2):
assert period_range.searchsorted(period2) == array_comparison.searchsorted(1)
assert period_range.searchsorted(period2, side='right') == array_comparison.searchsorted(1, side='right')
assert period_range.searchsorted(period2, side='left') == array_comparison.searchsorted(1, side='left')


def test_from_start_and_n_periods():
start_period = pd.Period('2020-01')
n_periods = 3
period_range = PeriodRange.from_start_and_n_periods(start_period, n_periods)
assert len(period_range) == n_periods
assert period_range[0] == TimePeriod.from_pandas(start_period)

0 comments on commit 27d6717

Please sign in to comment.