Skip to content

Commit

Permalink
Merge pull request #844 from dianna-ai/838-create-tabular-tab-to-dash…
Browse files Browse the repository at this point in the history
…board

838 create tabular tab to dashboard and redesign loaded data results #819
  • Loading branch information
laurasootes authored Sep 18, 2024
2 parents 47dfd99 + f38fd09 commit 45fd4ff
Show file tree
Hide file tree
Showing 10 changed files with 357 additions and 99 deletions.
19 changes: 11 additions & 8 deletions dianna/dashboard/Home.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import importlib
import streamlit as st
from _shared import add_sidebar_logo
from _shared import data_directory
from streamlit_option_menu import option_menu

st.set_page_config(page_title="Dianna's dashboard",
page_icon='📊',
layout='centered',
layout='wide',
initial_sidebar_state='auto',
menu_items={
'Get help':
Expand All @@ -22,6 +21,7 @@
pages = {
"Home": "home",
"Images": "pages.Images",
"Tabular": "pages.Tabular",
"Text": "pages.Text",
"Time series": "pages.Time_series"
}
Expand All @@ -30,16 +30,14 @@
selected = option_menu(
menu_title=None,
options=list(pages.keys()),
icons=["house", "camera", "alphabet", "clock"],
icons=["house", "camera", "table", "alphabet", "clock"],
menu_icon="cast",
default_index=0,
orientation="horizontal"
)

# Display the content of the selected page
if selected == "Home":
add_sidebar_logo()

st.image(str(data_directory / 'logo.png'))

st.markdown("""
Expand All @@ -50,9 +48,10 @@
### Pages
- <a href="/Images" target="_parent">Images</a>
- <a href="/Text" target="_parent">Text</a>
- <a href="/Time_series" target="_parent">Time series</a>
- <a href="/Images" target="_parent">Image data</a>
- <a href="/Tabular" target="_parent">Tabular data</a>
- <a href="/Text" target="_parent">Text data</a>
- <a href="/Time_series" target="_parent">Time series data</a>
### More information
Expand All @@ -70,6 +69,10 @@
for k in st.session_state.keys():
if 'Image' in k:
st.session_state.pop(k, None)
if selected != 'Tabular':
for k in st.session_state.keys():
if 'Tabular' in k:
st.session_state.pop(k, None)
if selected != 'Text':
for k in st.session_state.keys():
if 'Text' in k:
Expand Down
13 changes: 13 additions & 0 deletions dianna/dashboard/_model_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from pathlib import Path
import numpy as np
import onnx
import pandas as pd


def load_data(file):
"""Open data from a file and returns it as pandas DataFrame."""
df = pd.read_csv(file, parse_dates=True)
# Add index column
df.insert(0, 'Index', df.index)
return df


def preprocess_function(image):
Expand Down Expand Up @@ -29,3 +38,7 @@ def load_labels(file):
if labels is None or labels == ['']:
raise ValueError(labels)
return labels


def load_training_data(file):
return np.float32(np.load(file, allow_pickle=False))
57 changes: 57 additions & 0 deletions dianna/dashboard/_models_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import tempfile
import numpy as np
import streamlit as st
from dianna import explain_tabular
from dianna.utils.onnx_runner import SimpleModelRunner


@st.cache_data
def predict(*, model, tabular_input):
model_runner = SimpleModelRunner(model)
predictions = model_runner(tabular_input.reshape(1,-1).astype(np.float32))
return predictions


@st.cache_data
def _run_rise_tabular(_model, table, training_data, **kwargs):
relevances = explain_tabular(
_model,
table,
method='RISE',
training_data=training_data,
**kwargs,
)
return relevances


@st.cache_data
def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs):
relevances = explain_tabular(
_model,
table,
method='LIME',
training_data=training_data,
feature_names=_feature_names,
**kwargs,
)
return relevances

@st.cache_data
def _run_kernelshap_tabular(model, table, training_data, **kwargs):
# Kernelshap interface is different. Write model to temporary file.
with tempfile.NamedTemporaryFile() as f:
f.write(model)
f.flush()
relevances = explain_tabular(f.name,
table,
method='KernelSHAP',
training_data=training_data,
**kwargs)
return relevances[0]


explain_tabular_dispatcher = {
'RISE': _run_rise_tabular,
'LIME': _run_lime_tabular,
'KernelSHAP': _run_kernelshap_tabular
}
96 changes: 48 additions & 48 deletions dianna/dashboard/_shared.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import base64
import sys
from typing import Any
from typing import Dict
from typing import Sequence
import numpy as np
import streamlit as st
Expand Down Expand Up @@ -46,71 +44,67 @@ def build_markup_for_logo(


def add_sidebar_logo():
"""Based on: https://stackoverflow.com/a/73278825."""
png_file = data_directory / 'logo.png'
logo_markup = build_markup_for_logo(png_file)
st.markdown(
logo_markup,
unsafe_allow_html=True,
)
"""Upload DIANNA logo to sidebar element."""
st.sidebar.image(str(data_directory / 'logo.png'))


def _methods_checkboxes(*, choices: Sequence, key):
"""Get methods from a horizontal row of checkboxes."""
"""Get methods from a horizontal row of checkboxes and the corresponding parameters."""
n_choices = len(choices)
methods = []
method_params = {}

# Create a container for the message
message_container = st.empty()

for col, method in zip(st.columns(n_choices), choices):
with col:
if st.checkbox(method, key=key + method):
if st.checkbox(method, key=f'{key}_{method}'):
methods.append(method)
with st.expander(f'Click to modify {method} parameters'):
method_params[method] = _get_params(method, key=f'{key}_param')

if not methods:
st.info('Select a method to continue')
# Put the message in the container above
message_container.info('Select a method to continue')
st.stop()

return methods
return methods, method_params


def _get_params(method: str, key):
if method == 'RISE':
return {
'n_masks':
st.number_input('Number of masks', value=1000, key=key + method + 'nmasks'),
st.number_input('Number of masks', value=1000, key=f'{key}_{method}_nmasks'),
'feature_res':
st.number_input('Feature resolution', value=6, key=key + method + 'fr'),
st.number_input('Feature resolution', value=6, key=f'{key}_{method}_fr'),
'p_keep':
st.number_input('Probability to be kept unmasked', value=0.1, key=key + method + 'pkeep'),
st.number_input('Probability to be kept unmasked', value=0.1, key=f'{key}_{method}_pkeep'),
}

elif method == 'KernelSHAP':
return {
'nsamples': st.number_input('Number of samples', value=1000, key=key + method + 'nsamp'),
'background': st.number_input('Background', value=0, key=key + method + 'background'),
'n_segments': st.number_input('Number of segments', value=200, key=key + method + 'nseg'),
'sigma': st.number_input('σ', value=0, key=key + method + 'sigma'),
}
if 'Tabular' in key:
return {'training_data_kmeans': st.number_input('Training data kmeans', value=5,
key=f'{key}_{method}_training_data_kmeans'),
}
else:
return {
'nsamples': st.number_input('Number of samples', value=1000, key=f'{key}_{method}_nsamp'),
'background': st.number_input('Background', value=0, key=f'{key}_{method}_background'),
'n_segments': st.number_input('Number of segments', value=200, key=f'{key}_{method}_nseg'),
'sigma': st.number_input('σ', value=0, key=f'{key}_{method}_sigma'),
}

elif method == 'LIME':
return {
'random_state': st.number_input('Random state', value=2, key=key + method + 'rs'),
'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'),
}

else:
raise ValueError(f'No such method: {method}')


def _get_method_params(methods: Sequence[str], key) -> Dict[str, Dict[str, Any]]:
method_params = {}

with st.expander('Click to modify method parameters'):
for method, col in zip(methods, st.columns(len(methods))):
with col:
st.header(method)
method_params[method] = _get_params(method, key=key)

return method_params


def _get_top_indices(predictions, n_top):
indices = np.array(np.argpartition(predictions, -n_top)[-n_top:])
indices = indices[np.argsort(predictions[indices])]
Expand All @@ -119,29 +113,35 @@ def _get_top_indices(predictions, n_top):


def _get_top_indices_and_labels(*, predictions, labels):
c1, c2 = st.columns(2)
cols = st.columns(4)

with c2:
n_top = st.number_input('Number of top results to show',
value=2,
min_value=1,
max_value=len(labels))
if labels is not None:
with cols[-1]:
n_top = st.number_input('Number of top classes to show',
value=1,
min_value=1,
max_value=len(labels))

top_indices = _get_top_indices(predictions, n_top)
top_labels = [labels[i] for i in top_indices]
top_indices = _get_top_indices(predictions, n_top)
top_labels = [labels[i] for i in top_indices]

with c1:
st.metric('Predicted class', top_labels[0])
with cols[0]:
st.metric('Predicted class:', top_labels[0])
else:
# If not a classifier, only return the predicted value
top_indices = top_labels = " "
with cols[0]:
st.metric('Predicted value:', f"{predictions[0]:.2f}")

return top_indices, top_labels

def reset_method():
# Clear selection
for k in st.session_state.keys():
if '_cb_' in k:
st.session_state[k] = False
if 'params' in k:
if '_param' in k:
st.session_state.pop(k)
elif '_cb' in k:
st.session_state[k] = False

def reset_example():
# Clear selection
Expand Down
25 changes: 16 additions & 9 deletions dianna/dashboard/pages/Images.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from _model_utils import load_model
from _models_image import explain_image_dispatcher
from _models_image import predict
from _shared import _get_method_params
from _shared import _get_top_indices_and_labels
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
Expand Down Expand Up @@ -88,15 +87,23 @@
labels = load_labels(image_label_file)

choices = ('RISE', 'KernelSHAP', 'LIME')
methods = _methods_checkboxes(choices=choices, key='Image_cb_')

method_params = _get_method_params(methods, key='Image_params_')
st.text("")
st.text("")

with st.spinner('Predicting class'):
predictions = predict(model=model, image=image)
with st.container(border=True):
prediction_placeholder = st.empty()
methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb')

top_indices, top_labels = _get_top_indices_and_labels(predictions=predictions,
labels=labels)
with st.spinner('Predicting class'):
predictions = predict(model=model, image=image)

with prediction_placeholder:
top_indices, top_labels = _get_top_indices_and_labels(
predictions=predictions,labels=labels)

st.text("")
st.text("")

# check which axis is color channel
original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :]
Expand All @@ -107,11 +114,11 @@

_, *columns = st.columns(column_spec)
for col, method in zip(columns, methods):
col.header(method)
col.markdown(f"<h4 style='text-align: center; '>{method}</h4>", unsafe_allow_html=True)

for index, label in zip(top_indices, top_labels):
index_col, *columns = st.columns(column_spec)
index_col.markdown(f'##### {label}')
index_col.markdown(f'##### Class: {label}')

for col, method in zip(columns, methods):
kwargs = method_params[method].copy()
Expand Down
Loading

0 comments on commit 45fd4ff

Please sign in to comment.