diff --git a/dashboard/.streamlit/config.toml b/dashboard/.streamlit/config.toml index 3dc2d6de..afefa20a 100644 --- a/dashboard/.streamlit/config.toml +++ b/dashboard/.streamlit/config.toml @@ -1,6 +1,6 @@ [theme] base="light" -primaryColor="0e749b" +primaryColor="7030a0" # backgroundColor= secondaryBackgroundColor="#e4f3f9" # textColor= diff --git a/dashboard/dashboard.py b/dashboard/Home.py similarity index 65% rename from dashboard/dashboard.py rename to dashboard/Home.py index 1cb9a008..fbb66de1 100644 --- a/dashboard/dashboard.py +++ b/dashboard/Home.py @@ -1,4 +1,6 @@ import streamlit as st +from _shared import add_sidebar_logo +from _shared import data_directory st.set_page_config(page_title="Dianna's dashboard", @@ -15,18 +17,26 @@ 'https://github.com/dianna-ai/dianna') }) -st.image( - 'https://user-images.githubusercontent.com/3244249/151994514-b584b984-a148-4ade-80ee-0f88b0aefa45.png' -) +add_sidebar_logo() -st.title("Dianna's dashboard") +st.image(str(data_directory / 'logo.png')) -st.write(""" +st.markdown(""" DIANNA is a Python package that brings explainable AI (XAI) to your research project. It wraps carefully selected XAI methods in a simple, uniform interface. It's built by, with and for (academic) researchers and research software engineers working on machine learning projects. -- [Images](/images) -- [Text](/text) -""") +### Pages + +- Images +- Text +- Time series + + +### More information + +- [Source code](https://github.com/dianna-ai/dianna) +- [Documentation](https://dianna.readthedocs.io/) +""", + unsafe_allow_html=True) diff --git a/dashboard/_model_utils.py b/dashboard/_model_utils.py index 8f83b7cb..272e4a40 100644 --- a/dashboard/_model_utils.py +++ b/dashboard/_model_utils.py @@ -1,3 +1,4 @@ +from pathlib import Path import numpy as np import onnx @@ -21,6 +22,9 @@ def load_model(file): def load_labels(file): + if isinstance(file, (str, Path)): + file = open(file, 'rb') + labels = [line.decode().rstrip() for line in file.readlines()] if labels is None or labels == ['']: raise ValueError(labels) diff --git a/dashboard/_models_image.py b/dashboard/_models_image.py index 2b707305..2a79997a 100644 --- a/dashboard/_models_image.py +++ b/dashboard/_models_image.py @@ -1,5 +1,4 @@ import tempfile -import numpy as np import streamlit as st from _model_utils import fill_segmentation from _model_utils import preprocess_function @@ -7,13 +6,6 @@ from dianna import explain_image -def get_top_indices(predictions, n_top): - indices = np.array(np.argpartition(predictions, -n_top)[-n_top:]) - indices = indices[np.argsort(predictions[indices])] - indices = np.flip(indices) - return indices - - @st.cache_data def predict(*, model, image): output_node = prepare(model, gen_tensor_dict=True).outputs[0] @@ -26,6 +18,7 @@ def _run_rise_image(model, image, i, **kwargs): relevances = explain_image( model, image, + method='RISE', **kwargs, ) return relevances[0] @@ -37,6 +30,7 @@ def _run_lime_image(model, image, i, **kwargs): model, image * 256, preprocess_function=preprocess_function, + method='LIME', **kwargs, ) return relevances[0] @@ -48,7 +42,10 @@ def _run_kernelshap_image(model, image, i, **kwargs): with tempfile.NamedTemporaryFile() as f: f.write(model) f.flush() - shap_values, segments_slic = explain_image(f.name, image, **kwargs) + shap_values, segments_slic = explain_image(f.name, + image, + method='KernelSHAP', + **kwargs) return fill_segmentation(shap_values[i][0], segments_slic) diff --git a/dashboard/_models_text.py b/dashboard/_models_text.py index b3be8d76..2cd9a148 100644 --- a/dashboard/_models_text.py +++ b/dashboard/_models_text.py @@ -20,6 +20,7 @@ def _run_rise_text(_model, text, **kwargs): _model, text, tokenizer, + method='RISE', **kwargs, ) return relevances @@ -27,7 +28,7 @@ def _run_rise_text(_model, text, **kwargs): @st.cache_data def _run_lime_text(_model, text, **kwargs): - relevances = explain_text(_model, text, tokenizer, **kwargs) + relevances = explain_text(_model, text, tokenizer, method='LIME', **kwargs) return relevances diff --git a/dashboard/_models_ts.py b/dashboard/_models_ts.py new file mode 100644 index 00000000..935dff92 --- /dev/null +++ b/dashboard/_models_ts.py @@ -0,0 +1,5 @@ +def predict(): + pass + + +explain_ts_dispatcher = {} diff --git a/dashboard/_shared.py b/dashboard/_shared.py new file mode 100644 index 00000000..f55f7f23 --- /dev/null +++ b/dashboard/_shared.py @@ -0,0 +1,129 @@ +import base64 +from pathlib import Path +from typing import Any +from typing import Dict +from typing import Sequence +import numpy as np +import streamlit as st + + +data_directory = Path(__file__).parent / 'data' + + +@st.cache_data +def get_base64_of_bin_file(png_file): + with open(png_file, 'rb') as f: + data = f.read() + return base64.b64encode(data).decode() + + +def build_markup_for_logo( + png_file, + background_position='50% 10%', + margin_top='10%', + image_width='60%', + image_height='', +): + binary_string = get_base64_of_bin_file(png_file) + return f""" + + """ + + +def add_sidebar_logo(): + """Based on: https://stackoverflow.com/a/73278825.""" + png_file = data_directory / 'logo.png' + logo_markup = build_markup_for_logo(png_file) + st.markdown( + logo_markup, + unsafe_allow_html=True, + ) + + +def _methods_checkboxes(*, choices: Sequence): + """Get methods from a horizontal row of checkboxes.""" + n_choices = len(choices) + methods = [] + for col, method in zip(st.columns(n_choices), choices): + with col: + if st.checkbox(method): + methods.append(method) + + if not methods: + st.info('Select a method to continue') + st.stop() + + return methods + + +def _get_params(method: str): + if method == 'RISE': + return { + 'n_masks': + st.number_input('Number of masks', value=1000), + 'feature_res': + st.number_input('Feature resolution', value=6), + 'p_keep': + st.number_input('Probability to be kept unmasked', value=0.1), + } + + elif method == 'KernelSHAP': + return { + 'nsamples': st.number_input('Number of samples', value=1000), + 'background': st.number_input('Background', value=0), + 'n_segments': st.number_input('Number of segments', value=200), + 'sigma': st.number_input('σ', value=0), + } + + elif method == 'LIME': + return { + 'rand_state': st.number_input('Random state', value=2), + } + + else: + raise ValueError(f'No such method: {method}') + + +def _get_method_params(methods: Sequence[str]) -> Dict[str, Dict[str, Any]]: + method_params = {} + + with st.expander('Click to modify method parameters'): + for method, col in zip(methods, st.columns(len(methods))): + with col: + st.header(method) + method_params[method] = _get_params(method) + + return method_params + + +def _get_top_indices(predictions, n_top): + indices = np.array(np.argpartition(predictions, -n_top)[-n_top:]) + indices = indices[np.argsort(predictions[indices])] + indices = np.flip(indices) + return indices + + +def _get_top_indices_and_labels(*, predictions, labels): + c1, c2 = st.columns(2) + + with c2: + n_top = st.number_input('Number of top results to show', + value=2, + min_value=1, + max_value=len(labels)) + + top_indices = _get_top_indices(predictions, n_top) + top_labels = [labels[i] for i in top_indices] + + with c1: + st.metric('Predicted class', top_labels[0]) + + return top_indices, top_labels diff --git a/dashboard/_ts_utils.py b/dashboard/_ts_utils.py new file mode 100644 index 00000000..3760df7f --- /dev/null +++ b/dashboard/_ts_utils.py @@ -0,0 +1,6 @@ +import numpy as np + + +def open_timeseries(file): + """Open a time series from a file and returns it as a numpy array.""" + return np.arange(10), np.arange(10)**2 diff --git a/dashboard/pages/1_Images.py b/dashboard/pages/1_Images.py new file mode 100644 index 00000000..e1558954 --- /dev/null +++ b/dashboard/pages/1_Images.py @@ -0,0 +1,100 @@ +import streamlit as st +from _image_utils import open_image +from _model_utils import load_labels +from _model_utils import load_model +from _models_image import explain_image_dispatcher +from _models_image import predict +from _shared import _get_method_params +from _shared import _get_top_indices_and_labels +from _shared import _methods_checkboxes +from _shared import add_sidebar_logo +from _shared import data_directory +from dianna.visualization import plot_image + + +add_sidebar_logo() + +st.title('Image explanation') + +with st.sidebar: + st.header('Input data') + + load_example = st.checkbox('Load example data', key='image_example_check') + + image_file = st.file_uploader('Select image', + type=('png', 'jpg', 'jpeg'), + disabled=load_example) + + if image_file: + st.image(image_file) + + image_model_file = st.file_uploader('Select model', + type='onnx', + disabled=load_example) + + image_label_file = st.file_uploader('Select labels', + type='txt', + disabled=load_example) + + if load_example: + image_file = (data_directory / 'digit0.png') + image_model_file = (data_directory / 'mnist_model_tf.onnx') + image_label_file = (data_directory / 'labels_mnist.txt') + +if not (image_file and image_model_file and image_label_file): + st.info('Add your input data in the left panel to continue') + st.stop() + +image, _ = open_image(image_file) + +model = load_model(image_model_file) +serialized_model = model.SerializeToString() + +labels = load_labels(image_label_file) + +choices = ('RISE', 'KernelSHAP', 'LIME') +methods = _methods_checkboxes(choices=choices) + +method_params = _get_method_params(methods) + +with st.spinner('Predicting class'): + predictions = predict(model=model, image=image) + +top_indices, top_labels = _get_top_indices_and_labels(predictions=predictions, + labels=labels) + +# check which axis is color channel +original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :] +axis_labels = {2: 'channels'} if image.shape[2] <= 3 else {0: 'channels'} + +weight = 0.9 / len(methods) +column_spec = [0.1, *[weight for _ in methods]] + +_, *columns = st.columns(column_spec) +for col, method in zip(columns, methods): + with col: + st.header(method) + +for index, label in zip(top_indices, top_labels): + index_col, *columns = st.columns(column_spec) + + with index_col: + st.markdown(f'##### {label}') + + for col, method in zip(columns, methods): + kwargs = method_params[method].copy() + kwargs['axis_labels'] = axis_labels + kwargs['labels'] = [index] + + func = explain_image_dispatcher[method] + + with col: + with st.spinner(f'Running {method}'): + heatmap = func(serialized_model, image, index, **kwargs) + + fig = plot_image(heatmap, + original_data=original_data, + heatmap_cmap='bwr', + show_plot=False) + + st.pyplot(fig) diff --git a/dashboard/pages/1_images.py b/dashboard/pages/1_images.py deleted file mode 100644 index 2ce1a5db..00000000 --- a/dashboard/pages/1_images.py +++ /dev/null @@ -1,120 +0,0 @@ -import numpy as np -import streamlit as st -from _image_utils import open_image -from _model_utils import load_labels -from _model_utils import load_model -from _models_image import explain_image_dispatcher -from _models_image import get_top_indices -from _models_image import predict -from dianna.visualization import plot_image - - -st.title("Dianna's dashboard") - -with st.sidebar: - st.header('Input data') - - image_file = st.file_uploader('Select image', type=('png', 'jpg', 'jpeg')) - - if image_file: - st.image(image_file) - - model_file = st.file_uploader('Select model', type='onnx') - - label_file = st.file_uploader('Select labels', type='txt') - -if not (image_file and model_file and label_file): - st.info('Add your input data in the left panel to continue') - st.stop() - -image, _ = open_image(image_file) -assert isinstance(image, np.ndarray) - -model = load_model(model_file) -serialized_model = model.SerializeToString() - -labels = load_labels(label_file) - -methods = st.multiselect('Select XAI methods', - options=('RISE', 'KernelSHAP', 'LIME')) - -n_top = st.number_input('Number of top results to show', - value=2, - min_value=0, - max_value=len(labels)) - -if not methods: - st.info('Select a method to continue') - st.stop() - -tabs = st.tabs(methods) - -kws = {'RISE': {}, 'KernelSHAP': {}, 'LIME': {}} - -for method, tab in zip(methods, tabs): - with tab: - c1, c2 = st.columns(2) - if method == 'RISE': - with c1: - kws['RISE']['n_masks'] = st.number_input('Number of masks', - value=1000) - kws['RISE']['feature_res'] = st.number_input( - 'Feature resolution', value=6) - with c2: - kws['RISE']['p_keep'] = st.number_input( - 'Probability to be kept unmasked', value=0.1) - - if method == 'KernelSHAP': - with c1: - kws['KernelSHAP']['nsamples'] = st.number_input( - 'Number of samples', value=1000) - kws['KernelSHAP']['background'] = st.number_input('Background', - value=0) - with c2: - kws['KernelSHAP']['n_segments'] = st.number_input( - 'Number of segments', value=200) - kws['KernelSHAP']['sigma'] = st.number_input('σ', value=0) - - if method == 'LIME': - with c1: - kws['LIME']['rand_state'] = st.number_input('Random state', - value=2) - -with st.spinner('Predicting class'): - predictions = predict(model=model, image=image) - -predicted_class = labels[np.argmax(predictions)] - -st.info(f'The predicted class is: {predicted_class}') - -top_indices = get_top_indices(predictions, n_top) -top_labels = [labels[i] for i in top_indices] - -# check which axis is color channel -original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :] -axis_labels = {2: 'channels'} if image.shape[2] <= 3 else {0: 'channels'} - -columns = st.columns(len(methods)) - -for col, method in zip(columns, methods): - kwargs = kws[method].copy() - kwargs['method'] = method - kwargs['axis_labels'] = axis_labels - - func = explain_image_dispatcher[method] - - with col: - st.header(method) - - for index, label in enumerate(top_labels): - with st.spinner(f'Running {method}'): - kwargs['labels'] = [top_indices[index]] - heatmap = func(serialized_model, image, index, **kwargs) - - st.write(f'index={index}, label={label}') - - fig = plot_image(heatmap, - original_data=original_data, - heatmap_cmap='bwr', - show_plot=False) - st.pyplot(fig) diff --git a/dashboard/pages/2_Text.py b/dashboard/pages/2_Text.py new file mode 100644 index 00000000..5c4cf4f6 --- /dev/null +++ b/dashboard/pages/2_Text.py @@ -0,0 +1,92 @@ +import streamlit as st +from _model_utils import load_labels +from _model_utils import load_model +from _models_text import explain_text_dispatcher +from _models_text import predict +from _movie_model import MovieReviewsModelRunner +from _shared import _get_method_params +from _shared import _get_top_indices_and_labels +from _shared import _methods_checkboxes +from _shared import add_sidebar_logo +from _shared import data_directory +from _text_utils import format_word_importances + + +add_sidebar_logo() + +st.title('Text explanation') + +with st.sidebar: + st.header('Input data') + + load_example = st.checkbox('Load example data', key='text_example_check') + + text_input = st.text_input('Input string', disabled=load_example) + + if text_input: + st.write(text_input) + + text_model_file = st.file_uploader('Select model', + type='onnx', + disabled=load_example) + + text_label_file = st.file_uploader('Select labels', + type='txt', + disabled=load_example) + + if load_example: + text_input = 'The movie started out great but the ending was dissappointing' + text_model_file = data_directory / 'movie_review_model.onnx' + text_label_file = data_directory / 'labels_text.txt' + +if not (text_input and text_model_file and text_label_file): + st.info('Add your input data in the left panel to continue') + st.stop() + +model = load_model(text_model_file) +serialized_model = model.SerializeToString() + +labels = load_labels(text_label_file) + +choices = ('RISE', 'LIME') +methods = _methods_checkboxes(choices=choices) + +method_params = _get_method_params(methods) + +model_runner = MovieReviewsModelRunner(serialized_model) + +with st.spinner('Predicting class'): + predictions = predict(model=serialized_model, text_input=text_input) + +top_indices, top_labels = _get_top_indices_and_labels( + predictions=predictions[0], labels=labels) + +weight = 0.85 / len(methods) +column_spec = [0.15, *[weight for _ in methods]] + +_, *columns = st.columns(column_spec) +for col, method in zip(columns, methods): + with col: + st.header(method) + +for index, label in zip(top_indices, top_labels): + index_col, *columns = st.columns(column_spec) + + with index_col: + st.markdown(f'##### {label}') + + for col, method in zip(columns, methods): + kwargs = method_params[method].copy() + kwargs['labels'] = [index] + + func = explain_text_dispatcher[method] + + with col: + with st.spinner(f'Running {method}'): + relevances = func(model_runner, text_input, **kwargs) + + html = format_word_importances(text_input, relevances[0]) + st.write(html, unsafe_allow_html=True) + + # add some white space to separate rows + st.markdown('') diff --git a/dashboard/pages/2_text.py b/dashboard/pages/2_text.py deleted file mode 100644 index 64b3628a..00000000 --- a/dashboard/pages/2_text.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -import streamlit as st -from _model_utils import load_labels -from _model_utils import load_model -from _models_text import explain_text_dispatcher -from _models_text import predict -from _movie_model import MovieReviewsModelRunner -from _text_utils import format_word_importances - - -st.title("Dianna's dashboard") - -with st.sidebar: - st.header('Input data') - - text_input = st.text_input('Input string') - - if text_input: - st.write(text_input) - - model_file = st.file_uploader('Select model', type='onnx') - - label_file = st.file_uploader('Select labels', type='txt') - -if not (text_input and model_file and label_file): - st.info('Add your input data in the left panel to continue') - st.stop() - -model = load_model(model_file) -serialized_model = model.SerializeToString() - -labels = load_labels(label_file) - -methods = st.multiselect('Select XAI methods', options=('RISE', 'LIME')) - -if not methods: - st.info('Select a method to continue') - st.stop() - -tabs = st.tabs(methods) - -kws = {'RISE': {}, 'LIME': {}} - -for method, tab in zip(methods, tabs): - with tab: - c1, c2 = st.columns(2) - if method == 'RISE': - with c1: - kws['RISE']['n_masks'] = st.number_input('Number of masks', - value=1000) - kws['RISE']['feature_res'] = st.number_input( - 'Feature resolution', value=6) - with c2: - kws['RISE']['p_keep'] = st.number_input( - 'Probability to be kept unmasked', value=0.1) - - if method == 'LIME': - with c1: - kws['LIME']['rand_state'] = st.number_input('Random state', - value=2) - -model_runner = MovieReviewsModelRunner(serialized_model) - -with st.spinner('Predicting class'): - predictions = predict(model=serialized_model, text_input=text_input) - -predicted_class = labels[np.argmax(predictions)] -predicted_index = labels.index(predicted_class) - -st.info(f'The predicted class is: {predicted_class}') - -columns = st.columns(len(methods)) - -for col, method in zip(columns, methods): - kwargs = kws[method].copy() - kwargs['method'] = method - kwargs['labels'] = [predicted_index] - - func = explain_text_dispatcher[method] - - with col: - st.header(method) - - with st.spinner(f'Running {method}'): - relevances = func(model_runner, text_input, **kwargs) - - html = format_word_importances(text_input, relevances[0]) - st.write(html, unsafe_allow_html=True) diff --git a/dashboard/pages/3_Time_series.py b/dashboard/pages/3_Time_series.py new file mode 100644 index 00000000..e2ec69a7 --- /dev/null +++ b/dashboard/pages/3_Time_series.py @@ -0,0 +1,93 @@ +import streamlit as st +from _model_utils import load_labels +from _model_utils import load_model +from _models_ts import explain_ts_dispatcher +from _models_ts import predict +from _shared import _get_method_params +from _shared import _get_top_indices_and_labels +from _shared import _methods_checkboxes +from _shared import add_sidebar_logo +from _shared import data_directory +from _ts_utils import open_timeseries +from dianna.visualization import plot_timeseries + + +add_sidebar_logo() + +st.title('Time series explanation') + +st.error( + 'Time series explanation is still work in progress and not yet functioning!' +) + +with st.sidebar: + st.header('Input data') + + load_example = st.checkbox('Load example data', key='ts_example_check') + + ts_file = st.file_uploader('Select input data', + type=(), + disabled=load_example) + + ts_model_file = st.file_uploader('Select model', + type='onnx', + disabled=load_example) + + ts_label_file = st.file_uploader('Select labels', + type='txt', + disabled=load_example) + + if load_example: + ts_file = (data_directory / 'xxx.suffix') + ts_model_file = (data_directory / 'xxx.onnx') + ts_label_file = (data_directory / 'xxx.txt') + +if not (ts_file and ts_model_file and ts_label_file): + st.info('Add your input data in the left panel to continue') + st.stop() + +ts_data, _ = open_timeseries(ts_file) + +model = load_model(ts_model_file) +serialized_model = model.SerializeToString() + +labels = load_labels(ts_label_file) + +choices = () +methods = _methods_checkboxes(choices=choices) + +method_params = _get_method_params(methods) + +with st.spinner('Predicting class'): + predictions = predict(model=model, ts_data=ts_data) + +top_indices, top_labels = _get_top_indices_and_labels(predictions=predictions, + labels=labels) + +weight = 0.9 / len(methods) +column_spec = [0.1, *[weight for _ in methods]] + +_, *columns = st.columns(column_spec) +for col, method in zip(columns, methods): + with col: + st.header(method) + +for index, label in zip(top_indices, top_labels): + index_col, *columns = st.columns(column_spec) + + with index_col: + st.markdown(f'##### {label}') + + for col, method in zip(columns, methods): + kwargs = method_params[method].copy() + kwargs['labels'] = [index] + + func = explain_ts_dispatcher[method] + + with col: + with st.spinner(f'Running {method}'): + segments = func(serialized_model, ..., **kwargs) + + fig = plot_timeseries(...) + + st.pyplot(fig) diff --git a/dashboard/readme.md b/dashboard/readme.md index 93310e4b..39bab7d0 100644 --- a/dashboard/readme.md +++ b/dashboard/readme.md @@ -3,10 +3,16 @@ A dashboard was created for DIANNA using [streamlit](https://streamlit.io/) that can be used for simple exploration of your trained model explained by DIANNA. The dashboard produces the visual explanation of your selected XAI method. Additionally it allows you to compare the results of different XAI methods, as well as explanations of the top ranked predicted labels. -To open the dashboard, run +To open the dashboard, you can install dianna via `pip install -e .[dashboard]` and run: ```console -streamlit run dashboard.py +dianna-dashboard +``` + +or, from this directory: + +```console +streamlit run Home.py ``` Open the link on which the app is running. Note that you are running the dashboard *only locally*. The data you use in the dashboard is your local data, and it is *not* uploaded to any server. diff --git a/dianna/_logging_utils.py b/dianna/_logging_utils.py new file mode 100644 index 00000000..41b80920 --- /dev/null +++ b/dianna/_logging_utils.py @@ -0,0 +1,42 @@ +import logging + + +class LoggingContext: + """Context manager to Temporarily change logging configuration. + + From https://docs.python.org/3/howto/logging-cookbook.html + + Parameters + ---------- + logger : None, optional + Logging instance to change, defaults to root logger. + level : None, optional + New log level, i.e. `logging.CRITICAL`. + handler : None, optional + Log handler to use. + close : bool, optional + Whether to close the handler after use. + """ + + def __init__(self, logger=None, level=None, handler=None, close=True): + if not logger: + logger = logging.getLogger() + self.logger = logger + self.level = level + self.handler = handler + self.close = close + + def __enter__(self): + if self.level is not None: + self.old_level = self.logger.level + self.logger.setLevel(self.level) + if self.handler: + self.logger.addHandler(self.handler) + + def __exit__(self, et, ev, tb): + if self.level is not None: + self.logger.setLevel(self.old_level) + if self.handler: + self.logger.removeHandler(self.handler) + if self.handler and self.close: + self.handler.close() diff --git a/dianna/cli.py b/dianna/cli.py new file mode 100644 index 00000000..6b783037 --- /dev/null +++ b/dianna/cli.py @@ -0,0 +1,33 @@ +import os +import sys + + +if sys.version_info < (3, 10): + from importlib_resources import files +else: + from importlib.resources import files + + +def dashboard(): + """Start streamlit dashboard.""" + from streamlit.web import cli as stcli + + dashboard_dir = files('dianna').parent / 'dashboard' + os.chdir(dashboard_dir) + + # https://docs.streamlit.io/library/advanced-features/configuration + sys.argv = [ + 'streamlit', + 'run', + 'Home.py', + '--theme.base', + 'light', + '--theme.primaryColor', + '7030a0', + '--theme.secondaryBackgroundColor', + 'e4f3f9', + '--browser.gatherUsageStats', + 'false', + ] + + sys.exit(stcli.main()) diff --git a/dianna/methods/kernelshap.py b/dianna/methods/kernelshap.py index 13ce6984..8cd22c51 100644 --- a/dianna/methods/kernelshap.py +++ b/dianna/methods/kernelshap.py @@ -1,8 +1,9 @@ -import warnings +import logging import numpy as np import shap import skimage.segmentation from dianna import utils +from .._logging_utils import LoggingContext class KERNELSHAPImage: @@ -24,13 +25,7 @@ def __init__(self, axis_labels=None, preprocess_function=None): self.onnx_to_tf = prepare @staticmethod - def _segment_image( - image, - n_segments, - compactness, - sigma, - **kwargs - ): + def _segment_image(image, n_segments, compactness, sigma, **kwargs): """Create segmentation to explain by segment, not every pixel. This could help speed-up the calculation when the input size is very large. @@ -49,13 +44,11 @@ def _segment_image( via the following link: https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.slic """ - image_segments = skimage.segmentation.slic( - image=image, - n_segments=n_segments, - compactness=compactness, - sigma=sigma, - **kwargs - ) + image_segments = skimage.segmentation.slic(image=image, + n_segments=n_segments, + compactness=compactness, + sigma=sigma, + **kwargs) return image_segments @@ -64,7 +57,7 @@ def explain( model, input_data, labels, - nsamples="auto", + nsamples='auto', background=None, n_segments=100, compactness=10.0, @@ -120,24 +113,19 @@ def explain( skimage.segmentation.slic, kwargs) # call the segment method to create segmentation of input image - self.image_segments = self._segment_image( - self.input_data, - n_segments, - compactness, - sigma, - **slic_kwargs - ) + self.image_segments = self._segment_image(self.input_data, n_segments, + compactness, sigma, + **slic_kwargs) # call the Kernel SHAP explainer explainer = shap.KernelExplainer( self._runner, np.zeros((len(self.labels), n_segments))) - with warnings.catch_warnings(): - # avoid warnings due to version conflicts - warnings.simplefilter("ignore") - shap_values = explainer.shap_values( - np.ones((len(self.labels), n_segments)), nsamples=nsamples - ) + # Temporarily hide warnings, because shap is very spammy + with LoggingContext(level=logging.CRITICAL): + shap_values = explainer.shap_values(np.ones( + (len(self.labels), n_segments)), + nsamples=nsamples) return shap_values, self.image_segments @@ -150,13 +138,15 @@ def _prepare_image_data(self, input_data): transformed input data """ # automatically determine the location of the channels axis if no axis_labels were provided - axis_label_names = self.axis_labels.values() if isinstance(self.axis_labels, dict) else self.axis_labels + axis_label_names = self.axis_labels.values() if isinstance( + self.axis_labels, dict) else self.axis_labels if not axis_label_names: channels_axis_index = utils.locate_channels_axis(input_data.shape) self.axis_labels = {channels_axis_index: 'channels'} elif 'channels' not in axis_label_names: - raise ValueError("When providing axis_labels it is required to provide the location" - " of the channels axis") + raise ValueError( + 'When providing axis_labels it is required to provide the location' + ' of the channels axis') input_data = utils.to_xarray(input_data, self.axis_labels) # ensure channels axis is last and keep track of where it was so we can move it back @@ -165,10 +155,13 @@ def _prepare_image_data(self, input_data): return input_data - def _mask_image( - self, features, segmentation, image, background=None, - channels_axis_index=2, datatype=np.float32 - ): + def _mask_image(self, + features, + segmentation, + image, + background=None, + channels_axis_index=2, + datatype=np.float32): """Define a function that depends on a binary mask representing if an image region is hidden. Args: @@ -186,9 +179,8 @@ def _mask_image( background = image.mean(axis=(0, 1)) # Create an empty 4D array - out = np.zeros( - (features.shape[0], image.shape[0], image.shape[1], image.shape[2]) - ) + out = np.zeros((features.shape[0], image.shape[0], image.shape[1], + image.shape[2])) for i in range(features.shape[0]): out[i] = image @@ -209,13 +201,11 @@ def _runner(self, features): features (np.ndarray): A matrix of samples (# samples x # features) on which to explain the model's output. """ - model_input = self._mask_image(features, - self.image_segments, - self.input_data, - self.background, + model_input = self._mask_image(features, self.image_segments, + self.input_data, self.background, self.channels_axis_index, - self.input_node_dtype.as_numpy_dtype - ) + self.input_node_dtype.as_numpy_dtype) if self.preprocess_function is not None: model_input = self.preprocess_function(model_input) - return self.onnx_to_tf(self.onnx_model).run(model_input)[f"{self.output_node}"] + return self.onnx_to_tf( + self.onnx_model).run(model_input)[f'{self.output_node}'] diff --git a/dianna/visualization/image.py b/dianna/visualization/image.py index 2b708461..1ab88c05 100644 --- a/dianna/visualization/image.py +++ b/dianna/visualization/image.py @@ -54,6 +54,14 @@ def plot_image(heatmap, alpha = .5 ax.imshow(heatmap, cmap=heatmap_cmap, alpha=alpha) + ax.tick_params(bottom=False, + left=False, + right=False, + top=False, + labelleft=False, + labelbottom=False, + labelright=False, + labeltop=False) if show_plot: plt.show() if output_filename: diff --git a/dianna/visualization/text.py b/dianna/visualization/text.py index 62d464e5..c9fc5204 100644 --- a/dianna/visualization/text.py +++ b/dianna/visualization/text.py @@ -30,31 +30,40 @@ def highlight_text(explanation, display(HTML(output)) -def _create_html(input_tokens, explanation, max_opacity=0.8): - max_importance = max(abs(item[2]) for item in explanation) - explained_indices = [index for _, index, _ in explanation] - highlighted_words = [] - for index, word in enumerate(input_tokens): - # if word has an explanation, highlight based on that, otherwise - # make it grey - try: - explained_index = explained_indices.index(index) - importance = explanation[explained_index][2] - highlighted_words.append( - _highlight_word(word, importance, max_importance, max_opacity)) - except ValueError: - highlighted_words.append( - f'{word}' - ) - - return '
' + ' '.join(highlighted_words) + '' - - -def _highlight_word(word, importance, max_importance, max_opacity): - opacity = max_opacity * abs(importance) / max_importance +def _create_html(tokens, explanation, opacity: float = 0.8): + importance_map = {r[0]: r[2] for r in explanation} + + max_importance = max(abs(val) for val in importance_map.values()) + + tags = [] + for token in tokens: + importance = importance_map.get(token) + + if importance is None: + color = f'hsl(0, 0%, 75%, {opacity})' + else: + # normalize to max importance + importance = importance / max_importance + color = _get_color(importance, opacity) + + tag = (f'{token}') + tags.append(tag) + + html = ' '.join(tags) + + return html + + +def _get_color(importance: float, opacity: float) -> str: + # clip values to prevent CSS errors (Values should be from [-1,1]) + importance = max(-1, min(1, importance)) if importance > 0: - color = f'rgba(255, 0, 0, {opacity:.2f})' + hue = 0 + sat = 100 + lig = 100 - int(50 * importance) else: - color = f'rgba(0, 0, 255, {opacity:2f})' - highlighted_word = f'{word}' - return highlighted_word + hue = 240 + sat = 100 + lig = 100 - int(-50 * importance) + return f'hsl({hue}, {sat}%, {lig}%, {opacity})' diff --git a/setup.cfg b/setup.cfg index 81292ef8..2cd21660 100644 --- a/setup.cfg +++ b/setup.cfg @@ -84,6 +84,7 @@ text = torchtext spacy dashboard = + importlib_resources;python_version<'3.10' keras Pillow plotly @@ -106,6 +107,10 @@ notebooks = torchvision ipywidgets +[options.entry_points] +console_scripts = + dianna-dashboard = dianna.cli:dashboard + [options.packages.find] include = dianna, dianna.* diff --git a/tests/test_text_visualization.py b/tests/test_text_visualization.py index 7859a33f..3fa5849a 100644 --- a/tests/test_text_visualization.py +++ b/tests/test_text_visualization.py @@ -25,11 +25,13 @@ class TextExample: class TextExampleWithExpectedHtml: """Short text and explanation and its expected html output after visualizing.""" - expected_html = 'Such ' \ - 'a ' \ - 'bad ' \ - 'movie ' \ - '.\n' + expected_html = ( + 'Such ' + 'a ' + 'bad ' + 'movie ' + '.\n' + ) original_text = 'Such a bad movie.' @@ -74,6 +76,7 @@ def test_text_visualization_html_output_is_correct(self): with open(self.html_file_path, encoding='utf-8') as result_file: result = result_file.read() + assert result == TextExampleWithExpectedHtml.expected_html def test_text_visualization_show_plot(self):