diff --git a/dianna/dashboard/Home.py b/dianna/dashboard/Home.py
index c86fd264..51d5a118 100644
--- a/dianna/dashboard/Home.py
+++ b/dianna/dashboard/Home.py
@@ -1,12 +1,11 @@
import importlib
import streamlit as st
-from _shared import add_sidebar_logo
from _shared import data_directory
from streamlit_option_menu import option_menu
st.set_page_config(page_title="Dianna's dashboard",
page_icon='📊',
- layout='centered',
+ layout='wide',
initial_sidebar_state='auto',
menu_items={
'Get help':
@@ -22,6 +21,7 @@
pages = {
"Home": "home",
"Images": "pages.Images",
+ "Tabular": "pages.Tabular",
"Text": "pages.Text",
"Time series": "pages.Time_series"
}
@@ -30,7 +30,7 @@
selected = option_menu(
menu_title=None,
options=list(pages.keys()),
- icons=["house", "camera", "alphabet", "clock"],
+ icons=["house", "camera", "table", "alphabet", "clock"],
menu_icon="cast",
default_index=0,
orientation="horizontal"
@@ -38,8 +38,6 @@
# Display the content of the selected page
if selected == "Home":
- add_sidebar_logo()
-
st.image(str(data_directory / 'logo.png'))
st.markdown("""
@@ -50,9 +48,10 @@
### Pages
- - Images
- - Text
- - Time series
+ - Image data
+ - Tabular data
+ - Text data
+ - Time series data
### More information
@@ -70,6 +69,10 @@
for k in st.session_state.keys():
if 'Image' in k:
st.session_state.pop(k, None)
+ if selected != 'Tabular':
+ for k in st.session_state.keys():
+ if 'Tabular' in k:
+ st.session_state.pop(k, None)
if selected != 'Text':
for k in st.session_state.keys():
if 'Text' in k:
diff --git a/dianna/dashboard/_model_utils.py b/dianna/dashboard/_model_utils.py
index 272e4a40..cc8084d0 100644
--- a/dianna/dashboard/_model_utils.py
+++ b/dianna/dashboard/_model_utils.py
@@ -1,6 +1,15 @@
from pathlib import Path
import numpy as np
import onnx
+import pandas as pd
+
+
+def load_data(file):
+ """Open data from a file and returns it as pandas DataFrame."""
+ df = pd.read_csv(file, parse_dates=True)
+ # Add index column
+ df.insert(0, 'Index', df.index)
+ return df
def preprocess_function(image):
@@ -29,3 +38,7 @@ def load_labels(file):
if labels is None or labels == ['']:
raise ValueError(labels)
return labels
+
+
+def load_training_data(file):
+ return np.float32(np.load(file, allow_pickle=False))
diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py
new file mode 100644
index 00000000..96573326
--- /dev/null
+++ b/dianna/dashboard/_models_tabular.py
@@ -0,0 +1,57 @@
+import tempfile
+import numpy as np
+import streamlit as st
+from dianna import explain_tabular
+from dianna.utils.onnx_runner import SimpleModelRunner
+
+
+@st.cache_data
+def predict(*, model, tabular_input):
+ model_runner = SimpleModelRunner(model)
+ predictions = model_runner(tabular_input.reshape(1,-1).astype(np.float32))
+ return predictions
+
+
+@st.cache_data
+def _run_rise_tabular(_model, table, training_data, **kwargs):
+ relevances = explain_tabular(
+ _model,
+ table,
+ method='RISE',
+ training_data=training_data,
+ **kwargs,
+ )
+ return relevances
+
+
+@st.cache_data
+def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs):
+ relevances = explain_tabular(
+ _model,
+ table,
+ method='LIME',
+ training_data=training_data,
+ feature_names=_feature_names,
+ **kwargs,
+ )
+ return relevances
+
+@st.cache_data
+def _run_kernelshap_tabular(model, table, training_data, **kwargs):
+ # Kernelshap interface is different. Write model to temporary file.
+ with tempfile.NamedTemporaryFile() as f:
+ f.write(model)
+ f.flush()
+ relevances = explain_tabular(f.name,
+ table,
+ method='KernelSHAP',
+ training_data=training_data,
+ **kwargs)
+ return relevances[0]
+
+
+explain_tabular_dispatcher = {
+ 'RISE': _run_rise_tabular,
+ 'LIME': _run_lime_tabular,
+ 'KernelSHAP': _run_kernelshap_tabular
+}
diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py
index 35060525..6ed408ae 100644
--- a/dianna/dashboard/_shared.py
+++ b/dianna/dashboard/_shared.py
@@ -1,7 +1,5 @@
import base64
import sys
-from typing import Any
-from typing import Dict
from typing import Sequence
import numpy as np
import streamlit as st
@@ -46,71 +44,67 @@ def build_markup_for_logo(
def add_sidebar_logo():
- """Based on: https://stackoverflow.com/a/73278825."""
- png_file = data_directory / 'logo.png'
- logo_markup = build_markup_for_logo(png_file)
- st.markdown(
- logo_markup,
- unsafe_allow_html=True,
- )
+ """Upload DIANNA logo to sidebar element."""
+ st.sidebar.image(str(data_directory / 'logo.png'))
def _methods_checkboxes(*, choices: Sequence, key):
- """Get methods from a horizontal row of checkboxes."""
+ """Get methods from a horizontal row of checkboxes and the corresponding parameters."""
n_choices = len(choices)
methods = []
+ method_params = {}
+
+ # Create a container for the message
+ message_container = st.empty()
+
for col, method in zip(st.columns(n_choices), choices):
with col:
- if st.checkbox(method, key=key + method):
+ if st.checkbox(method, key=f'{key}_{method}'):
methods.append(method)
+ with st.expander(f'Click to modify {method} parameters'):
+ method_params[method] = _get_params(method, key=f'{key}_param')
if not methods:
- st.info('Select a method to continue')
+ # Put the message in the container above
+ message_container.info('Select a method to continue')
st.stop()
- return methods
+ return methods, method_params
def _get_params(method: str, key):
if method == 'RISE':
return {
'n_masks':
- st.number_input('Number of masks', value=1000, key=key + method + 'nmasks'),
+ st.number_input('Number of masks', value=1000, key=f'{key}_{method}_nmasks'),
'feature_res':
- st.number_input('Feature resolution', value=6, key=key + method + 'fr'),
+ st.number_input('Feature resolution', value=6, key=f'{key}_{method}_fr'),
'p_keep':
- st.number_input('Probability to be kept unmasked', value=0.1, key=key + method + 'pkeep'),
+ st.number_input('Probability to be kept unmasked', value=0.1, key=f'{key}_{method}_pkeep'),
}
elif method == 'KernelSHAP':
- return {
- 'nsamples': st.number_input('Number of samples', value=1000, key=key + method + 'nsamp'),
- 'background': st.number_input('Background', value=0, key=key + method + 'background'),
- 'n_segments': st.number_input('Number of segments', value=200, key=key + method + 'nseg'),
- 'sigma': st.number_input('σ', value=0, key=key + method + 'sigma'),
- }
+ if 'Tabular' in key:
+ return {'training_data_kmeans': st.number_input('Training data kmeans', value=5,
+ key=f'{key}_{method}_training_data_kmeans'),
+ }
+ else:
+ return {
+ 'nsamples': st.number_input('Number of samples', value=1000, key=f'{key}_{method}_nsamp'),
+ 'background': st.number_input('Background', value=0, key=f'{key}_{method}_background'),
+ 'n_segments': st.number_input('Number of segments', value=200, key=f'{key}_{method}_nseg'),
+ 'sigma': st.number_input('σ', value=0, key=f'{key}_{method}_sigma'),
+ }
elif method == 'LIME':
return {
- 'random_state': st.number_input('Random state', value=2, key=key + method + 'rs'),
+ 'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'),
}
else:
raise ValueError(f'No such method: {method}')
-def _get_method_params(methods: Sequence[str], key) -> Dict[str, Dict[str, Any]]:
- method_params = {}
-
- with st.expander('Click to modify method parameters'):
- for method, col in zip(methods, st.columns(len(methods))):
- with col:
- st.header(method)
- method_params[method] = _get_params(method, key=key)
-
- return method_params
-
-
def _get_top_indices(predictions, n_top):
indices = np.array(np.argpartition(predictions, -n_top)[-n_top:])
indices = indices[np.argsort(predictions[indices])]
@@ -119,29 +113,35 @@ def _get_top_indices(predictions, n_top):
def _get_top_indices_and_labels(*, predictions, labels):
- c1, c2 = st.columns(2)
+ cols = st.columns(4)
- with c2:
- n_top = st.number_input('Number of top results to show',
- value=2,
- min_value=1,
- max_value=len(labels))
+ if labels is not None:
+ with cols[-1]:
+ n_top = st.number_input('Number of top classes to show',
+ value=1,
+ min_value=1,
+ max_value=len(labels))
- top_indices = _get_top_indices(predictions, n_top)
- top_labels = [labels[i] for i in top_indices]
+ top_indices = _get_top_indices(predictions, n_top)
+ top_labels = [labels[i] for i in top_indices]
- with c1:
- st.metric('Predicted class', top_labels[0])
+ with cols[0]:
+ st.metric('Predicted class:', top_labels[0])
+ else:
+ # If not a classifier, only return the predicted value
+ top_indices = top_labels = " "
+ with cols[0]:
+ st.metric('Predicted value:', f"{predictions[0]:.2f}")
return top_indices, top_labels
def reset_method():
# Clear selection
for k in st.session_state.keys():
- if '_cb_' in k:
- st.session_state[k] = False
- if 'params' in k:
+ if '_param' in k:
st.session_state.pop(k)
+ elif '_cb' in k:
+ st.session_state[k] = False
def reset_example():
# Clear selection
diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py
index 7a36fd86..edab213d 100644
--- a/dianna/dashboard/pages/Images.py
+++ b/dianna/dashboard/pages/Images.py
@@ -4,7 +4,6 @@
from _model_utils import load_model
from _models_image import explain_image_dispatcher
from _models_image import predict
-from _shared import _get_method_params
from _shared import _get_top_indices_and_labels
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
@@ -88,15 +87,23 @@
labels = load_labels(image_label_file)
choices = ('RISE', 'KernelSHAP', 'LIME')
-methods = _methods_checkboxes(choices=choices, key='Image_cb_')
-method_params = _get_method_params(methods, key='Image_params_')
+st.text("")
+st.text("")
-with st.spinner('Predicting class'):
- predictions = predict(model=model, image=image)
+with st.container(border=True):
+ prediction_placeholder = st.empty()
+ methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb')
-top_indices, top_labels = _get_top_indices_and_labels(predictions=predictions,
- labels=labels)
+ with st.spinner('Predicting class'):
+ predictions = predict(model=model, image=image)
+
+ with prediction_placeholder:
+ top_indices, top_labels = _get_top_indices_and_labels(
+ predictions=predictions,labels=labels)
+
+st.text("")
+st.text("")
# check which axis is color channel
original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :]
@@ -107,11 +114,11 @@
_, *columns = st.columns(column_spec)
for col, method in zip(columns, methods):
- col.header(method)
+ col.markdown(f"
{method}
", unsafe_allow_html=True)
for index, label in zip(top_indices, top_labels):
index_col, *columns = st.columns(column_spec)
- index_col.markdown(f'##### {label}')
+ index_col.markdown(f'##### Class: {label}')
for col, method in zip(columns, methods):
kwargs = method_params[method].copy()
diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py
new file mode 100644
index 00000000..f9825648
--- /dev/null
+++ b/dianna/dashboard/pages/Tabular.py
@@ -0,0 +1,145 @@
+import numpy as np
+import streamlit as st
+from _model_utils import load_data
+from _model_utils import load_labels
+from _model_utils import load_model
+from _model_utils import load_training_data
+from _models_tabular import explain_tabular_dispatcher
+from _models_tabular import predict
+from _shared import _get_top_indices_and_labels
+from _shared import _methods_checkboxes
+from _shared import add_sidebar_logo
+from _shared import reset_example
+from st_aggrid import AgGrid
+from st_aggrid import GridOptionsBuilder
+from st_aggrid import GridUpdateMode
+from dianna.visualization import plot_tabular
+
+add_sidebar_logo()
+
+st.title('Tabular data explanation')
+
+st.sidebar.header('Input data')
+
+input_type = st.sidebar.radio(
+ label='Select which input to use',
+ options = ('Use an example', 'Use your own data'),
+ index = None,
+ on_change = reset_example,
+ key = 'Tabular_input_type'
+ )
+
+# Use the examples
+if input_type == 'Use an example':
+ """load_example = st.sidebar.radio(
+ label='Use example',
+ options=(''),
+ index = None,
+ on_change = reset_method,
+ key='Tabular_load_example')"""
+ st.info("No examples availble yet")
+ st.stop()
+
+# Option to upload your own data
+if input_type == 'Use your own data':
+ tabular_data_file = st.sidebar.file_uploader('Select tabular data', type='csv')
+ tabular_model_file = st.sidebar.file_uploader('Select model', type='onnx')
+ tabular_training_data_file = st.sidebar.file_uploader('Select training data', type='npy')
+ tabular_label_file = st.sidebar.file_uploader('Select labels in case of classification model', type='txt')
+
+if input_type is None:
+ st.info('Select which input type to use in the left panel to continue')
+ st.stop()
+
+if not (tabular_data_file and tabular_model_file and tabular_training_data_file):
+ st.info('Add your input data in the left panel to continue')
+ st.stop()
+
+data = load_data(tabular_data_file)
+
+model = load_model(tabular_model_file)
+serialized_model = model.SerializeToString()
+
+training_data = load_training_data(tabular_training_data_file)
+
+if tabular_label_file:
+ labels = load_labels(tabular_label_file)
+ mode = 'classification'
+else:
+ labels = None
+ mode = 'regression'
+
+choices = ('RISE', 'LIME')
+
+st.text("")
+st.text("")
+
+# Get predictions and create parameter box
+with st.container(border=True):
+ prediction_placeholder = st.empty()
+ methods, method_params = _methods_checkboxes(choices=choices, key='Tabular_cb')
+
+
+# Configure Ag-Grid options
+gb = GridOptionsBuilder.from_dataframe(data)
+gb.configure_selection('single')
+grid_options = gb.build()
+
+# Display the grid with the DataFrame
+grid_response = AgGrid(
+ data,
+ gridOptions=grid_options,
+ update_mode=GridUpdateMode.SELECTION_CHANGED,
+ theme='streamlit'
+)
+
+if grid_response['selected_rows'] is not None:
+ selected_row = grid_response['selected_rows']['Index'].iloc[0]
+ selected_data = data.iloc[selected_row, 1:].to_numpy(dtype=np.float32)
+ with st.spinner('Predicting class'):
+ predictions = predict(model=serialized_model, tabular_input=selected_data)
+
+ with prediction_placeholder:
+ top_indices, top_labels = _get_top_indices_and_labels(
+ predictions=predictions[0], labels=labels)
+
+else:
+ st.info("Select the input data by clicking a row in the table.")
+ st.stop()
+
+st.text("")
+st.text("")
+
+weight = 0.85 / len(methods)
+column_spec = [0.15, *[weight for _ in methods]]
+
+_, *columns = st.columns(column_spec)
+for col, method in zip(columns, methods):
+ col.markdown(f"{method}
", unsafe_allow_html=True)
+
+for index, label in zip(top_indices, top_labels):
+ index_col, *columns = st.columns(column_spec)
+
+ if mode == 'classification':
+ index_col.markdown(f'##### Class: {label}')
+
+ for col, method in zip(columns, methods):
+ kwargs = method_params[method].copy()
+ kwargs['labels'] = [index]
+ kwargs['mode'] = mode
+ if method == 'LIME':
+ kwargs['_feature_names']=data[:1].columns.to_list()
+
+ func = explain_tabular_dispatcher[method]
+
+ with col:
+ with st.spinner(f'Running {method}'):
+ relevances = func(serialized_model, selected_data, training_data, **kwargs)
+ fig, _ = plot_tabular(x=relevances, y=data[:1].columns, num_features=10, show_plot=False)
+ st.pyplot(fig)
+
+ # add some white space to separate rows
+ st.markdown('')
+
+
+st.stop()
diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py
index c3864c69..387e37a2 100644
--- a/dianna/dashboard/pages/Text.py
+++ b/dianna/dashboard/pages/Text.py
@@ -4,7 +4,6 @@
from _models_text import explain_text_dispatcher
from _models_text import predict
from _movie_model import MovieReviewsModelRunner
-from _shared import _get_method_params
from _shared import _get_top_indices_and_labels
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
@@ -34,10 +33,10 @@
options=('Movie sentiment',),
index = None,
on_change = reset_method,
- key='Text_example_check_moviesentiment')
+ key='Text_load_example')
if load_example == 'Movie sentiment':
- text_input = 'The movie started out great but the ending was dissappointing'
+ text_input = 'The movie started out great but the ending was disappointing'
text_model_file = download('movie_review_model.onnx', 'model')
text_label_file = download('labels_text.txt', 'label')
@@ -80,29 +79,37 @@
labels = load_labels(text_label_file)
choices = ('RISE', 'LIME')
-methods = _methods_checkboxes(choices=choices, key='Text_cb_')
-method_params = _get_method_params(methods, key='Text_params_')
+st.text("")
+st.text("")
-model_runner = MovieReviewsModelRunner(serialized_model)
+with st.container(border=True):
+ prediction_placeholder = st.empty()
+ methods, method_params = _methods_checkboxes(choices=choices, key='Text_cb')
-with st.spinner('Predicting class'):
- predictions = predict(model=serialized_model, text_input=text_input)
+ model_runner = MovieReviewsModelRunner(serialized_model)
-top_indices, top_labels = _get_top_indices_and_labels(
- predictions=predictions[0], labels=labels)
+ with st.spinner('Predicting class'):
+ predictions = predict(model=serialized_model, text_input=text_input)
+
+ with prediction_placeholder:
+ top_indices, top_labels = _get_top_indices_and_labels(
+ predictions=predictions[0], labels=labels)
+
+st.text("")
+st.text("")
weight = 0.85 / len(methods)
column_spec = [0.15, *[weight for _ in methods]]
_, *columns = st.columns(column_spec)
for col, method in zip(columns, methods):
- col.header(method)
+ col.markdown(f"{method}
", unsafe_allow_html=True)
for index, label in zip(top_indices, top_labels):
index_col, *columns = st.columns(column_spec)
- index_col.markdown(f'##### {label}')
+ index_col.markdown(f'##### Class: {label}')
for col, method in zip(columns, methods):
kwargs = method_params[method].copy()
diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py
index fb99cd25..07c88e15 100644
--- a/dianna/dashboard/pages/Time_series.py
+++ b/dianna/dashboard/pages/Time_series.py
@@ -4,7 +4,6 @@
from _model_utils import load_model
from _models_ts import explain_ts_dispatcher
from _models_ts import predict
-from _shared import _get_method_params
from _shared import _get_top_indices_and_labels
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
@@ -16,10 +15,9 @@
from dianna.visualization import plot_image
from dianna.visualization import plot_timeseries
-add_sidebar_logo()
-
st.title('Time series explanation')
+add_sidebar_logo()
st.sidebar.header('Input data')
input_type = st.sidebar.radio(
@@ -104,7 +102,6 @@ def preprocess(data):
st.info('Select which input type to use in the left panel to continue')
st.stop()
-
if not (ts_data_file and ts_model_file and ts_label_file):
st.info('Add your input data in the left panel to continue')
st.stop()
@@ -123,26 +120,34 @@ def preprocess(data):
choices = ('RISE',)
else:
choices = ('RISE', 'LIME')
-methods = _methods_checkboxes(choices=choices, key='TS_cb_')
-method_params = _get_method_params(methods, key='TS_params_')
+st.text("")
+st.text("")
+
+with st.container(border=True):
+ prediction_placeholder = st.empty()
+ methods, method_params = _methods_checkboxes(choices=choices, key='TS_cb')
+
+ with st.spinner('Predicting class'):
+ predictions = predict(model=serialized_model, ts_data=ts_data_predictor)
-with st.spinner('Predicting class'):
- predictions = predict(model=serialized_model, ts_data=ts_data_predictor)
+ with prediction_placeholder:
+ top_indices, top_labels = _get_top_indices_and_labels(
+ predictions=predictions[0], labels=labels)
-top_indices, top_labels = _get_top_indices_and_labels(
- predictions=predictions[0], labels=labels)
+st.text("")
+st.text("")
weight = 0.9 / len(methods)
column_spec = [0.1, *[weight for _ in methods]]
_, *columns = st.columns(column_spec)
for col, method in zip(columns, methods):
- col.header(method)
+ col.markdown(f"{method}
", unsafe_allow_html=True)
for index, label in zip(top_indices, top_labels):
index_col, *columns = st.columns(column_spec)
- index_col.markdown(f'##### {label}')
+ index_col.markdown(f'##### Class: {label}')
for col, method in zip(columns, methods):
kwargs = method_params[method].copy()
diff --git a/dianna/methods/lime_tabular.py b/dianna/methods/lime_tabular.py
index f6b4b1fd..ff7864c7 100644
--- a/dianna/methods/lime_tabular.py
+++ b/dianna/methods/lime_tabular.py
@@ -63,6 +63,7 @@ def __init__(
LimeTabularExplainer, kwargs)
# temporary solution for setting num_features and top_labels
+ # when fixed, also fix in dashboard Tabular.py -> _feature_names
self.num_features = len(feature_names)
self.explainer = LimeTabularExplainer(
diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py
index 08fb02a4..0ce296c0 100644
--- a/tests/test_dashboard.py
+++ b/tests/test_dashboard.py
@@ -95,11 +95,11 @@ def test_text_page(page: Page):
# Movie sentiment example
page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
page.get_by_text("Movie sentiment").click()
- expect(page.get_by_text("Select a method to continue")).to_be_visible()
+ expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=50_000)
page.locator('label').filter(has_text='RISE').locator('span').click()
page.locator('label').filter(has_text='LIME').locator('span').click()
-
+ page.get_by_test_id("stNumberInput-StepUp").click()
page.get_by_text('Running...').wait_for(state='detached', timeout=100_000)
for selector in (
@@ -117,7 +117,6 @@ def test_text_page(page: Page):
page.get_by_role('img', name='0').nth(2),
page.get_by_role('img', name='0').nth(3),
):
- print(selector)
expect(selector).to_be_visible()
# Own data option
@@ -149,13 +148,13 @@ def test_image_page(page: Page):
page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
page.get_by_text("Hand-written digit recognition").click()
- expect(page.get_by_text('Select a method to continue')).to_be_visible()
+ expect(page.get_by_text('Select a method to continue')).to_be_visible(timeout=100_000)
page.locator('label').filter(has_text='RISE').locator('span').click()
page.locator('label').filter(has_text='KernelSHAP').locator('span').click()
page.locator('label').filter(has_text='LIME').locator('span').click()
-
- page.get_by_text('Running...').wait_for(state='detached', timeout=45_000)
+ page.get_by_test_id("stNumberInput-StepUp").click()
+ page.get_by_text('Running...').wait_for(state='detached', timeout=50_000)
for selector in (
page.get_by_role('heading', name='RISE').get_by_text('RISE'),
@@ -189,7 +188,7 @@ def test_timeseries_page(page: Page):
expect(page).to_have_title('Time_series')
- expect(page.get_by_text("Select which input type to")).to_be_visible()
+ expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000)
page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
expect(page.get_by_text("Select an example in the left")).to_be_visible()
@@ -198,11 +197,11 @@ def test_timeseries_page(page: Page):
# Test weather example
page.locator("label").filter(has_text="Weather").locator("div").nth(1).click()
- expect(page.get_by_text("Select a method to continue")).to_be_visible()
+ expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000)
page.locator('label').filter(has_text='LIME').locator('span').click()
page.locator('label').filter(has_text='RISE').locator('span').click()
-
+ page.get_by_test_id("stNumberInput-StepUp").click()
page.get_by_text('Running...').wait_for(state='detached', timeout=100_000)
for selector in (
@@ -221,7 +220,7 @@ def test_timeseries_page(page: Page):
# Test FRB example
page.locator("label").filter(has_text="FRB").locator("div").nth(1).click()
- expect(page.get_by_text("Select a method to continue")).to_be_visible()
+ expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000)
page.locator('label').filter(has_text='RISE').locator('span').click()
@@ -245,3 +244,24 @@ def test_timeseries_page(page: Page):
"baseButton-secondary").click()
page.get_by_label("Select labels").get_by_test_id(
"baseButton-secondary").click()
+
+
+def test_tabular_page(page: Page):
+ """Test performance of tabular page."""
+ page.goto(f'{BASE_URL}/Tabular')
+
+ page.get_by_text('Running...').wait_for(state='detached')
+
+ expect(page).to_have_title('Tabular')
+
+ expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000)
+
+ page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
+
+ # Test using your own data
+ page.locator("label").filter(
+ has_text="Use your own data").locator("div").nth(1).click()
+ page.get_by_label("Select tabular data").get_by_test_id("baseButton-secondary").click()
+ page.get_by_label("Select model").get_by_test_id("baseButton-secondary").click()
+ page.get_by_label("Select training data").get_by_test_id("baseButton-secondary").click()
+ page.get_by_label("Select labels in case of").get_by_test_id("baseButton-secondary").click()