Skip to content

Commit

Permalink
Fix remaining issues with dashboard and revise layout (#551)
Browse files Browse the repository at this point in the history
* Change accent color to purple
* Capitalize tab names
* Add expander to hide method parameters
* Use checkboxes instead of multiselect
* Add check box to load example data
* Open links in same tab
* Change dianna's dashboard title
* Refactor methods checkboxes
* Make number input smaller
* Add class column for text and image
* Fix sorting of indices
* Remove tick parameters from images
* Refactor getting top indices and labels
* Add cli entry point for the dashboard
* Tweak top results spinner and predicted class output
* Add sidebar logo
* Add boilerplate for timeseries data
* Refactor method parameter definition code
* Filter kernelshap debug messages
* Add warning for timeseries page
* Add importlib resources backport
* Improve formatting of text visualization
  • Loading branch information
stefsmeets authored Apr 18, 2023
1 parent 750ccee commit 9c0b6a3
Show file tree
Hide file tree
Showing 21 changed files with 631 additions and 306 deletions.
2 changes: 1 addition & 1 deletion dashboard/.streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[theme]
base="light"
primaryColor="0e749b"
primaryColor="7030a0"
# backgroundColor=
secondaryBackgroundColor="#e4f3f9"
# textColor=
Expand Down
26 changes: 18 additions & 8 deletions dashboard/dashboard.py → dashboard/Home.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import streamlit as st
from _shared import add_sidebar_logo
from _shared import data_directory


st.set_page_config(page_title="Dianna's dashboard",
Expand All @@ -15,18 +17,26 @@
'https://github.com/dianna-ai/dianna')
})

st.image(
'https://user-images.githubusercontent.com/3244249/151994514-b584b984-a148-4ade-80ee-0f88b0aefa45.png'
)
add_sidebar_logo()

st.title("Dianna's dashboard")
st.image(str(data_directory / 'logo.png'))

st.write("""
st.markdown("""
DIANNA is a Python package that brings explainable AI (XAI) to your research project.
It wraps carefully selected XAI methods in a simple, uniform interface. It's built by,
with and for (academic) researchers and research software engineers working on machine
learning projects.
- [Images](/images)
- [Text](/text)
""")
### Pages
- <a href="/Images" target="_parent">Images</a>
- <a href="/Text" target="_parent">Text</a>
- <a href="/Time series" target="_parent">Time series</a>
### More information
- [Source code](https://github.com/dianna-ai/dianna)
- [Documentation](https://dianna.readthedocs.io/)
""",
unsafe_allow_html=True)
4 changes: 4 additions & 0 deletions dashboard/_model_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
import numpy as np
import onnx

Expand All @@ -21,6 +22,9 @@ def load_model(file):


def load_labels(file):
if isinstance(file, (str, Path)):
file = open(file, 'rb')

labels = [line.decode().rstrip() for line in file.readlines()]
if labels is None or labels == ['']:
raise ValueError(labels)
Expand Down
15 changes: 6 additions & 9 deletions dashboard/_models_image.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
import tempfile
import numpy as np
import streamlit as st
from _model_utils import fill_segmentation
from _model_utils import preprocess_function
from onnx_tf.backend import prepare
from dianna import explain_image


def get_top_indices(predictions, n_top):
indices = np.array(np.argpartition(predictions, -n_top)[-n_top:])
indices = indices[np.argsort(predictions[indices])]
indices = np.flip(indices)
return indices


@st.cache_data
def predict(*, model, image):
output_node = prepare(model, gen_tensor_dict=True).outputs[0]
Expand All @@ -26,6 +18,7 @@ def _run_rise_image(model, image, i, **kwargs):
relevances = explain_image(
model,
image,
method='RISE',
**kwargs,
)
return relevances[0]
Expand All @@ -37,6 +30,7 @@ def _run_lime_image(model, image, i, **kwargs):
model,
image * 256,
preprocess_function=preprocess_function,
method='LIME',
**kwargs,
)
return relevances[0]
Expand All @@ -48,7 +42,10 @@ def _run_kernelshap_image(model, image, i, **kwargs):
with tempfile.NamedTemporaryFile() as f:
f.write(model)
f.flush()
shap_values, segments_slic = explain_image(f.name, image, **kwargs)
shap_values, segments_slic = explain_image(f.name,
image,
method='KernelSHAP',
**kwargs)

return fill_segmentation(shap_values[i][0], segments_slic)

Expand Down
3 changes: 2 additions & 1 deletion dashboard/_models_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ def _run_rise_text(_model, text, **kwargs):
_model,
text,
tokenizer,
method='RISE',
**kwargs,
)
return relevances


@st.cache_data
def _run_lime_text(_model, text, **kwargs):
relevances = explain_text(_model, text, tokenizer, **kwargs)
relevances = explain_text(_model, text, tokenizer, method='LIME', **kwargs)
return relevances


Expand Down
5 changes: 5 additions & 0 deletions dashboard/_models_ts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def predict():
pass


explain_ts_dispatcher = {}
129 changes: 129 additions & 0 deletions dashboard/_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import base64
from pathlib import Path
from typing import Any
from typing import Dict
from typing import Sequence
import numpy as np
import streamlit as st


data_directory = Path(__file__).parent / 'data'


@st.cache_data
def get_base64_of_bin_file(png_file):
with open(png_file, 'rb') as f:
data = f.read()
return base64.b64encode(data).decode()


def build_markup_for_logo(
png_file,
background_position='50% 10%',
margin_top='10%',
image_width='60%',
image_height='',
):
binary_string = get_base64_of_bin_file(png_file)
return f"""
<style>
[data-testid="stSidebarNav"] {{
background-image: url("data:image/png;base64,{binary_string}");
background-repeat: no-repeat;
background-position: {background_position};
margin-top: {margin_top};
background-size: {image_width} {image_height};
}}
</style>
"""


def add_sidebar_logo():
"""Based on: https://stackoverflow.com/a/73278825."""
png_file = data_directory / 'logo.png'
logo_markup = build_markup_for_logo(png_file)
st.markdown(
logo_markup,
unsafe_allow_html=True,
)


def _methods_checkboxes(*, choices: Sequence):
"""Get methods from a horizontal row of checkboxes."""
n_choices = len(choices)
methods = []
for col, method in zip(st.columns(n_choices), choices):
with col:
if st.checkbox(method):
methods.append(method)

if not methods:
st.info('Select a method to continue')
st.stop()

return methods


def _get_params(method: str):
if method == 'RISE':
return {
'n_masks':
st.number_input('Number of masks', value=1000),
'feature_res':
st.number_input('Feature resolution', value=6),
'p_keep':
st.number_input('Probability to be kept unmasked', value=0.1),
}

elif method == 'KernelSHAP':
return {
'nsamples': st.number_input('Number of samples', value=1000),
'background': st.number_input('Background', value=0),
'n_segments': st.number_input('Number of segments', value=200),
'sigma': st.number_input('σ', value=0),
}

elif method == 'LIME':
return {
'rand_state': st.number_input('Random state', value=2),
}

else:
raise ValueError(f'No such method: {method}')


def _get_method_params(methods: Sequence[str]) -> Dict[str, Dict[str, Any]]:
method_params = {}

with st.expander('Click to modify method parameters'):
for method, col in zip(methods, st.columns(len(methods))):
with col:
st.header(method)
method_params[method] = _get_params(method)

return method_params


def _get_top_indices(predictions, n_top):
indices = np.array(np.argpartition(predictions, -n_top)[-n_top:])
indices = indices[np.argsort(predictions[indices])]
indices = np.flip(indices)
return indices


def _get_top_indices_and_labels(*, predictions, labels):
c1, c2 = st.columns(2)

with c2:
n_top = st.number_input('Number of top results to show',
value=2,
min_value=1,
max_value=len(labels))

top_indices = _get_top_indices(predictions, n_top)
top_labels = [labels[i] for i in top_indices]

with c1:
st.metric('Predicted class', top_labels[0])

return top_indices, top_labels
6 changes: 6 additions & 0 deletions dashboard/_ts_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import numpy as np


def open_timeseries(file):
"""Open a time series from a file and returns it as a numpy array."""
return np.arange(10), np.arange(10)**2
100 changes: 100 additions & 0 deletions dashboard/pages/1_Images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import streamlit as st
from _image_utils import open_image
from _model_utils import load_labels
from _model_utils import load_model
from _models_image import explain_image_dispatcher
from _models_image import predict
from _shared import _get_method_params
from _shared import _get_top_indices_and_labels
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
from _shared import data_directory
from dianna.visualization import plot_image


add_sidebar_logo()

st.title('Image explanation')

with st.sidebar:
st.header('Input data')

load_example = st.checkbox('Load example data', key='image_example_check')

image_file = st.file_uploader('Select image',
type=('png', 'jpg', 'jpeg'),
disabled=load_example)

if image_file:
st.image(image_file)

image_model_file = st.file_uploader('Select model',
type='onnx',
disabled=load_example)

image_label_file = st.file_uploader('Select labels',
type='txt',
disabled=load_example)

if load_example:
image_file = (data_directory / 'digit0.png')
image_model_file = (data_directory / 'mnist_model_tf.onnx')
image_label_file = (data_directory / 'labels_mnist.txt')

if not (image_file and image_model_file and image_label_file):
st.info('Add your input data in the left panel to continue')
st.stop()

image, _ = open_image(image_file)

model = load_model(image_model_file)
serialized_model = model.SerializeToString()

labels = load_labels(image_label_file)

choices = ('RISE', 'KernelSHAP', 'LIME')
methods = _methods_checkboxes(choices=choices)

method_params = _get_method_params(methods)

with st.spinner('Predicting class'):
predictions = predict(model=model, image=image)

top_indices, top_labels = _get_top_indices_and_labels(predictions=predictions,
labels=labels)

# check which axis is color channel
original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :]
axis_labels = {2: 'channels'} if image.shape[2] <= 3 else {0: 'channels'}

weight = 0.9 / len(methods)
column_spec = [0.1, *[weight for _ in methods]]

_, *columns = st.columns(column_spec)
for col, method in zip(columns, methods):
with col:
st.header(method)

for index, label in zip(top_indices, top_labels):
index_col, *columns = st.columns(column_spec)

with index_col:
st.markdown(f'##### {label}')

for col, method in zip(columns, methods):
kwargs = method_params[method].copy()
kwargs['axis_labels'] = axis_labels
kwargs['labels'] = [index]

func = explain_image_dispatcher[method]

with col:
with st.spinner(f'Running {method}'):
heatmap = func(serialized_model, image, index, **kwargs)

fig = plot_image(heatmap,
original_data=original_data,
heatmap_cmap='bwr',
show_plot=False)

st.pyplot(fig)
Loading

0 comments on commit 9c0b6a3

Please sign in to comment.