Skip to content

Commit

Permalink
merge main into feature branch
Browse files Browse the repository at this point in the history
  • Loading branch information
cwmeijer committed May 29, 2024
2 parents 17bf17d + 209a0ed commit ef893e1
Show file tree
Hide file tree
Showing 113 changed files with 7,926 additions and 21,417 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.4.0
current_version = 1.5.0

[comment]
comment = The contents of this file cannot be merged with that of setup.cfg until https://github.com/c4urself/bump2version/issues/185 is resolved
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ coverage.xml
.tox
*word_vectors.txt.pt

# tutorial model that is downloaded automatically
apertif_frb_dynamic_spectrum_model.onnx

docs/_build

# ide
Expand All @@ -36,4 +39,4 @@ venv3
.python-version

cache/
dashboard/cache/
dashboard/cache/
6 changes: 5 additions & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ authors:
family-names: Meijer
given-names: Christiaan
orcid: "https://orcid.org/0000-0002-5529-5761"
-
family-names: Alidoost
given-names: Fakhereh (Sarah)
orcid: "https://orcid.org/0000-0001-8407-6472"
-
family-names: Oostrum
given-names: Leon
Expand Down Expand Up @@ -49,7 +53,7 @@ authors:
name-particle: "van der"

doi: 10.5281/zenodo.5801485
version: "1.4.0"
version: "1.5.0"
repository-code: "https://github.com/dianna-ai/dianna"
keywords:
- XAI
Expand Down
53 changes: 36 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ After studying the vast XAI landscape we have made choices in the parts of the [

The key points of DIANNA:

* Provides an easy-to-use interface for non (X)AI experts
* Implements well-known XAI methods (LIME, RISE and Kernal SHAP) chosen by systematic and objective evaluation criteria
* Supports the de-facto standard format for neural network models - ONNX.
* Includes clear instructions for export/conversions from Tensorflow, Pytorch, Keras and scikit-learn to ONNX.
* Supports images, text and time series data modalities. Tabular data and even embeddings support is planned.
* Comes with simple intuitive image and text benchmarks
* Easily extendable to other XAI methods
* Provides an easy-to-use interface for non (X)AI experts
* Implements well-known XAI methods LIME, RISE and KernelSHAP, chosen by systematic and objective evaluation criteria
* Supports the de-facto standard of neural network models - ONNX
* Supports images, text, time series, and tabular data modalities, embeddings are currently being developed
* Comes with simple intuitive image, text, time series, and tabular benchmarks, so can help you with your XAI research
* Includes scientific use-cases tutorials
* Easily extendable to other XAI methods


For more information on the unique strengths of DIANNA with comparison to other tools, please see the [context landscape](https://dianna.readthedocs.io/en/latest/CONTEXT.html).

Expand Down Expand Up @@ -196,7 +197,7 @@ explanation = dianna.explain_timeseries(model_path, timeseries_data=timeseries_i

```

For visualization of the heatmap please refer to the [tutorial](https://github.com/dianna-ai/dianna/blob/main/tutorials/lime_timeseries_coffee.ipynb)
For visualization of the heatmap please refer to the [tutorial](https://github.com/dianna-ai/dianna/blob/main/tutorials/explainers/LIME/lime_timeseries_coffee.ipynb)

### Tabular example:

Expand All @@ -216,6 +217,10 @@ plot_tabular(explanation, X_test.columns, num_features=10) # display 10 most sa

![image](https://github.com/dianna-ai/dianna/assets/25911757/ce0b76b8-f00c-468a-9732-c21704e289f6)

### IMPORTANT: Sensitivity to hyperparameters
The XAI methods (explainers) are sensitive to the choice of their hyperparameters! In this [work](https://staff.fnwi.uva.nl/a.s.z.belloum/MSctheses/MScthesis_Willem_van_der_Spec.pdf), this sensitivity to hyperparameters is researched and useful conclusions are drawn.
The default hyperparameters used in DIANNA for each explainer as well as the values for our tutorial examples are given in the Tutorials [README](./tutorials/README.md#important-hyperparameters).

## Dashboard

Explore the explanations of your trained model using the DIANNA dashboard (for now images, text and time series classification is supported).
Expand Down Expand Up @@ -247,7 +252,7 @@ DIANNA comes with simple datasets. Their main goal is to provide intuitive insig

| Dataset | Description | Examples | Generation |
| :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------ |
| [Coffee dataset](https://timeseriesclassification.com/description.php?Dataset=Coffee) <img width="25" alt="Coffe Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/9ab50a0f-5da3-41d2-80e9-70d2c8769162"> | Food spectographs time series dataset for a two class problem to distinguish between Robusta and Arabica coffee beans. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/763002c5-40ad-48cc-9de0-ea43d7fa8a75)"> | [data source](https://github.com/QIBChemometrics/Benchtop-NMR-Coffee-Survey) |
| Coffee dataset <img width="25" alt="Coffe Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/9ab50a0f-5da3-41d2-80e9-70d2c8769162"> | Food spectographs time series dataset for a two class problem to distinguish between Robusta and Arabica coffee beans. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/763002c5-40ad-48cc-9de0-ea43d7fa8a75)"> | [data source](https://github.com/QIBChemometrics/Benchtop-NMR-Coffee-Survey) |
| [Weather dataset](https://zenodo.org/record/7525955) <img width="25" alt="Weather Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/3ff3d639-ed2f-4a38-b7ac-957c984bce9f"> | The light version of the weather prediciton dataset, which contains daily observations (89 features) for 11 European locations through the years 2000 to 2010. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/b0a505ac-8a6c-4e1c-b6ad-35e31e52f46d)"> | [data source](https://github.com/florian-huber/weather_prediction_dataset) |

### Tabular
Expand Down Expand Up @@ -290,8 +295,9 @@ And here are links to notebooks showing how we created our models on the benchma

| Models | Generation |
| :-------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [Coffee model](https://zenodo.org/records/10579458) | [Coffee model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/coffee/generate_model.ipynb) |
| [Season prediction model](https://zenodo.org/record/7543883) | [Season prediction model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/season_prediction/generate_model.ipynb) |
| [Coffee model](https://zenodo.org/records/10579458) | [Coffee model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/coffee/generate_model.ipynb) |
| [Season prediction model](https://zenodo.org/record/7543883) | [Season prediction model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/season_prediction/generate_model.ipynb) |
| [Fast Radio Burst classification model](https://zenodo.org/records/10656614) | [Fast Radio Burst classification model generation](https://doi.org/10.3847/1538-3881/aae649) |

### Tabular

Expand All @@ -305,7 +311,7 @@ And here are links to notebooks showing how we created our models on the benchma

## Tutorials

DIANNA supports different data modalities and XAI methods. The table contains links to the relevant XAI method's papers (for some explanatory videos on the methods, please see [tutorials](./tutorials)). The DIANNA [tutorials](./tutorials) cover each supported method and data modality on a least one dataset. Our future plans to expand DIANNA with more data modalities and XAI methods are given in the [ROADMAP](https://dianna.readthedocs.io/en/latest/ROADMAP.html).
DIANNA supports different data modalities and XAI methods (explainers). We have evaluated many explainers using objective criteria (see the [How to find your AI explainer](https://blog.esciencecenter.nl/how-to-find-your-artificial-intelligence-explainer-dbb1ac608009) blog-post). The table below contains links to the relevant XAI method's papers (for some explanatory videos on the methods, please see [tutorials](./tutorials)). The DIANNA [tutorials](./tutorials) cover each supported method and data modality on a least one dataset using the default or tuned [hyperparameters](./tutorials/README.md#important-hyperparameters). Our plans to expand DIANNA with more data modalities and explainers are given in the [ROADMAP](https://dianna.readthedocs.io/en/latest/ROADMAP.html).

<!-- see issue: https://github.com/dianna-ai/dianna/issues/142, also related issue: https://github.com/dianna-ai/dianna/issues/148 -->

Expand All @@ -314,11 +320,24 @@ DIANNA supports different data modalities and XAI methods. The table contains li
| Images ||||
| Text ||| |
| Timeseries ||| |
| Tabular | planned || planned |
| Embedding | planned | planned | planned
| Graphs* | work in progress | work in progress | work in progress |

[LRP](https://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0130140&type=printable) and [PatternAttribution](https://arxiv.org/pdf/1705.05598.pdf) also feature in the top 5 of our thoroughly evaluated XAI methods using objective criteria (details in coming blog-post). **Contributing by adding these and more (new) post-hoc explainability methods on ONNX models is very welcome!**
| Tabular | planned |||
| Embedding | work in progress | |
| Graphs* | next steps | ... | ... |

[LRP](https://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0130140&type=printable) and [PatternAttribution](https://arxiv.org/pdf/1705.05598.pdf) also feature in the top 5 of our thoroughly evaluated explainers.
Also [GradCAM](https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf)) has been recently found to be *semantically continous*! **Contributing by adding these and more (new) post-hoc explainability methods on ONNX models is very welcome!**


### Scientific use-cases
Our goal is that the scientific community embrases XAI as a source for novel and unexplored perspectives on scientific problems.
Here, we offer [tutorials](./tutorials) on specific scientific use-cases of uisng XAI:
| Use-case (data) \ XAI | [RISE](http://bmvc2018.org/contents/papers/1064.pdf) | [LIME](https://www.kdd.org/kdd2016/papers/files/rfp0573-ribeiroA.pdf) | [KernelSHAP](https://proceedings.neurips.cc/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf) |
| :--------- | :-------- | :------------------------------ | :-------------------------- |
| Biology (Phytomorphology): Tree Leaves classification (images) | || |
| Astronomy: Fast Radio Burst detection (timeseries) || | |
| Geo-science (raster data) | planned | ... | ... | ... |
| Social sciences (text) | work in progress | ... |... | ... |
| Climate | planned | ... | ... | ... |

## Reference documentation

Expand Down
48 changes: 28 additions & 20 deletions dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,34 @@
"""
import importlib
import logging
from collections.abc import Callable
from collections.abc import Iterable
from typing import Union
import numpy as np
from . import utils

logging.getLogger(__name__).addHandler(logging.NullHandler())

__author__ = 'DIANNA Team'
__email__ = '[email protected]'
__version__ = '1.4.0'
__version__ = '1.5.0'


def explain_timeseries(model_or_function, input_timeseries, method, labels,
**kwargs):
def explain_timeseries(model_or_function: Union[Callable, str],
input_timeseries: np.ndarray, method: str,
labels: Iterable[int], **kwargs) -> np.ndarray:
"""Explain timeseries data given a model and a chosen method.
Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
input_timeseries (np.ndarray): Timeseries data to be explained
method (string): One of the supported methods: RISE, LIME or KernelSHAP
method (str): One of the supported methods: RISE, LIME or KernelSHAP
labels (Iterable(int)): Labels to be explained
kwargs: key word arguments
Returns:
One heatmap per class.
np.ndarray: One heatmap per class.
"""
explainer = _get_explainer(method, kwargs, modality='Timeseries')
Expand All @@ -58,19 +63,21 @@ def explain_timeseries(model_or_function, input_timeseries, method, labels,
**explain_timeseries_kwargs)


def explain_image(model_or_function, input_image, method, labels, **kwargs):
def explain_image(model_or_function: Union[Callable,
str], input_image: np.ndarray,
method: str, labels: Iterable[int], **kwargs) -> np.ndarray:
"""Explain an image (input_data) given a model and a chosen method.
Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
input_image (np.ndarray): Image data to be explained
method (string): One of the supported methods: RISE, LIME or KernelSHAP
method (str): One of the supported methods: RISE, LIME or KernelSHAP
labels (Iterable(int)): Labels to be explained
kwargs: These keyword parameters are passed on
Returns:
One heatmap (2D array) per class.
np.ndarray: An array containing the heat maps for each class.
"""
if method.upper() == 'KERNELSHAP':
Expand All @@ -87,21 +94,22 @@ def explain_image(model_or_function, input_image, method, labels, **kwargs):
**explain_image_kwargs)


def explain_text(model_or_function, input_text, tokenizer, method, labels,
**kwargs):
def explain_text(model_or_function: Union[Callable,
str], input_text: str, tokenizer,
method: str, labels: Iterable[int], **kwargs) -> list:
"""Explain text (input_text) given a model and a chosen method.
Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
input_text (string): Text to be explained
tokenizer : Tokenizer class with tokenize and convert_tokens_to_string methods, and mask_token attribute
method (string): One of the supported methods: RISE or LIME
input_text (str): Text to be explained
tokenizer: Tokenizer class with tokenize and convert_tokens_to_string methods, and mask_token attribute
method (str): One of the supported methods: RISE or LIME
labels (Iterable(int)): Labels to be explained
kwargs: These keyword parameters are passed on
Returns:
List of (word, index of word in raw text, importance for target class) tuples.
list: List of tuples (word, index of word in raw text, importance for target class) for each class.
"""
explainer = _get_explainer(method, kwargs, modality='Text')
Expand All @@ -120,23 +128,23 @@ def explain_text(model_or_function, input_text, tokenizer, method, labels,
)


def explain_tabular(model_or_function,
input_tabular,
method,
def explain_tabular(model_or_function: Union[Callable, str],
input_tabular: np.ndarray,
method: str,
labels=None,
**kwargs):
**kwargs) -> np.ndarray:
"""Explain tabular (input_text) given a model and a chosen method.
Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
input_tabular (np.ndarray): Tabular data to be explained
method (string): One of the supported methods: RISE, LIME or KernelSHAP
method (str): One of the supported methods: RISE, LIME or KernelSHAP
labels (Iterable(int), optional): Labels to be explained
kwargs: These keyword parameters are passed on
Returns:
One heatmap (2D array) per class.
np.ndarray: An array containing the heat maps for each class.
"""
explainer = _get_explainer(method, kwargs, modality='Tabular')
explain_tabular_kwargs = utils.get_kwargs_applicable_to_function(
Expand Down
Loading

0 comments on commit ef893e1

Please sign in to comment.