Skip to content

Commit

Permalink
synch
Browse files Browse the repository at this point in the history
  • Loading branch information
knutdrand committed Sep 19, 2024
1 parent 8ab6620 commit 9ef2430
Show file tree
Hide file tree
Showing 21 changed files with 466 additions and 57 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,4 @@ climate_health/web_interface/yarn-error.log*

climate_health/web_interface/node_modules/

/scripts/runs/
28 changes: 27 additions & 1 deletion climate_health/assessment/forecast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import pandas as pd
from matplotlib import pyplot as plt

from climate_health.assessment.dataset_splitting import train_test_split_with_weather
from climate_health.assessment.prediction_evaluator import Estimator, Predictor
from climate_health.climate_predictor import MonthlyClimatePredictor
from climate_health.data.gluonts_adaptor.dataset import ForecastAdaptor
from climate_health.plotting.prediction_plot import plot_forecast_from_summaries
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from climate_health.time_period.date_util_wrapper import TimeDelta, Month
from climate_health.time_period.date_util_wrapper import TimeDelta, Month, PeriodRange
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -41,3 +47,23 @@ def multi_forecast(model, dataset: DataSet, prediction_lenght: TimeDelta, pre_tr
cur_dataset, _, _ = train_test_split_with_weather(cur_dataset, split_period)
logger.info(f'Forecasting {prediction_lenght} months into the future on {len(datasets)} datasets')
return (forecast(model, dataset, prediction_lenght) for dataset in datasets[::-1])


def forecast_ahead(estimator: Estimator, dataset: DataSet, prediction_length: int):
'''
Forecast n_months into the future using the model
'''
logger.info(f'Forecasting {prediction_length} months into the future')
train_data = dataset
predictor = estimator.train(train_data)
return forecast_with_predicted_weather(predictor, train_data, prediction_length, )


def forecast_with_predicted_weather(predictor: Predictor, historic_data: DataSet, prediction_length: int, ):
prediction_range = PeriodRange.from_start_and_n_periods(
Month(historic_data.end_timestamp).to_string(), prediction_length)
climate_predictor = MonthlyClimatePredictor()
climate_predictor.train(historic_data)
future_weather = climate_predictor.predict(prediction_range)
predictions = predictor.predict(historic_data, future_weather)
return predictions
85 changes: 76 additions & 9 deletions climate_health/assessment/prediction_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Protocol, TypeVar
from typing import Protocol, TypeVar, Iterable, Dict

from gluonts.evaluation import Evaluator
from gluonts.model import Forecast
Expand Down Expand Up @@ -111,6 +111,26 @@ def train(self, data: DataSet) -> Predictor:


def evaluate_model(estimator: Estimator, data: DataSet, prediction_length=3, n_test_sets=4, report_filename=None):
'''
Evaluate a model on a dataset on a held out test set, making multiple predictions on the test set
using the same trained model
Parameters
----------
estimator : Estimator
The estimator to train and evaluate
data : DataSet
The data to train and evaluate on
prediction_length : int
The number of periods to predict ahead
n_test_sets : int
The number of test sets to evaluate on
Returns
-------
tuple
Summary and individual evaluation results
'''
train, test_generator = train_test_generator(data, prediction_length, n_test_sets)
predictor = estimator.train(data)
truth_data = {
Expand All @@ -123,10 +143,42 @@ def evaluate_model(estimator: Estimator, data: DataSet, prediction_length=3, n_t
evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])
results = evaluator(tss, forecast_list)
return results


def evaluate_multi_model(estimator: Estimator, data: list[DataSet], prediction_length=3, n_test_sets=4,
report_base_name=None):
trains, test_geneartors = zip(*[train_test_generator(d, prediction_length, n_test_sets) for d in data])
predictor = estimator.multi_train(trains)
result_list = []
for i, (data, test_generator) in enumerate(zip(data, test_geneartors)):
truth_data = {
location: pd.DataFrame(data[location].disease_cases, index=data[location].time_period.to_period_index()) for
location in data.keys()}
if report_base_name is not None:
_, plot_test_generatro = train_test_generator(data, prediction_length, n_test_sets)
plot_forecasts(predictor, plot_test_generatro, truth_data, f'{report_base_name}_i.pdf')
forecast_list, tss = _get_forecast_generators(predictor, test_generator, truth_data)
evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])
results = evaluator(tss, forecast_list)
result_list.append(results)
return results
# forecasts = ((predictor.predict(*test_pair[:2]), test_pair[2]) for test_pair in test_generator)


def _get_forecast_generators(predictor, test_generator, truth_data) -> tuple[list[Forecast], list[pd.DataFrame]]:
def _get_forecast_generators(predictor: Predictor, test_generator: Iterable[tuple[DataSet, DataSet, DataSet]], truth_data: Dict[str, pd.DataFrame]) -> tuple[list[Forecast], list[pd.DataFrame]]:
'''
Get the forecast and truth data for a predictor and test generator.
One entry is a combination of prediction start period and location
Parameters
----------
predictor : Predictor
The predictor to evaluate
test_generator : Iterable[tuple[DataSet, DataSet, DataSet]]
The test generator to generate test data
truth_data : dict[str, pd.DataFrame]
The truth data for the locations
'''
tss = []
forecast_list = []
for historic_data, future_data, _ in test_generator:
Expand All @@ -144,11 +196,14 @@ def _get_forecast_dict(predictor: Predictor, test_generator) -> dict[str, list[F
forecast_dict = defaultdict(list)

for historic_data, future_data, _ in test_generator:
assert len(
future_data.period_range) > 0, f'Future data must have at least one period {historic_data.period_range}, {future_data.period_range}'
forecasts = predictor.predict(historic_data, future_data)
for location, samples in forecasts.items():
forecast_dict[location].append(ForecastAdaptor.from_samples(samples))
return forecast_dict


def get_forecast_df(predictor: Predictor, test_generator) -> pd.DataFrame:
forecast_dict = _get_forecast_dict(predictor, test_generator)
dfs = []
Expand All @@ -158,20 +213,17 @@ def get_forecast_df(predictor: Predictor, test_generator) -> pd.DataFrame:

return forecast_df


def plot_forecasts(predictors: list[Predictor], test_instance, truth, pdf_filename):
forecast_dicts = [_get_forecast_dict(predictor, test_instance) for predictor in predictors]
with PdfPages(pdf_filename) as pdf:
for location in forecast_dicts[0].keys():
_t = truth[location]
for forecast_dict in forecast_dicts:
fig = plt.subplots(figsize=(8, 4),ncols=len(forecast_dict))
fig = plt.subplots(figsize=(8, 4), ncols=len(forecast_dict))
for i in range(len(forecast_dict[location])):
forecast = forecast_dict[location][i]





# plt.figure(figsize=(8, 4)) # Set the figure size
# t = _t[_t.index <= forecast.index[-1]]
# forecast.plot(show_label=True)
Expand All @@ -182,7 +234,6 @@ def plot_forecasts(predictors: list[Predictor], test_instance, truth, pdf_filena
# plt.close() # Close the figure



def plot_forecasts(predictor, test_instance, truth, pdf_filename):
forecast_dict = _get_forecast_dict(predictor, test_instance)
with PdfPages(pdf_filename) as pdf:
Expand All @@ -199,13 +250,29 @@ def plot_forecasts(predictor, test_instance, truth, pdf_filename):
plt.close() # Close the figure


def plot_predictions(predictions: DataSet[Samples], truth: DataSet, pdf_filename):
truth_dict = {location: pd.DataFrame(truth[location].disease_cases, index=truth[location].time_period.to_period_index())
for location in truth.keys()}
with PdfPages(pdf_filename) as pdf:
for location, prediction in predictions.items():
prediction = ForecastAdaptor.from_samples(prediction)
t = truth_dict[location]
plt.figure(figsize=(8, 4)) # Set the figure size
# t = _t[_t.index <= prediction.index[-1]]
prediction.plot(show_label=True)
plt.plot(t[-150:].to_timestamp())
plt.title(location)
plt.legend()
pdf.savefig()
plt.close() # Close the figure


def plot_forecasts_list(predictor, test_instances, truth, pdf_filename):
forecasts, tss = _get_forecast_generators(predictor, test_instances, truth)
with PdfPages(pdf_filename) as pdf:
for i, (forecast_entry, ts_entry) in enumerate(zip(forecasts, tss)):
last_period = forecast_entry.index[-1]
ts_entry = ts_entry[ts_entry.index <= last_period]
offset = ts_entry
plt.figure(figsize=(8, 4)) # Set the figure size
plt.plot(ts_entry[-150:].to_timestamp())
forecast_entry.plot(show_label=True)
Expand Down
40 changes: 40 additions & 0 deletions climate_health/climate_data/gridded_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import ee
import xarray
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize

from ..google_earth_engine.gee_raw import load_credentials
import geopandas as gpd

# Load the GeoJSON file using GeoPandas
def get_gridded_data(polygons_filename):
gdf = gpd.read_file(polygons_filename)
# Get the bounding box of all polygons in the GeoJSON
lon1, lat1, lon2, lat2 = gdf.total_bounds
print(lon1, lat1, lon2, lat2)
credentials = load_credentials()
ee.Initialize(ee.ServiceAccountCredentials(credentials.account, key_data=credentials.private_key))
collection = ee.ImageCollection('ECMWF/ERA5_LAND/DAILY_AGGR').filterDate('2024-08-01', '2024-8-03').select('temperature_2m')
# lon1 = 28.8
# lon2 = 30.9
# lat1 = -2.9
# lat2 = -1.0
country_bounds = ee.Geometry.Rectangle(*gdf.total_bounds)#lon1, lat1, lon2, lat2)
projection = collection.first().select(0).projection() # EPSG:4326
dataset = xarray.open_dataset(
collection,
engine='ee',
projection=projection,
geometry=country_bounds
)
ds = dataset
first_image = dataset.isel(time=0)
temp_d = first_image['temperature_2m']
temp_d.plot()
temp = temp_d.values
#plt.imshow(temp, extent=[ds.lon.min(), ds.lon.max(), ds.lat.min(), ds.lat.max()], origin='lower', cmap='viridis',
# norm=Normalize())
#plt.imshow(temp, cmap='viridis')
gdf.boundary.plot(ax=plt.gca(), edgecolor='red', linewidth=1)
plt.show()
return temp
8 changes: 4 additions & 4 deletions climate_health/climate_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ def _feature_matrix(self, time_period: PeriodRange):
return time_period.month[:,None] == np.arange(1, 13)

def train(self, train_data: DataSet[ClimateData]):
train_data = train_data.remove_field('disease_cases')
for location, data in train_data.items():
data = data.data()
self._cls = data.__class__
x = self._feature_matrix(data.time_period)
for field in dataclasses.fields(data):
if field.name == 'time_period':
if field.name in ('time_period'):
continue
y = getattr(data, field.name)
model = linear_model.LinearRegression()
model.fit(x, y[:,None])
model.fit(x, y[:, None])
self._models[location][field.name] = model

def predict(self, time_period: PeriodRange):
x = self._feature_matrix(time_period)
prediction_dict = {}
for location, models in self._models.items():
prediction_dict[location] = self._cls(time_period, **{field: model.predict(x) for field, model in models.items()})
prediction_dict[location] = self._cls(time_period, **{field: model.predict(x).ravel() for field, model in models.items()})
return DataSet(prediction_dict)


Expand Down
1 change: 1 addition & 0 deletions climate_health/data/gluonts_adaptor/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
import dataclasses
from pathlib import Path
from typing import Iterable, TypeVar
Expand Down
1 change: 1 addition & 0 deletions climate_health/external/models/flax_models/flax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class TrainState(train_state.TrainState):
class FlaxModel:
model: nn.Module# = RNNModel()
n_iter: int = 3000

def __init__(self, rng_key: jax.random.PRNGKey = jax.random.PRNGKey(100), n_iter: int = None):
self.rng_key = rng_key
self._losses = []
Expand Down
23 changes: 2 additions & 21 deletions climate_health/rest_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
from contextlib import asynccontextmanager
import logging
from asyncio import CancelledError
from typing import List, Union
from typing import Union

from fastapi import BackgroundTasks, UploadFile, HTTPException
from pydantic import BaseModel
Expand All @@ -12,18 +10,14 @@
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware

from climate_health.api import read_zip_folder, train_on_prediction_data
from climate_health.api_types import RequestV1
from climate_health.google_earth_engine.gee_era5 import Era5LandGoogleEarthEngine
from climate_health.internal_state import Control, InternalState
from climate_health.model_spec import ModelSpec, model_spec_from_model
from climate_health.predictor import all_models
from climate_health.predictor.feature_spec import Feature, all_features
from climate_health.rest_api_src.data_models import FullPredictionResponse
from climate_health.rest_api_src.worker_functions import train_on_zip_file, train_on_json_data
from climate_health.training_control import TrainingControl
from dotenv import load_dotenv, find_dotenv

from climate_health.worker.background_tasks_worker import BGTaskWorker
from climate_health.worker.rq_worker import RedisQueue

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -97,23 +91,10 @@ def is_finished(self):
# worker = BGTaskWorker(BackgroundTasks(), internal_state, state)
worker = RedisQueue()


def set_cur_response(response):
state['response'] = response


class PredictionResponse(BaseModel):
value: float
orgUnit: str
dataElement: str
period: str


class FullPredictionResponse(BaseModel):
diseaseId: str
dataValues: List[PredictionResponse]


@app.get('favicon.ico')
async def favicon() -> FileResponse:
return FileResponse('chap_icon.jpeg')
Expand Down
27 changes: 27 additions & 0 deletions climate_health/rest_api_src/data_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import List

from pydantic import BaseModel


class PredictionBase(BaseModel):
orgUnit: str
dataElement: str
period: str


class PredictionResponse(PredictionBase):
value: float


class PredictionSamplResponse(PredictionBase):
values: list[float]


class FullPredictionResponse(BaseModel):
diseaseId: str
dataValues: List[PredictionResponse]


class FullPredictionSampleResponse(BaseModel):
diseaseId: str
dataValues: List[PredictionSamplResponse]
Loading

0 comments on commit 9ef2430

Please sign in to comment.