Skip to content

Commit

Permalink
rename spatiotemporaldict
Browse files Browse the repository at this point in the history
  • Loading branch information
knutdrand committed Aug 30, 2024
1 parent 646650f commit e8bc805
Show file tree
Hide file tree
Showing 47 changed files with 198 additions and 198 deletions.
12 changes: 6 additions & 6 deletions climate_health/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .geojson import geojson_to_shape, geojson_to_graph, NeighbourGraph
from .plotting.prediction_plot import plot_forecast_from_summaries
from .predictor import get_model
from .spatio_temporal_data.temporal_dataclass import SpatioTemporalDict
from .spatio_temporal_data.temporal_dataclass import DataSet
import dataclasses

from .time_period.date_util_wrapper import Week, delta_week, delta_month, Month
Expand Down Expand Up @@ -48,9 +48,9 @@ class AreaPolygons:
@dataclasses.dataclass
class PredictionData:
area_polygons: AreaPolygons = None
health_data: SpatioTemporalDict[HealthData] = None
climate_data: SpatioTemporalDict[ClimateData] = None
population_data: SpatioTemporalDict[HealthPopulationData] = None
health_data: DataSet[HealthData] = None
climate_data: DataSet[ClimateData] = None
population_data: DataSet[HealthPopulationData] = None
disease_id: Optional[str] = None
features : List[object] = None

Expand Down Expand Up @@ -102,7 +102,7 @@ def read_zip_folder(zip_file_path: str) -> PredictionData:
temperature["mean_temperature"] = temperature["mean_temperature"].astype(float)

features = json.load(ziparchive.open(expected_files["area_polygons"]))["features"]
climate = SpatioTemporalDict.from_pandas(temperature, dataclass=SimpleClimateData)
climate = DataSet.from_pandas(temperature, dataclass=SimpleClimateData)

population_json = json.load(ziparchive.open(expected_files["population"]))
population = parse_population_data(population_json)
Expand Down Expand Up @@ -199,7 +199,7 @@ def train_on_prediction_data(data, model_name=None, n_months=4, docker_filename=
population = data.population_data[location]
new_dict[location] = FullData.combine(health.data(), climate.data(), population)

climate_health_data = SpatioTemporalDict(new_dict)
climate_health_data = DataSet(new_dict)
prediction_start = Month(climate_health_data.end_timestamp) - n_months * delta_month
train_data, _, future_weather = train_test_split_with_weather(climate_health_data, prediction_start)
logger.info(f"Training model {model_name} on {len(train_data.items())} locations")
Expand Down
4 changes: 2 additions & 2 deletions climate_health/assessment/dataset_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from climate_health.dataset import IsSpatioTemporalDataSet
from climate_health.datatypes import ClimateHealthData, ClimateData, HealthData
from climate_health.spatio_temporal_data.temporal_dataclass import SpatioTemporalDict
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from climate_health.time_period import Year, Month, TimePeriod
from climate_health.time_period.relationships import previous
import dataclasses
Expand Down Expand Up @@ -50,7 +50,7 @@ def train_test_split(data_set: IsSpatioTemporalDataSet, prediction_start_period:
return train_data, test_data


def train_test_split_with_weather(data_set: SpatioTemporalDict, prediction_start_period: TimePeriod,
def train_test_split_with_weather(data_set: DataSet, prediction_start_period: TimePeriod,
extension: Optional[IsTimeDelta] = None,
future_weather_class: Type[ClimateData] = ClimateData):
train_set, test_set = train_test_split(data_set, prediction_start_period, extension)
Expand Down
6 changes: 3 additions & 3 deletions climate_health/assessment/forecast.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from climate_health.assessment.dataset_splitting import train_test_split_with_weather
from climate_health.plotting.prediction_plot import plot_forecast_from_summaries
from climate_health.spatio_temporal_data.temporal_dataclass import SpatioTemporalDict
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from climate_health.time_period.date_util_wrapper import TimeDelta, Month
import logging

logger = logging.getLogger(__name__)


def forecast(model, dataset: SpatioTemporalDict, prediction_length: TimeDelta, graph=None):
def forecast(model, dataset: DataSet, prediction_length: TimeDelta, graph=None):
'''
Forecast n_months into the future using the model
'''
Expand All @@ -27,7 +27,7 @@ def forecast(model, dataset: SpatioTemporalDict, prediction_length: TimeDelta, g
return predictions


def multi_forecast(model, dataset: SpatioTemporalDict, prediction_lenght: TimeDelta, pre_train_delta: TimeDelta):
def multi_forecast(model, dataset: DataSet, prediction_lenght: TimeDelta, pre_train_delta: TimeDelta):
'''
Forecast n_months into the future using the model
'''
Expand Down
2 changes: 1 addition & 1 deletion climate_health/climate_data/gee_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .external import ee
from ..datatypes import ClimateData, Location, Shape, SimpleClimateData
from ..spatio_temporal_data.temporal_dataclass import SpatioTemporalDict
from ..spatio_temporal_data.temporal_dataclass import DataSet
from ..time_period import TimePeriod, PeriodRange
from ..services.cache_manager import get_cache
from ..time_period import Month, Day
Expand Down
6 changes: 3 additions & 3 deletions climate_health/climate_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sklearn import linear_model

from .datatypes import ClimateData
from climate_health.spatio_temporal_data.temporal_dataclass import SpatioTemporalDict
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from climate_health.time_period import PeriodRange


Expand All @@ -18,7 +18,7 @@ def __init__(self):
def _feature_matrix(self, time_period: PeriodRange):
return time_period.month[:,None] == np.arange(1, 13)

def train(self, train_data: SpatioTemporalDict[ClimateData]):
def train(self, train_data: DataSet[ClimateData]):
for location, data in train_data.items():
data = data.data()
self._cls = data.__class__
Expand All @@ -36,7 +36,7 @@ def predict(self, time_period: PeriodRange):
prediction_dict = {}
for location, models in self._models.items():
prediction_dict[location] = self._cls(time_period, **{field: model.predict(x) for field, model in models.items()})
return SpatioTemporalDict(prediction_dict)
return DataSet(prediction_dict)



Expand Down
6 changes: 3 additions & 3 deletions climate_health/dhis2_interface/ChapProgram.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from climate_health.dhis2_interface.json_parsing import parse_climate_data, parse_disease_data, parse_population_data, predictions_to_datavalue
from climate_health.dhis2_interface.src.PushResult import DataValue, push_result
from climate_health.dhis2_interface.src.create_data_element_if_not_exists import create_data_element_if_not_exists
from climate_health.spatio_temporal_data.temporal_dataclass import SpatioTemporalDict
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from climate_health.dhis2_interface.src.PullAnalytics import pull_analytics_elements
from climate_health.dhis2_interface.src.Config import DHIS2AnalyticRequest, ProgramConfig

Expand Down Expand Up @@ -49,7 +49,7 @@ def startModelling(self):
# do the fancy modelling here?
return

def pushDataToDHIS2(self, data : SpatioTemporalDict[HealthData], model_name : str, do_dict=True):
def pushDataToDHIS2(self, data : DataSet[HealthData], model_name : str, do_dict=True):
# TODO do we need to delete previous modells?, or would we overwrite exisitng values?

#used to prefix CHAP-dataElements in DHIS2
Expand Down Expand Up @@ -86,6 +86,6 @@ def pushDataToDHIS2(self, data : SpatioTemporalDict[HealthData], model_name : st
process.startModelling()

d = {"" : ""}
sp = SpatioTemporalDict(d)
sp = DataSet(d)

process.pushDataToDHIS2(sp, "dengue")
8 changes: 4 additions & 4 deletions climate_health/dhis2_interface/json_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from climate_health.datatypes import HealthData, HealthPopulationData
from climate_health.dhis2_interface.periods import get_period_id, convert_time_period_string
from climate_health.dhis2_interface.src.PushResult import DataValue
from climate_health.spatio_temporal_data.temporal_dataclass import SpatioTemporalDict
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -41,7 +41,7 @@ def parse_disease_data(json_data, disease_name='IDS - Dengue Fever (Suspected ca
name_mapping={'time_period': 1, 'disease_cases': 3, 'location': 2}):
# meta_data = MetadDataLookup(json_data['metaData'])
df = json_to_pandas(json_data, name_mapping)
return SpatioTemporalDict.from_pandas(df, dataclass=HealthData, fill_missing=True)
return DataSet.from_pandas(df, dataclass=HealthData, fill_missing=True)


def parse_json_rows(rows, name_mapping):
Expand Down Expand Up @@ -85,10 +85,10 @@ def add_population_data(disease_data, population_lookup):
np.full(len(data.data()), population_lookup[location])
)
for location, data in disease_data.items()}
return SpatioTemporalDict(new_dict)
return DataSet(new_dict)


def predictions_to_datavalue(data: SpatioTemporalDict[HealthData], attribute_mapping: dict[str, str]):
def predictions_to_datavalue(data: DataSet[HealthData], attribute_mapping: dict[str, str]):
entries = []
for location, data in data.items():
data = data.data()
Expand Down
6 changes: 3 additions & 3 deletions climate_health/dhis2_interface/pydantic_to_spatiotemporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from climate_health.api_types import DataElement
from climate_health.datatypes import TimeSeriesArray
from climate_health.dhis2_interface.periods import convert_time_period_string
from climate_health.spatio_temporal_data.temporal_dataclass import SpatioTemporalDict
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet


def v1_conversion(data_list: list[DataElement], fill_missing=False) -> SpatioTemporalDict[TimeSeriesArray]:
def v1_conversion(data_list: list[DataElement], fill_missing=False) -> DataSet[TimeSeriesArray]:
'''
Convert a list of DataElement objects to a SpatioTemporalDict[TimeSeriesArray] object.
'''
df = pd.DataFrame([d.dict() for d in data_list])
df.sort_values(by=['ou', 'pe'], inplace=True)
d = dict(time_period=[convert_time_period_string(row) for row in df['pe']], location=df.ou, value=df.value)
converted_df = pd.DataFrame(d)
return SpatioTemporalDict.from_pandas(converted_df, TimeSeriesArray, fill_missing=fill_missing)
return DataSet.from_pandas(converted_df, TimeSeriesArray, fill_missing=fill_missing)
12 changes: 6 additions & 6 deletions climate_health/external/external_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from climate_health.runners.command_line_runner import CommandLineRunner
from climate_health.runners.docker_runner import DockerImageRunner, DockerRunner
from climate_health.runners.runner import Runner
from climate_health.spatio_temporal_data.temporal_dataclass import SpatioTemporalDict
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from climate_health.time_period.date_util_wrapper import TimeDelta, delta_month, TimePeriod

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -220,17 +220,17 @@ def predict(self, future_data: IsSpatioTemporalDataSet[FeatureType]) -> IsSpatio
time_periods = [TimePeriod.parse(s) for s in df.time_period.astype(str)]
mask = [start_time <= time_period.start_timestamp for time_period in time_periods]
df = df[mask]
return SpatioTemporalDict.from_pandas(df, result_class)
return DataSet.from_pandas(df, result_class)

def forecast(self, future_data: SpatioTemporalDict[FeatureType], n_samples=1000,
def forecast(self, future_data: DataSet[FeatureType], n_samples=1000,
forecast_delta: TimeDelta = 3 * delta_month):
time_period = next(iter(future_data.data())).data().time_period
n_periods = forecast_delta // time_period.delta
future_data = SpatioTemporalDict({key: value.data()[:n_periods] for key, value in future_data.items()})
future_data = DataSet({key: value.data()[:n_periods] for key, value in future_data.items()})
return self.predict(future_data)

def prediction_summary(self, future_data: SpatioTemporalDict[FeatureType], n_samples=1000):
future_data = SpatioTemporalDict({key: value.data()[:1] for key, value in future_data.items()})
def prediction_summary(self, future_data: DataSet[FeatureType], n_samples=1000):
future_data = DataSet({key: value.data()[:1] for key, value in future_data.items()})
return self.predict(future_data)

def _provide_temp_file(self):
Expand Down
14 changes: 7 additions & 7 deletions climate_health/external/models/flax_models/flax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from climate_health.external.models.flax_models.rnn_model import RNNModel
from climate_health.external.models.jax_models.model_spec import skip_nan_distribution, Poisson, Normal, \
NegativeBinomial, NegativeBinomial2, NegativeBinomial3
from climate_health.spatio_temporal_data.temporal_dataclass import SpatioTemporalDict
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet

PoissonSkipNaN = skip_nan_distribution(Poisson)

Expand Down Expand Up @@ -54,12 +54,12 @@ def model(self):
self._model = RNNModel(n_locations=self._saved_x.shape[0])
return self._model

def set_validation_data(self, data: SpatioTemporalDict[FullData]):
def set_validation_data(self, data: DataSet[FullData]):
x, y = self._get_series(data)
self._validation_x = x
self._validation_y = y

def _get_series(self, data: SpatioTemporalDict[FullData]):
def _get_series(self, data: DataSet[FullData]):
x = []
y = []
for series in data.values():
Expand All @@ -83,7 +83,7 @@ def get_validation_y(self, params):
#print(y_pred.shape)
return y_pred[:, self._saved_x.shape[1]:]

def train(self, data: SpatioTemporalDict[ClimateHealthTimeSeries]):
def train(self, data: DataSet[ClimateHealthTimeSeries]):
x, y = self._get_series(data)
self._mu = np.mean(x, axis=(0, 1))
self._std = np.std(x, axis=(0, 1))
Expand Down Expand Up @@ -153,7 +153,7 @@ def loss_func(params):

self._params = state.params

def forecast(self, data: SpatioTemporalDict[FullData], n_samples=1000, forecast_delta=1):
def forecast(self, data: DataSet[FullData], n_samples=1000, forecast_delta=1):
#print('Forecasting with params:', self._params)
x, y = self._get_series(data)
x = (x - self._mu) / self._std
Expand All @@ -169,7 +169,7 @@ def forecast(self, data: SpatioTemporalDict[FullData], n_samples=1000, forecast_
median = self._get_q(eta, 0.5)[:, self._saved_x.shape[1]:]

time_period = next(iter(data.values())).time_period
return SpatioTemporalDict(
return DataSet(
{key: SummaryStatistics(time_period, *([row.ravel()] * 5 + [q_low.ravel(), q_high.ravel()]))
for key, row, q_high, q_low in zip(data.keys(), y_pred, q_highs, q_lows)})

Expand All @@ -188,7 +188,7 @@ def diagnose(self):
plt.plot(self._losses)
plt.show()

def predict(self, data: SpatioTemporalDict[FullData]):
def predict(self, data: DataSet[FullData]):
x, y = self._get_series(data)
return np.exp(self.model.apply(self._params, x))

Expand Down
28 changes: 14 additions & 14 deletions climate_health/external/models/jax_models/hierarchical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
HiearchicalLogProbFuncWithDistrictStates
from climate_health.external.models.jax_models.utii import get_state_transform, state_or_param, tree_sample, index_tree, \
PydanticTree
from climate_health.spatio_temporal_data.temporal_dataclass import SpatioTemporalDict
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from .model_spec import Poisson, PoissonSkipNaN, Normal, distributionclass
from .protoype_annotated_spec import Positive
from .simple_ssm import get_summary
Expand Down Expand Up @@ -98,13 +98,13 @@ def __init__(self, key: PRNGKey = PRNGKey(0), params: Optional[dict[str, Any]] =
def set_training_control(self, training_control):
self._training_control = training_control

def _get_standardization_func(self, data: SpatioTemporalDict[ClimateHealthTimeSeries]):
def _get_standardization_func(self, data: DataSet[ClimateHealthTimeSeries]):
values = np.concatenate([value.mean_temperature for value in data.values()])
mean = np.mean(values)
std = np.std(values)
return lambda x: (x - mean) / std

def _set_model(self, data_dict: SpatioTemporalDict[SeasonalClimateHealthData]):
def _set_model(self, data_dict: DataSet[SeasonalClimateHealthData]):
min_year = min([min(value.year) for value in data_dict.values()])
max_year = max([max(value.year) for value in data_dict.values()])
n_years = max_year - min_year + 1
Expand All @@ -125,7 +125,7 @@ def ch_regression(params: 'ParamClass', given: SeasonalClimateHealthData) -> Hea

self._regression_model = ch_regression

def train(self, data: SpatioTemporalDict[FullData]):
def train(self, data: DataSet[FullData]):
random_key, self._key = jax.random.split(self._key)
data_dict = {key: create_seasonal_data(value.data()) for key, value in data.items()}
self._set_model(data_dict)
Expand Down Expand Up @@ -155,30 +155,30 @@ def _get_log_prob_func(self, data_dict):
self._param_class, SeasonalDistrictParams, data_dict,
self._regression_model, observed_name='disease_cases')

def sample(self, data: SpatioTemporalDict[ClimateData], n=1) -> SpatioTemporalDict[HealthData]:
def sample(self, data: DataSet[ClimateData], n=1) -> DataSet[HealthData]:
params = index_tree(self.params, -1)
random_key, self._key = jax.random.split(self._key)
data_dict = {key: create_seasonal_data(value.data()) for key, value in data.items()}
true_params = {name: join_global_and_district(params[0],
params[1][name])
for name in data_dict.keys()}
return SpatioTemporalDict({key: self._regression_model(true_params[key], data_dict[key]).sample(random_key)
for key in data_dict.keys()})
return DataSet({key: self._regression_model(true_params[key], data_dict[key]).sample(random_key)
for key in data_dict.keys()})

def _adapt_params(self, params, data_dict):
return params

def prediction_summary(self, future_weather: SpatioTemporalDict[ClimateData], n_samples=1000) -> SpatioTemporalDict[SummaryStatistics]:
def prediction_summary(self, future_weather: DataSet[ClimateData], n_samples=1000) -> DataSet[SummaryStatistics]:
time_delta = next(iter(future_weather.data())).data().time_period.delta
future_weather = SpatioTemporalDict({key: value.data()[:1] for key, value in future_weather.items()})
future_weather = DataSet({key: value.data()[:1] for key, value in future_weather.items()})
return self.forecast(future_weather, n_samples=n_samples, forecast_delta=1*time_delta)

def forecast(self, future_weather: SpatioTemporalDict[ClimateData], n_samples=1000,
forecast_delta=6 * delta_month) -> SpatioTemporalDict[SummaryStatistics]:
def forecast(self, future_weather: DataSet[ClimateData], n_samples=1000,
forecast_delta=6 * delta_month) -> DataSet[SummaryStatistics]:

time_period = next(iter(future_weather.data())).data().time_period
n_periods = forecast_delta // time_period.delta
future_weather = SpatioTemporalDict({key: value.data()[:n_periods] for key, value in future_weather.items()})
future_weather = DataSet({key: value.data()[:n_periods] for key, value in future_weather.items()})
time_period = next(iter(future_weather.data())).data().time_period
num_samples = n_samples
param_key, self._key = jax.random.split(self._key)
Expand All @@ -195,7 +195,7 @@ def forecast(self, future_weather: SpatioTemporalDict[ClimateData], n_samples=10
for key, value in data_dict.items():
new_key, random_key = jax.random.split(random_key)
samples[key].append(self._sample_from_model(key, new_key, params, true_params, value))
return SpatioTemporalDict(
return DataSet(
{key: get_summary(time_period, np.array(value)) for key, value in samples.items()})

def _sample_from_model(self, key, new_key, params, true_params, value):
Expand Down Expand Up @@ -270,7 +270,7 @@ def diagnose(self):
if val.ndim == 1:
px.line(val).show()

def _set_model(self, data_dict: SpatioTemporalDict[SeasonalClimateHealthDataState]):
def _set_model(self, data_dict: DataSet[SeasonalClimateHealthDataState]):
min_idx = min([min(value.time_index) for value in data_dict.values()])
max_idx = max([max(value.time_index) for value in data_dict.values()])
self._idx_range = (min_idx, max_idx)
Expand Down
Loading

0 comments on commit e8bc805

Please sign in to comment.