Skip to content

Commit

Permalink
Merge pull request #340 from AIStream-Peelout/informer_shap
Browse files Browse the repository at this point in the history
Informer support
  • Loading branch information
isaacmg authored May 14, 2021
2 parents 44208e7 + 9d6602f commit 91393a8
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 21 deletions.
4 changes: 3 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,10 @@ jobs:
name: Trainer tests
when: always
command: |
coverage run flood_forecast/trainer.py -p tests/test_inf_single.json
echo -e 'test informer single target'
coverage run flood_forecast/trainer.py -p tests/test_informer.json
echo -e 'test informer'
echo -e 'test multi informer'
coverage run flood_forecast/trainer.py -p tests/transformer_gaussian.json
coverage run flood_forecast/trainer.py -p tests/multi_decoder_test.json
echo -e 'training multi-task-decoder'
Expand Down
5 changes: 5 additions & 0 deletions docs/source/explain_model_output.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Explain Model Output
=================

.. automodule:: flood_forecast.explain_model_output
:members:
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ of datasets.

basic_utils

.. automodule:: flood_forecast.explain_model_output

.. automodule:: flood_forecast.da_rnn

.. toctree::
Expand Down
8 changes: 5 additions & 3 deletions flood_forecast/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def evaluate_model(
from flood_forecast.evaluator import evaluate_model
forecast_model = PyTorchForecast(config_file)
evaluate_model(forecast_model, "PyTorch", ["cfs"], ["MSE", "MAPE"], {})
e_log, df_train_test, f_idx, df_preds = evaluate_model(forecast_model, "PyTorch", ["cfs"], ["MSE", "MAPE"], {})
print(e_log) # {"MSE":0.2, "MAPE":0.1}
print(df_train_test) #
...
'''
"""
Expand Down Expand Up @@ -341,7 +343,7 @@ def handle_ci_multi(prediction_samples: torch.Tensor, csv_test_loader: CSVTestLo
:type num_samples: int
:raises ValueError: [description]
:raises ValueError: [description]
:return: [description]
:return: Returns an array with different CI predictions
:rtype: List[pd.DataFrame]
"""
df_prediction_arr = []
Expand Down Expand Up @@ -454,7 +456,7 @@ def generate_predictions_non_decoded(
) -> torch.Tensor:
"""Generates predictions for the models that do not use a decoder
:param model: [description]
:param model: A PyTorchForecast
:type model: Type[TimeSeriesModel]
:param df: [description]
:type df: pd.DataFrame
Expand Down
79 changes: 64 additions & 15 deletions flood_forecast/explain_model_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from datetime import datetime
from typing import Optional
from typing import Optional, Tuple
import numpy as np
import shap
import torch
Expand All @@ -17,6 +17,29 @@
BACKGROUND_BATCH_SIZE = 5


def handle_dl_output(dl, dl_class: str, datetime_start: datetime, device: str) -> Tuple[torch.Tensor, int]:
"""
:param dl: The test data-loader. Should be passed directly
:type dl: Union[CSVTestLoader, TemporalTestLoader]
:param dl_class: A string that is the name of DL passef from the params file.
:type dl_class: str
:param datetime_start: The start datetime for the forecast
:type datetime_start: datetime
:param device: Typical device should be either cpu or cuda
:type device: str
:return: Returns a tuple containing either a..
:rtype: Tuple[torch.Tensor, int]
"""
if dl_class == "TemporalLoader":
his, tar, _, forecast_start_idx = dl.get_from_start_date(datetime_start)
history = [his[0].unsqueeze(0), his[1].unsqueeze(0), tar[1].unsqueeze(0), tar[0].unsqueeze(0)]
else:
history, _, forecast_start_idx = dl.get_from_start_date(datetime_start)
history = history.to(device).unsqueeze(0)
return history, forecast_start_idx


def _prepare_background_tensor(
csv_test_loader: CSVTestLoader, backgound_batch_size: int = BACKGROUND_BATCH_SIZE
) -> torch.Tensor:
Expand Down Expand Up @@ -67,15 +90,23 @@ def deep_explain_model_summary_plot(
if datetime_start is None:
datetime_start = model.params["inference_params"]["datetime_start"]

history, _, forecast_start_idx = csv_test_loader.get_from_start_date(datetime_start)
history, forecast_start_idx = handle_dl_output(csv_test_loader, model.params["dataset_params"]["class"],
datetime_start, device)
background_tensor = _prepare_background_tensor(csv_test_loader)
background_tensor = background_tensor.to(device)
model.model.eval()

# background shape (L, N, M)
# L - batch size, N - history length, M - feature size
deep_explainer = shap.DeepExplainer(model.model, background_tensor)
shap_values = deep_explainer.shap_values(background_tensor)
s_values_list = []
if isinstance(history, list):
deep_explainer = shap.DeepExplainer(model.model, history)
shap_values = deep_explainer.shap_values(history)
s_values_list.append(shap_values)
else:
deep_explainer = shap.DeepExplainer(model.model, background_tensor)
shap_values = deep_explainer.shap_values(background_tensor)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values)
# shap_values needs to be 4-dimensional
if len(shap_values.shape) != 4:
Expand Down Expand Up @@ -103,13 +134,17 @@ def deep_explain_model_summary_plot(
wandb.log({"Overall feature ranking per prediction time-step": fig})

# summary plot for one prediction at datetime_start
if isinstance(history, list):
hist = history[0]
else:
hist = history

history = history.to(device).unsqueeze(0)
history_numpy = torch.tensor(
history.cpu().numpy(), names=["batches", "observations", "features"]
hist.cpu().numpy(), names=["batches", "observations", "features"]
)

shap_values = deep_explainer.shap_values(history)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values)
if len(shap_values.shape) != 4:
shap_values = np.expand_dims(shap_values, axis=0)
Expand All @@ -128,6 +163,13 @@ def deep_explain_model_summary_plot(
)


def fix_shap_values(shap_values, history):
if isinstance(history, list):
shap_values = list(zip(*shap_values))[0]
return shap_values
return shap_values


def deep_explain_model_heatmap(
model, csv_test_loader: CSVTestLoader, datetime_start: Optional[datetime] = None
) -> None:
Expand All @@ -150,13 +192,15 @@ def deep_explain_model_heatmap(
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if model.params["model_name"] == "DARNN" and device.type == "cuda":
# TO-DO check if this is still true
print("Currently DARNN doesn't work with shap on CUDA")
return

if datetime_start is None:
datetime_start = model.params["inference_params"]["datetime_start"]

history, _, forecast_start_idx = csv_test_loader.get_from_start_date(datetime_start)
history, forecast_start_idx = handle_dl_output(csv_test_loader, model.params["dataset_params"]["class"],
datetime_start, device)
background_tensor = _prepare_background_tensor(csv_test_loader)
background_tensor = background_tensor.to(device)
model.model.eval()
Expand All @@ -165,11 +209,16 @@ def deep_explain_model_heatmap(
# L - batch size, N - history length, M - feature size
# for each element in each N x M batch in L,
# attribute to each prediction in forecast len
deep_explainer = shap.DeepExplainer(model.model, background_tensor)
shap_values = deep_explainer.shap_values(
background_tensor
) # forecast_len x N x L x M
shap_values = np.stack(shap_values)
s_values_list = []
if isinstance(history, list):
deep_explainer = shap.DeepExplainer(model.model, history)
shap_values = deep_explainer.shap_values(history)
s_values_list.append(shap_values)
else:
deep_explainer = shap.DeepExplainer(model.model, background_tensor)
shap_values = deep_explainer.shap_values(background_tensor)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values) # forecast_len x N x L x M
if len(shap_values.shape) != 4:
shap_values = np.expand_dims(shap_values, axis=0)
shap_values = torch.tensor(
Expand All @@ -182,15 +231,15 @@ def deep_explain_model_heatmap(

# heatmap one prediction sequence at datetime_start
# (seq_len*forecast_len) per fop feature
to_explain = history.to(device).unsqueeze(0)
to_explain = history
shap_values = deep_explainer.shap_values(to_explain)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values)
if len(shap_values.shape) != 4:
shap_values = np.expand_dims(shap_values, axis=0)
shap_values = torch.tensor(
shap_values, names=["preds", "batches", "observations", "features"]
)

) # no fake ballo t
figs = plot_shap_value_heatmaps(shap_values)
if use_wandb:
for fig, feature in zip(figs, csv_test_loader.df.columns):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name='flood_forecast',
version='0.956dev',
version='0.97dev',
packages=[
'flood_forecast',
'flood_forecast.transformer_xl',
Expand Down
25 changes: 24 additions & 1 deletion tests/test_explain_model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from flood_forecast.explain_model_output import (
deep_explain_model_heatmap,
deep_explain_model_summary_plot,
handle_dl_output
)
from flood_forecast.preprocessing.pytorch_loaders import CSVTestLoader
from flood_forecast.preprocessing.pytorch_loaders import CSVTestLoader, TemporalTestLoader
from flood_forecast.time_model import PyTorchForecast


Expand Down Expand Up @@ -136,6 +137,28 @@ def test_deep_explain_model_heatmap(self):
# dummy assert
self.assertEqual(1, 1)

def test_handle_dl(self):
params_dict = {}
params_dict["kwargs"] = {
"file_path": "tests/test_data/keag_small.csv",
"forecast_history": 5,
"forecast_length": 5,
"no_scale": True,
"relevant_cols": ["cfs", "precip", "temp"],
"sort_column": "datetime",
"feature_params": {
"datetime_params": {
"hour": "numerical"
}
},
"target_col": ["cfs"],
"interpolate_param": False}
params_dict["df_path"] = self.keag_file
params_dict["forecast_total"] = 35
t = TemporalTestLoader(["hour"], params_dict)
self.assertIsInstance(handle_dl_output(self.csv_test_loader, "normal", datetime(2014, 6, 2, 0), "cpu"), tuple)
self.assertIsInstance(handle_dl_output(t, "TemporalLoader", datetime(2014, 6, 2, 0), "cpu")[0], list)
# self.assertIsEqual(len(handle_dl_output(t, "TemporalLoader")), 3)

if __name__ == "__main__":
unittest.main()
100 changes: 100 additions & 0 deletions tests/test_inf_single.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
{
"model_name": "Informer",
"use_decoder": true,
"model_type": "PyTorch",
"model_params": {
"n_time_series":3,
"dec_in":3,
"c_out": 1,
"seq_len":10,
"label_len":10,
"out_len":2,
"factor":2
},
"dataset_params":
{ "class": "TemporalLoader",
"temporal_feats": ["month", "day", "day_of_week", "hour"],
"training_path": "tests/test_data/keag_small.csv",
"validation_path": "tests/test_data/keag_small.csv",
"test_path": "tests/test_data/keag_small.csv",
"batch_size":4,
"forecast_history":10,
"forecast_length":2,
"train_end": 200,
"valid_start":201,
"valid_end": 220,
"test_start":299,
"test_end": 400,
"target_col": ["cfs"],
"relevant_cols": ["cfs", "precip", "temp"],
"scaler": "StandardScaler",
"sort_column":"datetime",
"interpolate": false,
"feature_param":
{
"datetime_params":{
"month": "numerical",
"day": "numerical",
"day_of_week": "numerical",
"hour":"numerical"
}
}
},
"early_stopping":
{
"patience":2

},
"training_params":
{
"criterion":"MSE",
"optimizer": "Adam",
"optim_params":
{

},
"lr": 0.3,
"epochs": 4,
"batch_size":4

},
"GCS": false,

"wandb": {
"name": "flood_forecast_circleci",
"tags": ["dummy_run", "circleci"],
"project":"repo-flood_forecast"
},
"forward_params":{
},
"metrics":["MSE"],
"inference_params":
{
"datetime_start":"2016-05-31",
"hours_to_forecast":336,
"test_csv_path":"tests/test_data/keag_small.csv",
"decoder_params":{
"decoder_function": "greedy_decode",
"unsqueeze_dim": 1},
"dataset_params":{
"file_path": "tests/test_data/keag_small.csv",
"forecast_history":10,
"forecast_length":2,
"relevant_cols": ["cfs", "precip", "temp"],
"target_col": ["cfs"],
"scaling": "StandardScaler",
"interpolate_param": false,
"sort_column":"datetime",
"feature_params":
{
"datetime_params":{
"month": "numerical",
"day": "numerical",
"day_of_week": "numerical",
"hour":"numerical"
}
}
}
}
}

1 change: 1 addition & 0 deletions tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class JoinTest(unittest.TestCase):
def setUp(self):
self.test_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data")
# stuff

def test_join_function(self):
df = pd.read_csv(os.path.join(self.test_data_path, "fake_test_small.csv"), sep="\t")
Expand Down

0 comments on commit 91393a8

Please sign in to comment.