Skip to content

Commit

Permalink
Merge pull request #847 from dianna-ai/789-add-tabular-example-in-das…
Browse files Browse the repository at this point in the history
…hboard

789 add tabular example in dashboard and more example upgrades
  • Loading branch information
elboyran authored Sep 18, 2024
2 parents 45fd4ff + 2f41d15 commit f742047
Show file tree
Hide file tree
Showing 18 changed files with 879 additions and 332 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ jobs:
- name: Run unit tests
run: python -m pytest -v

- name: Verify that we can build the package
run: python setup.py sdist bdist_wheel
#- name: Verify that we can build the package
# run: python setup.py sdist bdist_wheel

test_downloader:
name: Test file downloader
Expand All @@ -73,7 +73,8 @@ jobs:

test_dashboard:
name: Test dashboard
if: github.event.pull_request.draft == false
if: always()
#github.event.pull_request.draft == false
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions dianna/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def dashboard():
*('--theme.primaryColor', '7030a0'),
*('--theme.secondaryBackgroundColor', 'e4f3f9'),
*('--browser.gatherUsageStats', 'false'),
*('--client.showSidebarNavigation', 'false'),
*args,
]

Expand Down
8 changes: 0 additions & 8 deletions dianna/dashboard/Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@
with and for (academic) researchers and research software engineers working on machine
learning projects.
### Pages
- <a href="/Images" target="_parent">Image data</a>
- <a href="/Tabular" target="_parent">Tabular data</a>
- <a href="/Text" target="_parent">Text data</a>
- <a href="/Time_series" target="_parent">Time series data</a>
### More information
- [Source code](https://github.com/dianna-ai/dianna)
Expand Down
39 changes: 39 additions & 0 deletions dianna/dashboard/_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import onnx
import pandas as pd
from sklearn.model_selection import train_test_split


def load_data(file):
Expand Down Expand Up @@ -42,3 +43,41 @@ def load_labels(file):

def load_training_data(file):
return np.float32(np.load(file, allow_pickle=False))


def load_sunshine(file):
"""Tabular sunshine example.
Load the csv file in a pandas dataframe and split the data in a train and test set.
"""
data = load_data(file)

# Drop unused columns
X_data = data.drop(columns=['DATE', 'MONTH', 'Index'])[:-1]
y_data = data.loc[1:]["BASEL_sunshine"]

# Split the data
X_train, X_holdout, _, y_holdout = train_test_split(X_data, y_data, test_size=0.3, random_state=0)
_, X_test, _, _ = train_test_split(X_holdout, y_holdout, test_size=0.5, random_state=0)
X_test = X_test.reset_index(drop=True)
X_test.insert(0, 'Index', X_test.index)

return X_train.to_numpy(dtype=np.float32), X_test

def load_penguins(penguins):
"""Prep the data for the penguin model example as per ntoebook."""
# Remove categorial columns and NaN values
penguins_filtered = penguins.drop(columns=['island', 'sex']).dropna()


# Extract inputs and target
input_features = penguins_filtered.drop(columns=['species'])
target = pd.get_dummies(penguins_filtered['species'])

X_train, X_test, _, _ = train_test_split(input_features, target, test_size=0.2,
random_state=0, shuffle=True, stratify=target)

X_test = X_test.reset_index(drop=True)
X_test.insert(0, 'Index', X_test.index)

return X_train.to_numpy(dtype=np.float32), X_test
55 changes: 41 additions & 14 deletions dianna/dashboard/_models_tabular.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,55 @@
import tempfile
import numpy as np
import onnxruntime as ort
import streamlit as st
from dianna import explain_tabular
from dianna.utils.onnx_runner import SimpleModelRunner


@st.cache_data
def predict(*, model, tabular_input):
model_runner = SimpleModelRunner(model)
predictions = model_runner(tabular_input.reshape(1,-1).astype(np.float32))
return predictions
# Make sure that tabular input is provided as float32
sess = ort.InferenceSession(model)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

onnx_input = {input_name: tabular_input.astype(np.float32)}
pred_onnx = sess.run([output_name], onnx_input)[0]

return pred_onnx


@st.cache_data
def _run_rise_tabular(_model, table, training_data, **kwargs):
def _run_rise_tabular(_model, table, training_data,_feature_names, **kwargs):
# convert streamlit kwarg requirement back to dianna kwarg requirement
if "_preprocess_function" in kwargs:
kwargs["preprocess_function"] = kwargs["_preprocess_function"]
del kwargs["_preprocess_function"]

def run_model(tabular_input):
return predict(model=_model, tabular_input=tabular_input)

relevances = explain_tabular(
_model,
run_model,
table,
method='RISE',
training_data=training_data,
feature_names=_feature_names,
**kwargs,
)
return relevances


@st.cache_data
def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs):
# convert streamlit kwarg requirement back to dianna kwarg requirement
if "_preprocess_function" in kwargs:
kwargs["preprocess_function"] = kwargs["_preprocess_function"]
del kwargs["_preprocess_function"]

def run_model(tabular_input):
return predict(model=_model, tabular_input=tabular_input)

relevances = explain_tabular(
_model,
run_model,
table,
method='LIME',
training_data=training_data,
Expand All @@ -37,17 +59,22 @@ def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs):
return relevances

@st.cache_data
def _run_kernelshap_tabular(model, table, training_data, **kwargs):
def _run_kernelshap_tabular(model, table, training_data, _feature_names, **kwargs):
# Kernelshap interface is different. Write model to temporary file.
with tempfile.NamedTemporaryFile() as f:
f.write(model)
f.flush()
relevances = explain_tabular(f.name,
if "_preprocess_function" in kwargs:
kwargs["preprocess_function"] = kwargs["_preprocess_function"]
del kwargs["_preprocess_function"]

def run_model(tabular_input):
return predict(model=model, tabular_input=tabular_input)

relevances = explain_tabular(run_model,
table,
method='KernelSHAP',
training_data=training_data,
feature_names=_feature_names,
**kwargs)
return relevances[0]
return np.array(relevances)


explain_tabular_dispatcher = {
Expand Down
27 changes: 22 additions & 5 deletions dianna/dashboard/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,25 @@ def _methods_checkboxes(*, choices: Sequence, key):

def _get_params(method: str, key):
if method == 'RISE':
n_masks = 1000
fr = 8
pkeep = 0.1
if 'FRB' in key:
n_masks = 5000
fr = 16
elif 'Tabular' in key:
pkeep = 0.5
elif 'Weather' in key:
n_masks = 10000
elif 'Digits' in key:
n_masks = 5000
return {
'n_masks':
st.number_input('Number of masks', value=1000, key=f'{key}_{method}_nmasks'),
st.number_input('Number of masks', value=n_masks, key=f'{key}_{method}_nmasks'),
'feature_res':
st.number_input('Feature resolution', value=6, key=f'{key}_{method}_fr'),
st.number_input('Feature resolution', value=fr, key=f'{key}_{method}_fr'),
'p_keep':
st.number_input('Probability to be kept unmasked', value=0.1, key=f'{key}_{method}_pkeep'),
st.number_input('Probability to be kept unmasked', value=pkeep, key=f'{key}_{method}_pkeep'),
}

elif method == 'KernelSHAP':
Expand All @@ -97,9 +109,14 @@ def _get_params(method: str, key):
}

elif method == 'LIME':
return {
'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'),
if 'Tabular' in key:
return {
'random_state': st.number_input('Random state', value=0, key=f'{key}_{method}_rs'),
}
else:
return {
'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'),
}

else:
raise ValueError(f'No such method: {method}')
Expand Down
Binary file removed dianna/dashboard/dashboard-screenshot.png
Binary file not shown.
6 changes: 5 additions & 1 deletion dianna/dashboard/pages/Images.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
image_model_file = download('mnist_model_tf.onnx', 'model')
image_label_file = download('labels_mnist.txt', 'label')

imagekey = 'Digits_Image_cb'

st.markdown(
"""
This example demonstrates the use of DIANNA on a pretrained binary
Expand Down Expand Up @@ -71,6 +73,8 @@
image_label_file = st.sidebar.file_uploader('Select labels',
type='txt')

imagekey = 'Image_cb'

if input_type is None:
st.info('Select which input type to use in the left panel to continue')
st.stop()
Expand All @@ -93,7 +97,7 @@

with st.container(border=True):
prediction_placeholder = st.empty()
methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb')
methods, method_params = _methods_checkboxes(choices=choices, key=imagekey)

with st.spinner('Predicting class'):
predictions = predict(model=model, image=image)
Expand Down
Loading

0 comments on commit f742047

Please sign in to comment.