Skip to content

Commit

Permalink
Merge pull request #584 from dianna-ai/notebooks
Browse files Browse the repository at this point in the history
Notebooks
  • Loading branch information
APJansen authored May 9, 2023
2 parents 35145d4 + e807a5f commit c1e8c49
Show file tree
Hide file tree
Showing 12 changed files with 1,950 additions and 995 deletions.
11 changes: 11 additions & 0 deletions tutorials/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,14 @@ The datasets used in the tutorials are represented with their respective logo:
| [Stanford sentiment treebank](https://nlp.stanford.edu/sentiment/index.html) | <img width="25" alt="nlp-logo_half_size" src="https://user-images.githubusercontent.com/3244249/152540890-c8e1e37d-f0cc-4f84-80a4-2c59176cbf4c.png">|

The models used in the tutorials are available at [tutorials/models](https://github.com/dianna-ai/dianna/tree/main/tutorials/models).


## Colab
The tutorials can also be run directly in Google Colab, by clicking on the links/buttons below, or for a general demo here: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/demo.ipynb).

| modality \ method | RISE | LIME | KernelSHAP |
|-------------------|------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------|
| images | [mnist](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/rise_mnist.ipynb), [imagenet](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/rise_imagenet.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/lime_images.ipynb) | [mnist](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/kernelshap_mnist.ipynb), [geometric shapes](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/kernelshap_geometric_shapes.ipynb) |
| text | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/rise_text.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/lime_text.ipynb) | - |
| timeseries | [weather](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/rise_timeseries_weather.ipynb) | [weather](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/lime_timeseries_weather.ipynb), [coffee](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/lime_timeseries_coffee.ipynb) | - |

35 changes: 31 additions & 4 deletions tutorials/demo.ipynb

Large diffs are not rendered by default.

112 changes: 87 additions & 25 deletions tutorials/kernelshap_geometric_shapes.ipynb

Large diffs are not rendered by default.

106 changes: 82 additions & 24 deletions tutorials/kernelshap_mnist.ipynb

Large diffs are not rendered by default.

65 changes: 39 additions & 26 deletions tutorials/lime_images.ipynb

Large diffs are not rendered by default.

116 changes: 93 additions & 23 deletions tutorials/lime_text.ipynb

Large diffs are not rendered by default.

36 changes: 35 additions & 1 deletion tutorials/lime_timeseries_coffee.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,40 @@
"*NOTE*: This tutorial is still work-in-progress, the final results need to be improved by tweaking the LIME parameters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Colab Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"running_in_colab = 'google.colab' in str(get_ipython())\n",
"if running_in_colab:\n",
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]\n",
" \n",
" # download data used in this demo\n",
" import os \n",
" base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/tutorials/'\n",
" paths_to_download = ['data/coffee_train.csv', 'data/coffee_test.csv', 'models/coffee.onnx']\n",
" for path in paths_to_download:\n",
" !wget {base_url + path} -P {os.path.dirname(path)}"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -442,7 +476,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.9.12"
},
"orig_nbformat": 4
},
Expand Down
36 changes: 35 additions & 1 deletion tutorials/lime_timeseries_weather.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,40 @@
"*NOTE*: This tutorial is still work-in-progress, the final results need to be improved by tweaking the LIME parameters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Colab Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"running_in_colab = 'google.colab' in str(get_ipython())\n",
"if running_in_colab:\n",
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]\n",
" \n",
" # download data used in this demo\n",
" import os \n",
" base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/tutorials/'\n",
" paths_to_download = ['models/season_prediction_model_temp_max_binary.onnx']\n",
" for path in paths_to_download:\n",
" !wget {base_url + path} -P {os.path.dirname(path)}"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -1065,7 +1099,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.9.12"
},
"orig_nbformat": 4
},
Expand Down
104 changes: 69 additions & 35 deletions tutorials/rise_imagenet.ipynb

Large diffs are not rendered by default.

76 changes: 61 additions & 15 deletions tutorials/rise_mnist.ipynb

Large diffs are not rendered by default.

124 changes: 102 additions & 22 deletions tutorials/rise_text.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,35 @@
"*NOTE*: This tutorial is still work-in-progress, the final results need to be improved by tweaking the RISE parameters"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "40dc5e32",
"metadata": {},
"source": [
"#### Colab Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "236ca562",
"metadata": {},
"outputs": [],
"source": [
"running_in_colab = 'google.colab' in str(get_ipython())\n",
"if running_in_colab:\n",
" # install dianna\n",
" !python3 -m pip install dianna[notebooks]\n",
" \n",
" # download data used in this demo\n",
" import os \n",
" base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/tutorials/'\n",
" paths_to_download = ['data/movie_reviews_word_vectors.txt', 'models/movie_review_model.onnx']\n",
" for path in paths_to_download:\n",
" !wget {base_url + path} -P {os.path.dirname(path)}"
]
},
{
"cell_type": "markdown",
"id": "a5cf6f82-c1c7-4814-ae0f-5a1c0b8578f6",
Expand All @@ -27,10 +56,19 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "34b556d8-5337-44dc-8efe-14d1dff6f011",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import os\n",
"import matplotlib.pyplot as plt\n",
Expand All @@ -48,7 +86,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "c616916c-78ef-48d0-a744-b25b37b62a3f",
"metadata": {},
"outputs": [],
Expand All @@ -71,20 +109,62 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "486540bd-2676-4dfa-bbe8-ee8aa289acd3",
"metadata": {
"tags": []
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting en-core-web-sm==3.2.0\n",
" Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.2.0/en_core_web_sm-3.2.0-py3-none-any.whl (13.9 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.9/13.9 MB 2.2 MB/s eta 0:00:00\n",
"Requirement already satisfied: spacy<3.3.0,>=3.2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from en-core-web-sm==3.2.0) (3.2.4)\n",
"Requirement already satisfied: click<8.1.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (8.0.4)\n",
"Requirement already satisfied: blis<0.8.0,>=0.4.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.7.7)\n",
"Requirement already satisfied: numpy>=1.15.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.21.6)\n",
"Requirement already satisfied: jinja2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.1.1)\n",
"Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.8 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.9)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.6)\n",
"Requirement already satisfied: packaging>=20.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (21.3)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.6)\n",
"Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.3.0)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.27.1)\n",
"Requirement already satisfied: pathy>=0.3.5 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.6.1)\n",
"Requirement already satisfied: setuptools in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (62.1.0)\n",
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.0.6)\n",
"Requirement already satisfied: thinc<8.1.0,>=8.0.12 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (8.0.15)\n",
"Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.7)\n",
"Requirement already satisfied: srsly<3.0.0,>=2.4.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.4.3)\n",
"Requirement already satisfied: typer<0.5.0,>=0.3.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.4.1)\n",
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (4.64.0)\n",
"Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.0.2)\n",
"Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.9.1)\n",
"Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.8.2)\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from packaging>=20.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.8)\n",
"Requirement already satisfied: smart-open<6.0.0,>=5.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from pathy>=0.3.5->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (5.2.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (4.1.1)\n",
"Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.3)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.26.9)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.12)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2021.10.8)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from jinja2->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.1.1)\n",
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
"You can now load the package via spacy.load('en_core_web_sm')\n"
]
}
],
"source": [
"# ensure the tokenizer for english is available\n",
"spacy.cli.download('en_core_web_sm')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "555842c5-3f82-4f63-93bb-696645d4b447",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -126,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "443e8a99-6fa3-4a73-9311-2fbe0251c2b1",
"metadata": {},
"outputs": [],
Expand All @@ -152,7 +232,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "7fc6ebcb-2328-4c06-ae67-c5590032eb69",
"metadata": {},
"outputs": [],
Expand All @@ -162,7 +242,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "7c0bfd7d-df1d-4981-b714-496bc16b9347",
"metadata": {},
"outputs": [
Expand All @@ -177,23 +257,23 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Explaining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:17<00:00, 1.72s/it]\n"
"Explaining: 100%|██████████| 10/10 [00:03<00:00, 2.75it/s]\n"
]
},
{
"data": {
"text/plain": [
"[('A', 0, 0.7158780014514923),\n",
" ('delectable', 1, 0.913871341049671),\n",
" ('and', 2, 0.6892129376530648),\n",
" ('intriguing', 3, 1.0620161551237106),\n",
" ('thriller', 4, 0.840078490972519),\n",
" ('filled', 5, 0.6051010835170746),\n",
" ('with', 6, 0.6926153092086315),\n",
" ('surprises', 7, 0.6697717276215553)]"
"[('A', 0, 0.5653130280971527),\n",
" ('delectable', 1, 0.8641307824850082),\n",
" ('and', 2, 0.7081780250370502),\n",
" ('intriguing', 3, 1.004394978582859),\n",
" ('thriller', 4, 0.9396217280626297),\n",
" ('filled', 5, 0.6516930902004242),\n",
" ('with', 6, 0.7476113395392894),\n",
" ('surprises', 7, 0.7425235873460769)]"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -217,14 +297,14 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "0136005d-a22f-43a0-80da-4ec1f283f870",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<html><body><span style=\"background:rgba(255, 0, 0, 0.54)\">A</span> <span style=\"background:rgba(255, 0, 0, 0.69)\">delectable</span> <span style=\"background:rgba(255, 0, 0, 0.52)\">and</span> <span style=\"background:rgba(255, 0, 0, 0.80)\">intriguing</span> <span style=\"background:rgba(255, 0, 0, 0.63)\">thriller</span> <span style=\"background:rgba(255, 0, 0, 0.46)\">filled</span> <span style=\"background:rgba(255, 0, 0, 0.52)\">with</span> <span style=\"background:rgba(255, 0, 0, 0.50)\">surprises</span></body></html>"
"<mark style=\"background-color: hsl(0, 100%, 72%, 0.8); line-height:1.75\">A</mark> <mark style=\"background-color: hsl(0, 100%, 57%, 0.8); line-height:1.75\">delectable</mark> <mark style=\"background-color: hsl(0, 100%, 65%, 0.8); line-height:1.75\">and</mark> <mark style=\"background-color: hsl(0, 100%, 50%, 0.8); line-height:1.75\">intriguing</mark> <mark style=\"background-color: hsl(0, 100%, 54%, 0.8); line-height:1.75\">thriller</mark> <mark style=\"background-color: hsl(0, 100%, 68%, 0.8); line-height:1.75\">filled</mark> <mark style=\"background-color: hsl(0, 100%, 63%, 0.8); line-height:1.75\">with</mark> <mark style=\"background-color: hsl(0, 100%, 64%, 0.8); line-height:1.75\">surprises</mark>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand Down Expand Up @@ -263,7 +343,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.12"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit c1e8c49

Please sign in to comment.