-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
113 changed files
with
7,926 additions
and
21,417 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,29 +22,34 @@ | |
""" | ||
import importlib | ||
import logging | ||
from collections.abc import Callable | ||
from collections.abc import Iterable | ||
from typing import Union | ||
import numpy as np | ||
from . import utils | ||
|
||
logging.getLogger(__name__).addHandler(logging.NullHandler()) | ||
|
||
__author__ = 'DIANNA Team' | ||
__email__ = '[email protected]' | ||
__version__ = '1.4.0' | ||
__version__ = '1.5.0' | ||
|
||
|
||
def explain_timeseries(model_or_function, input_timeseries, method, labels, | ||
**kwargs): | ||
def explain_timeseries(model_or_function: Union[Callable, str], | ||
input_timeseries: np.ndarray, method: str, | ||
labels: Iterable[int], **kwargs) -> np.ndarray: | ||
"""Explain timeseries data given a model and a chosen method. | ||
Args: | ||
model_or_function (callable or str): The function that runs the model to be explained _or_ | ||
the path to a ONNX model on disk. | ||
input_timeseries (np.ndarray): Timeseries data to be explained | ||
method (string): One of the supported methods: RISE, LIME or KernelSHAP | ||
method (str): One of the supported methods: RISE, LIME or KernelSHAP | ||
labels (Iterable(int)): Labels to be explained | ||
kwargs: key word arguments | ||
Returns: | ||
One heatmap per class. | ||
np.ndarray: One heatmap per class. | ||
""" | ||
explainer = _get_explainer(method, kwargs, modality='Timeseries') | ||
|
@@ -58,19 +63,21 @@ def explain_timeseries(model_or_function, input_timeseries, method, labels, | |
**explain_timeseries_kwargs) | ||
|
||
|
||
def explain_image(model_or_function, input_image, method, labels, **kwargs): | ||
def explain_image(model_or_function: Union[Callable, | ||
str], input_image: np.ndarray, | ||
method: str, labels: Iterable[int], **kwargs) -> np.ndarray: | ||
"""Explain an image (input_data) given a model and a chosen method. | ||
Args: | ||
model_or_function (callable or str): The function that runs the model to be explained _or_ | ||
the path to a ONNX model on disk. | ||
input_image (np.ndarray): Image data to be explained | ||
method (string): One of the supported methods: RISE, LIME or KernelSHAP | ||
method (str): One of the supported methods: RISE, LIME or KernelSHAP | ||
labels (Iterable(int)): Labels to be explained | ||
kwargs: These keyword parameters are passed on | ||
Returns: | ||
One heatmap (2D array) per class. | ||
np.ndarray: An array containing the heat maps for each class. | ||
""" | ||
if method.upper() == 'KERNELSHAP': | ||
|
@@ -87,21 +94,22 @@ def explain_image(model_or_function, input_image, method, labels, **kwargs): | |
**explain_image_kwargs) | ||
|
||
|
||
def explain_text(model_or_function, input_text, tokenizer, method, labels, | ||
**kwargs): | ||
def explain_text(model_or_function: Union[Callable, | ||
str], input_text: str, tokenizer, | ||
method: str, labels: Iterable[int], **kwargs) -> list: | ||
"""Explain text (input_text) given a model and a chosen method. | ||
Args: | ||
model_or_function (callable or str): The function that runs the model to be explained _or_ | ||
the path to a ONNX model on disk. | ||
input_text (string): Text to be explained | ||
tokenizer : Tokenizer class with tokenize and convert_tokens_to_string methods, and mask_token attribute | ||
method (string): One of the supported methods: RISE or LIME | ||
input_text (str): Text to be explained | ||
tokenizer: Tokenizer class with tokenize and convert_tokens_to_string methods, and mask_token attribute | ||
method (str): One of the supported methods: RISE or LIME | ||
labels (Iterable(int)): Labels to be explained | ||
kwargs: These keyword parameters are passed on | ||
Returns: | ||
List of (word, index of word in raw text, importance for target class) tuples. | ||
list: List of tuples (word, index of word in raw text, importance for target class) for each class. | ||
""" | ||
explainer = _get_explainer(method, kwargs, modality='Text') | ||
|
@@ -120,23 +128,23 @@ def explain_text(model_or_function, input_text, tokenizer, method, labels, | |
) | ||
|
||
|
||
def explain_tabular(model_or_function, | ||
input_tabular, | ||
method, | ||
def explain_tabular(model_or_function: Union[Callable, str], | ||
input_tabular: np.ndarray, | ||
method: str, | ||
labels=None, | ||
**kwargs): | ||
**kwargs) -> np.ndarray: | ||
"""Explain tabular (input_text) given a model and a chosen method. | ||
Args: | ||
model_or_function (callable or str): The function that runs the model to be explained _or_ | ||
the path to a ONNX model on disk. | ||
input_tabular (np.ndarray): Tabular data to be explained | ||
method (string): One of the supported methods: RISE, LIME or KernelSHAP | ||
method (str): One of the supported methods: RISE, LIME or KernelSHAP | ||
labels (Iterable(int), optional): Labels to be explained | ||
kwargs: These keyword parameters are passed on | ||
Returns: | ||
One heatmap (2D array) per class. | ||
np.ndarray: An array containing the heat maps for each class. | ||
""" | ||
explainer = _get_explainer(method, kwargs, modality='Tabular') | ||
explain_tabular_kwargs = utils.get_kwargs_applicable_to_function( | ||
|
Oops, something went wrong.