-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix remaining issues with dashboard and revise layout (#551)
* Change accent color to purple * Capitalize tab names * Add expander to hide method parameters * Use checkboxes instead of multiselect * Add check box to load example data * Open links in same tab * Change dianna's dashboard title * Refactor methods checkboxes * Make number input smaller * Add class column for text and image * Fix sorting of indices * Remove tick parameters from images * Refactor getting top indices and labels * Add cli entry point for the dashboard * Tweak top results spinner and predicted class output * Add sidebar logo * Add boilerplate for timeseries data * Refactor method parameter definition code * Filter kernelshap debug messages * Add warning for timeseries page * Add importlib resources backport * Improve formatting of text visualization
- Loading branch information
1 parent
750ccee
commit 9c0b6a3
Showing
21 changed files
with
631 additions
and
306 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
def predict(): | ||
pass | ||
|
||
|
||
explain_ts_dispatcher = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import base64 | ||
from pathlib import Path | ||
from typing import Any | ||
from typing import Dict | ||
from typing import Sequence | ||
import numpy as np | ||
import streamlit as st | ||
|
||
|
||
data_directory = Path(__file__).parent / 'data' | ||
|
||
|
||
@st.cache_data | ||
def get_base64_of_bin_file(png_file): | ||
with open(png_file, 'rb') as f: | ||
data = f.read() | ||
return base64.b64encode(data).decode() | ||
|
||
|
||
def build_markup_for_logo( | ||
png_file, | ||
background_position='50% 10%', | ||
margin_top='10%', | ||
image_width='60%', | ||
image_height='', | ||
): | ||
binary_string = get_base64_of_bin_file(png_file) | ||
return f""" | ||
<style> | ||
[data-testid="stSidebarNav"] {{ | ||
background-image: url("data:image/png;base64,{binary_string}"); | ||
background-repeat: no-repeat; | ||
background-position: {background_position}; | ||
margin-top: {margin_top}; | ||
background-size: {image_width} {image_height}; | ||
}} | ||
</style> | ||
""" | ||
|
||
|
||
def add_sidebar_logo(): | ||
"""Based on: https://stackoverflow.com/a/73278825.""" | ||
png_file = data_directory / 'logo.png' | ||
logo_markup = build_markup_for_logo(png_file) | ||
st.markdown( | ||
logo_markup, | ||
unsafe_allow_html=True, | ||
) | ||
|
||
|
||
def _methods_checkboxes(*, choices: Sequence): | ||
"""Get methods from a horizontal row of checkboxes.""" | ||
n_choices = len(choices) | ||
methods = [] | ||
for col, method in zip(st.columns(n_choices), choices): | ||
with col: | ||
if st.checkbox(method): | ||
methods.append(method) | ||
|
||
if not methods: | ||
st.info('Select a method to continue') | ||
st.stop() | ||
|
||
return methods | ||
|
||
|
||
def _get_params(method: str): | ||
if method == 'RISE': | ||
return { | ||
'n_masks': | ||
st.number_input('Number of masks', value=1000), | ||
'feature_res': | ||
st.number_input('Feature resolution', value=6), | ||
'p_keep': | ||
st.number_input('Probability to be kept unmasked', value=0.1), | ||
} | ||
|
||
elif method == 'KernelSHAP': | ||
return { | ||
'nsamples': st.number_input('Number of samples', value=1000), | ||
'background': st.number_input('Background', value=0), | ||
'n_segments': st.number_input('Number of segments', value=200), | ||
'sigma': st.number_input('σ', value=0), | ||
} | ||
|
||
elif method == 'LIME': | ||
return { | ||
'rand_state': st.number_input('Random state', value=2), | ||
} | ||
|
||
else: | ||
raise ValueError(f'No such method: {method}') | ||
|
||
|
||
def _get_method_params(methods: Sequence[str]) -> Dict[str, Dict[str, Any]]: | ||
method_params = {} | ||
|
||
with st.expander('Click to modify method parameters'): | ||
for method, col in zip(methods, st.columns(len(methods))): | ||
with col: | ||
st.header(method) | ||
method_params[method] = _get_params(method) | ||
|
||
return method_params | ||
|
||
|
||
def _get_top_indices(predictions, n_top): | ||
indices = np.array(np.argpartition(predictions, -n_top)[-n_top:]) | ||
indices = indices[np.argsort(predictions[indices])] | ||
indices = np.flip(indices) | ||
return indices | ||
|
||
|
||
def _get_top_indices_and_labels(*, predictions, labels): | ||
c1, c2 = st.columns(2) | ||
|
||
with c2: | ||
n_top = st.number_input('Number of top results to show', | ||
value=2, | ||
min_value=1, | ||
max_value=len(labels)) | ||
|
||
top_indices = _get_top_indices(predictions, n_top) | ||
top_labels = [labels[i] for i in top_indices] | ||
|
||
with c1: | ||
st.metric('Predicted class', top_labels[0]) | ||
|
||
return top_indices, top_labels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import numpy as np | ||
|
||
|
||
def open_timeseries(file): | ||
"""Open a time series from a file and returns it as a numpy array.""" | ||
return np.arange(10), np.arange(10)**2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import streamlit as st | ||
from _image_utils import open_image | ||
from _model_utils import load_labels | ||
from _model_utils import load_model | ||
from _models_image import explain_image_dispatcher | ||
from _models_image import predict | ||
from _shared import _get_method_params | ||
from _shared import _get_top_indices_and_labels | ||
from _shared import _methods_checkboxes | ||
from _shared import add_sidebar_logo | ||
from _shared import data_directory | ||
from dianna.visualization import plot_image | ||
|
||
|
||
add_sidebar_logo() | ||
|
||
st.title('Image explanation') | ||
|
||
with st.sidebar: | ||
st.header('Input data') | ||
|
||
load_example = st.checkbox('Load example data', key='image_example_check') | ||
|
||
image_file = st.file_uploader('Select image', | ||
type=('png', 'jpg', 'jpeg'), | ||
disabled=load_example) | ||
|
||
if image_file: | ||
st.image(image_file) | ||
|
||
image_model_file = st.file_uploader('Select model', | ||
type='onnx', | ||
disabled=load_example) | ||
|
||
image_label_file = st.file_uploader('Select labels', | ||
type='txt', | ||
disabled=load_example) | ||
|
||
if load_example: | ||
image_file = (data_directory / 'digit0.png') | ||
image_model_file = (data_directory / 'mnist_model_tf.onnx') | ||
image_label_file = (data_directory / 'labels_mnist.txt') | ||
|
||
if not (image_file and image_model_file and image_label_file): | ||
st.info('Add your input data in the left panel to continue') | ||
st.stop() | ||
|
||
image, _ = open_image(image_file) | ||
|
||
model = load_model(image_model_file) | ||
serialized_model = model.SerializeToString() | ||
|
||
labels = load_labels(image_label_file) | ||
|
||
choices = ('RISE', 'KernelSHAP', 'LIME') | ||
methods = _methods_checkboxes(choices=choices) | ||
|
||
method_params = _get_method_params(methods) | ||
|
||
with st.spinner('Predicting class'): | ||
predictions = predict(model=model, image=image) | ||
|
||
top_indices, top_labels = _get_top_indices_and_labels(predictions=predictions, | ||
labels=labels) | ||
|
||
# check which axis is color channel | ||
original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :] | ||
axis_labels = {2: 'channels'} if image.shape[2] <= 3 else {0: 'channels'} | ||
|
||
weight = 0.9 / len(methods) | ||
column_spec = [0.1, *[weight for _ in methods]] | ||
|
||
_, *columns = st.columns(column_spec) | ||
for col, method in zip(columns, methods): | ||
with col: | ||
st.header(method) | ||
|
||
for index, label in zip(top_indices, top_labels): | ||
index_col, *columns = st.columns(column_spec) | ||
|
||
with index_col: | ||
st.markdown(f'##### {label}') | ||
|
||
for col, method in zip(columns, methods): | ||
kwargs = method_params[method].copy() | ||
kwargs['axis_labels'] = axis_labels | ||
kwargs['labels'] = [index] | ||
|
||
func = explain_image_dispatcher[method] | ||
|
||
with col: | ||
with st.spinner(f'Running {method}'): | ||
heatmap = func(serialized_model, image, index, **kwargs) | ||
|
||
fig = plot_image(heatmap, | ||
original_data=original_data, | ||
heatmap_cmap='bwr', | ||
show_plot=False) | ||
|
||
st.pyplot(fig) |
Oops, something went wrong.