Skip to content

Commit

Permalink
Merge pull request #772 from dianna-ai/750-fix-colab-paths
Browse files Browse the repository at this point in the history
Fix data file paths for Colab
  • Loading branch information
loostrum authored May 29, 2024
2 parents 2e7a10c + 4db827f commit 79f4b87
Show file tree
Hide file tree
Showing 17 changed files with 211 additions and 206 deletions.
4 changes: 2 additions & 2 deletions dianna/dashboard/_movie_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import numpy as np
from _shared import label_directory
from _shared import data_directory
from scipy.special import expit as sigmoid
from torchtext.vocab import Vectors
from dianna import utils
Expand All @@ -13,7 +13,7 @@ class MovieReviewsModelRunner:
def __init__(self, model, word_vectors=None, max_filter_size=5):
"""Initializes the class."""
if word_vectors is None:
word_vectors = label_directory / 'movie_reviews_word_vectors.txt'
word_vectors = data_directory / 'movie_reviews_word_vectors.txt'

self.run_model = utils.get_function(model)
self.vocab = Vectors(word_vectors, cache=os.path.dirname(word_vectors))
Expand Down
File renamed without changes.
23 changes: 9 additions & 14 deletions tutorials/explainers/KernelSHAP/kernelshap_geometric_shapes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,8 @@
"source": [
"running_in_colab = 'google.colab' in str(get_ipython())\n",
"if running_in_colab:\n",
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]\n",
" \n",
" # download data used in this demo\n",
" import os \n",
" base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'\n",
" paths_to_download = ['./data/shapes.npz', './models/geometric_shapes_model.onnx']\n",
" for path in paths_to_download:\n",
" !wget {base_url + path} -P {os.path.dirname(path)}"
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]"
]
},
{
Expand All @@ -64,14 +57,16 @@
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import warnings\n",
"warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf\n",
"import numpy as np\n",
"import dianna\n",
"import onnx\n",
"from onnx_tf.backend import prepare\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path"
"\n",
"root_dir = Path(dianna.__file__).parent"
]
},
{
Expand Down Expand Up @@ -108,7 +103,7 @@
"outputs": [],
"source": [
"# load dataset\n",
"data = np.load(Path('..','..','..','dianna', 'data', 'shapes.npz'))\n",
"data = np.load(Path(root_dir, 'data', 'shapes.npz'))\n",
"# load testing data and the related labels\n",
"X_test = data['X_test'].astype(np.float32).reshape([-1, 1, 64, 64])\n",
"y_test = data['y_test']"
Expand Down Expand Up @@ -136,7 +131,7 @@
"outputs": [],
"source": [
"# Load saved onnx model\n",
"onnx_model_path = Path('..','..','..','dianna','models', 'geometric_shapes_model.onnx')\n",
"onnx_model_path = Path(root_dir, 'models', 'geometric_shapes_model.onnx')\n",
"onnx_model = onnx.load(onnx_model_path)\n",
"# get the output node\n",
"output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]"
Expand Down Expand Up @@ -366,7 +361,7 @@
"hash": "e7604e8ec5f09e490e10161e37a4725039efd3ab703d81b1b8a1e00d6741866c"
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -380,7 +375,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.9.1"
}
},
"nbformat": 4,
Expand Down
23 changes: 9 additions & 14 deletions tutorials/explainers/KernelSHAP/kernelshap_mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,8 @@
"source": [
"running_in_colab = 'google.colab' in str(get_ipython())\n",
"if running_in_colab:\n",
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]\n",
" \n",
" # download data used in this demo\n",
" import os \n",
" base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'\n",
" paths_to_download = ['./data/binary-mnist.npz', './models/mnist_model_tf.onnx']\n",
" for path in paths_to_download:\n",
" !wget {base_url + path} -P {os.path.dirname(path)}"
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]"
]
},
{
Expand All @@ -64,14 +57,16 @@
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import warnings\n",
"warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf\n",
"import numpy as np\n",
"import dianna\n",
"import onnx\n",
"from onnx_tf.backend import prepare\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path"
"\n",
"root_dir = Path(dianna.__file__).parent"
]
},
{
Expand Down Expand Up @@ -108,7 +103,7 @@
"outputs": [],
"source": [
"# load dataset\n",
"data = np.load(Path('..','..','..','dianna','data', 'binary-mnist.npz'))\n",
"data = np.load(Path(root_dir, 'data', 'binary-mnist.npz'))\n",
"# load testing data and the related labels\n",
"X_test = data['X_test'].astype(np.float32).reshape([-1, 28, 28, 1]) / 255\n",
"y_test = data['y_test']"
Expand Down Expand Up @@ -136,7 +131,7 @@
"outputs": [],
"source": [
"# Load saved onnx model\n",
"onnx_model_path = Path('..','..','..','dianna','models', 'mnist_model_tf.onnx')\n",
"onnx_model_path = Path(root_dir, 'models', 'mnist_model_tf.onnx')\n",
"onnx_model = onnx.load(onnx_model_path)\n",
"# get the output node\n",
"output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]"
Expand Down Expand Up @@ -333,7 +328,7 @@
"hash": "e7604e8ec5f09e490e10161e37a4725039efd3ab703d81b1b8a1e00d6741866c"
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -347,7 +342,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.9.1"
}
},
"nbformat": 4,
Expand Down
22 changes: 9 additions & 13 deletions tutorials/explainers/KernelSHAP/kernelshap_tabular_penguin.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,8 @@
"source": [
"running_in_colab = 'google.colab' in str(get_ipython())\n",
"if running_in_colab:\n",
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]\n",
" \n",
" # download data used in this demo\n",
" import os \n",
" base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'\n",
" paths_to_download = ['./models/penguin_model.onnx']\n",
" for path in paths_to_download:\n",
" !wget {base_url + path} -P {os.path.dirname(path)}"
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]"
]
},
{
Expand All @@ -49,6 +42,7 @@
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import dianna\n",
"import numpy as np\n",
"import pandas as pd\n",
Expand All @@ -59,7 +53,9 @@
"from numba.core.errors import NumbaDeprecationWarning\n",
"import warnings\n",
"# silence the Numba deprecation warnings in shap\n",
"warnings.simplefilter('ignore', category=NumbaDeprecationWarning)"
"warnings.simplefilter('ignore', category=NumbaDeprecationWarning)\n",
"\n",
"root_dir = Path(dianna.__file__).parent"
]
},
{
Expand Down Expand Up @@ -308,7 +304,7 @@
],
"source": [
"# load onnx model and check the prediction with it\n",
"model_path = '../../../dianna/models/penguin_model.onnx'\n",
"model_path = Path(root_dir, 'models', 'penguin_model.onnx')\n",
"loaded_model = SimpleModelRunner(model_path)\n",
"predictions = loaded_model(data_instance.reshape(1,-1).astype(np.float32))\n",
"species[np.argmax(predictions)]"
Expand Down Expand Up @@ -411,7 +407,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -425,7 +421,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.9.1"
}
},
"nbformat": 4,
Expand Down
20 changes: 8 additions & 12 deletions tutorials/explainers/KernelSHAP/kernelshap_tabular_weather.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,8 @@
"source": [
"running_in_colab = 'google.colab' in str(get_ipython())\n",
"if running_in_colab:\n",
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]\n",
" \n",
" # download data used in this demo\n",
" import os\n",
" base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'\n",
" paths_to_download = ['./models/sunshine_hours_regression_model.onnx']\n",
" for path in paths_to_download:\n",
" !wget {base_url + path} -P {os.path.dirname(path)}"
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]"
]
},
{
Expand All @@ -51,6 +44,7 @@
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import dianna\n",
"import numpy as np\n",
"import pandas as pd\n",
Expand All @@ -60,7 +54,9 @@
"from numba.core.errors import NumbaDeprecationWarning\n",
"import warnings\n",
"# silence the Numba deprecation warnings in shap\n",
"warnings.simplefilter('ignore', category=NumbaDeprecationWarning)"
"warnings.simplefilter('ignore', category=NumbaDeprecationWarning)\n",
"\n",
"root_dir = Path(dianna.__file__).parent"
]
},
{
Expand Down Expand Up @@ -257,7 +253,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -271,7 +267,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.9.1"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 79f4b87

Please sign in to comment.