Skip to content

Commit

Permalink
introducing key locally_run to disable time consuming cell for github…
Browse files Browse the repository at this point in the history
… action
  • Loading branch information
SarahAlidoost committed Dec 19, 2024
1 parent 4492710 commit d92c03b
Showing 1 changed file with 93 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@
},
{
"cell_type": "markdown",
"id": "80411a9d-881e-4196-8559-17aaadd15841",
"id": "ddb1e4f0-2674-4242-bcd8-abf66f97c611",
"metadata": {},
"source": [
"#### 5 - Run the explainer at one location, several data instances (here as an example one month time series)\n",
Expand Down Expand Up @@ -805,6 +805,24 @@
"background_data = x_train.drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()"
]
},
{
"cell_type": "markdown",
"id": "8b612e55-e1ec-40dc-b189-65d90ffb2b1c",
"metadata": {},
"source": [
"This step takes a few minutes, so not suitable for github actions. If you want to run this step locally, set `locally_run = True`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "59a54eaa-f6f2-42b5-8849-aceb37b06156",
"metadata": {},
"outputs": [],
"source": [
"locally_run = False"
]
},
{
"cell_type": "code",
"execution_count": 14,
Expand All @@ -821,11 +839,12 @@
],
"source": [
"# run explainer over time series, this might take a few minutes\n",
"explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
"\n",
"print(\"Dianna is done!\") "
"if locally_run:\n",
" explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
" \n",
" print(\"Dianna is done!\") "
]
},
{
Expand All @@ -846,30 +865,31 @@
}
],
"source": [
"# create shap_values object\n",
"shap_values = Explanation(explanations[key])\n",
"shap_values.feature_names = features.columns\n",
"\n",
"# create comparison plot: predictions vs test data \n",
"y_predict_time = runner(features.to_numpy())\n",
"y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
"comparison_plot(y_test_time, y_predict_time, show=False) \n",
"comparison_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create summary plot\n",
"shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
"summary_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create heatmap plot\n",
"shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
"heatmap_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# plot all three figures in one cell\n",
"figures = [comparison_img, heatmap_img, summary_img]\n",
"display_figures(figures, captions, 1, 3)"
"if locally_run:\n",
" # create shap_values object\n",
" shap_values = Explanation(explanations[key])\n",
" shap_values.feature_names = features.columns\n",
" \n",
" # create comparison plot: predictions vs test data \n",
" y_predict_time = runner(features.to_numpy())\n",
" y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
" comparison_plot(y_test_time, y_predict_time, show=False) \n",
" comparison_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create summary plot\n",
" shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
" summary_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create heatmap plot\n",
" shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
" heatmap_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # plot all three figures in one cell\n",
" figures = [comparison_img, heatmap_img, summary_img]\n",
" display_figures(figures, captions, 1, 3)"
]
},
{
Expand All @@ -887,9 +907,10 @@
}
],
"source": [
"relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
"cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
"print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
"if locally_run:\n",
" relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
" cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
" print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
]
},
{
Expand Down Expand Up @@ -947,12 +968,13 @@
}
],
"source": [
"# run explainer over time series, this might take a few minutes\n",
"explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
"\n",
"print(\"Dianna is done!\") "
"if locally_run:\n",
" # run explainer over time series, this might take a few minutes\n",
" explanations[key] = dianna.explain_tabular(runner, input_tabular=features.values, method='kernelshap',\n",
" mode ='regression', training_data=background_data, training_data_kmeans=5,\n",
" feature_names=features.columns, silent=True)\n",
" \n",
" print(\"Dianna is done!\") "
]
},
{
Expand All @@ -973,30 +995,31 @@
}
],
"source": [
"# create shap_values object\n",
"shap_values = Explanation(explanations[key])\n",
"shap_values.feature_names = features.columns\n",
"\n",
"# create comparison plot: predictions vs test data \n",
"y_predict_time = runner(features.to_numpy())\n",
"y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
"comparison_plot(y_test_time, y_predict_time, show=False) \n",
"comparison_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create summary plot\n",
"shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
"summary_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# create heatmap plot\n",
"shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
"heatmap_img = plt.gcf()\n",
"plt.close()\n",
"\n",
"# plot all three figures in one cell\n",
"figures = [comparison_img, heatmap_img, summary_img]\n",
"display_figures(figures, captions, 1, 3)"
"if locally_run:\n",
" # create shap_values object\n",
" shap_values = Explanation(explanations[key])\n",
" shap_values.feature_names = features.columns\n",
" \n",
" # create comparison plot: predictions vs test data \n",
" y_predict_time = runner(features.to_numpy())\n",
" y_test_time = y_test[(y_test[\"station\"] == location) & (y_test[\"date_UTC\"].dt.month == month)].drop(columns=['station', 'date_UTC']).fillna(0).to_numpy()\n",
" comparison_plot(y_test_time, y_predict_time, show=False) \n",
" comparison_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create summary plot\n",
" shap.summary_plot(shap_values, features.values, feature_names=features.columns, cmap=\"PRGn\", show=False, max_display=15)\n",
" summary_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # create heatmap plot\n",
" shap.plots.heatmap(shap_values, cmap=\"bwr\", show=False, max_display=15)\n",
" heatmap_img = plt.gcf()\n",
" plt.close()\n",
" \n",
" # plot all three figures in one cell\n",
" figures = [comparison_img, heatmap_img, summary_img]\n",
" display_figures(figures, captions, 1, 3)"
]
},
{
Expand All @@ -1014,9 +1037,10 @@
}
],
"source": [
"relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
"cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
"print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
"if locally_run:\n",
" relative_mae = np.mean(np.abs(y_predict_time - y_test_time))/ np.mean(y_test_time)\n",
" cor = np.corrcoef(y_predict_time.T, y_test_time.T)[0,1]\n",
" print(f\"Relative MAE is {relative_mae} and correlation is {cor}\")"
]
},
{
Expand Down Expand Up @@ -1166,6 +1190,9 @@
}
],
"metadata": {
"execution": {
"timeout": 1800
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
Expand All @@ -1182,9 +1209,6 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
},
"execution": {
"timeout": 1800
}
},
"nbformat": 4,
Expand Down

0 comments on commit d92c03b

Please sign in to comment.