diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index de265c91..5dcc714b 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -55,8 +55,8 @@ jobs:
- name: Run unit tests
run: python -m pytest -v
- - name: Verify that we can build the package
- run: python setup.py sdist bdist_wheel
+ #- name: Verify that we can build the package
+ # run: python setup.py sdist bdist_wheel
test_downloader:
name: Test file downloader
@@ -73,7 +73,8 @@ jobs:
test_dashboard:
name: Test dashboard
- if: github.event.pull_request.draft == false
+ if: always()
+ #github.event.pull_request.draft == false
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
diff --git a/dianna/cli.py b/dianna/cli.py
index 4c3e1977..d8508b24 100644
--- a/dianna/cli.py
+++ b/dianna/cli.py
@@ -21,6 +21,7 @@ def dashboard():
*('--theme.primaryColor', '7030a0'),
*('--theme.secondaryBackgroundColor', 'e4f3f9'),
*('--browser.gatherUsageStats', 'false'),
+ *('--client.showSidebarNavigation', 'false'),
*args,
]
diff --git a/dianna/dashboard/Home.py b/dianna/dashboard/Home.py
index 51d5a118..f3cd87e2 100644
--- a/dianna/dashboard/Home.py
+++ b/dianna/dashboard/Home.py
@@ -46,14 +46,6 @@
with and for (academic) researchers and research software engineers working on machine
learning projects.
- ### Pages
-
- - Image data
- - Tabular data
- - Text data
- - Time series data
-
-
### More information
- [Source code](https://github.com/dianna-ai/dianna)
diff --git a/dianna/dashboard/_model_utils.py b/dianna/dashboard/_model_utils.py
index cc8084d0..67b61b2a 100644
--- a/dianna/dashboard/_model_utils.py
+++ b/dianna/dashboard/_model_utils.py
@@ -2,6 +2,7 @@
import numpy as np
import onnx
import pandas as pd
+from sklearn.model_selection import train_test_split
def load_data(file):
@@ -42,3 +43,41 @@ def load_labels(file):
def load_training_data(file):
return np.float32(np.load(file, allow_pickle=False))
+
+
+def load_sunshine(file):
+ """Tabular sunshine example.
+
+ Load the csv file in a pandas dataframe and split the data in a train and test set.
+ """
+ data = load_data(file)
+
+ # Drop unused columns
+ X_data = data.drop(columns=['DATE', 'MONTH', 'Index'])[:-1]
+ y_data = data.loc[1:]["BASEL_sunshine"]
+
+ # Split the data
+ X_train, X_holdout, _, y_holdout = train_test_split(X_data, y_data, test_size=0.3, random_state=0)
+ _, X_test, _, _ = train_test_split(X_holdout, y_holdout, test_size=0.5, random_state=0)
+ X_test = X_test.reset_index(drop=True)
+ X_test.insert(0, 'Index', X_test.index)
+
+ return X_train.to_numpy(dtype=np.float32), X_test
+
+def load_penguins(penguins):
+ """Prep the data for the penguin model example as per ntoebook."""
+ # Remove categorial columns and NaN values
+ penguins_filtered = penguins.drop(columns=['island', 'sex']).dropna()
+
+
+ # Extract inputs and target
+ input_features = penguins_filtered.drop(columns=['species'])
+ target = pd.get_dummies(penguins_filtered['species'])
+
+ X_train, X_test, _, _ = train_test_split(input_features, target, test_size=0.2,
+ random_state=0, shuffle=True, stratify=target)
+
+ X_test = X_test.reset_index(drop=True)
+ X_test.insert(0, 'Index', X_test.index)
+
+ return X_train.to_numpy(dtype=np.float32), X_test
diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py
index 96573326..38685917 100644
--- a/dianna/dashboard/_models_tabular.py
+++ b/dianna/dashboard/_models_tabular.py
@@ -1,24 +1,38 @@
-import tempfile
import numpy as np
+import onnxruntime as ort
import streamlit as st
from dianna import explain_tabular
-from dianna.utils.onnx_runner import SimpleModelRunner
@st.cache_data
def predict(*, model, tabular_input):
- model_runner = SimpleModelRunner(model)
- predictions = model_runner(tabular_input.reshape(1,-1).astype(np.float32))
- return predictions
+ # Make sure that tabular input is provided as float32
+ sess = ort.InferenceSession(model)
+ input_name = sess.get_inputs()[0].name
+ output_name = sess.get_outputs()[0].name
+
+ onnx_input = {input_name: tabular_input.astype(np.float32)}
+ pred_onnx = sess.run([output_name], onnx_input)[0]
+
+ return pred_onnx
@st.cache_data
-def _run_rise_tabular(_model, table, training_data, **kwargs):
+def _run_rise_tabular(_model, table, training_data,_feature_names, **kwargs):
+ # convert streamlit kwarg requirement back to dianna kwarg requirement
+ if "_preprocess_function" in kwargs:
+ kwargs["preprocess_function"] = kwargs["_preprocess_function"]
+ del kwargs["_preprocess_function"]
+
+ def run_model(tabular_input):
+ return predict(model=_model, tabular_input=tabular_input)
+
relevances = explain_tabular(
- _model,
+ run_model,
table,
method='RISE',
training_data=training_data,
+ feature_names=_feature_names,
**kwargs,
)
return relevances
@@ -26,8 +40,16 @@ def _run_rise_tabular(_model, table, training_data, **kwargs):
@st.cache_data
def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs):
+ # convert streamlit kwarg requirement back to dianna kwarg requirement
+ if "_preprocess_function" in kwargs:
+ kwargs["preprocess_function"] = kwargs["_preprocess_function"]
+ del kwargs["_preprocess_function"]
+
+ def run_model(tabular_input):
+ return predict(model=_model, tabular_input=tabular_input)
+
relevances = explain_tabular(
- _model,
+ run_model,
table,
method='LIME',
training_data=training_data,
@@ -37,17 +59,22 @@ def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs):
return relevances
@st.cache_data
-def _run_kernelshap_tabular(model, table, training_data, **kwargs):
+def _run_kernelshap_tabular(model, table, training_data, _feature_names, **kwargs):
# Kernelshap interface is different. Write model to temporary file.
- with tempfile.NamedTemporaryFile() as f:
- f.write(model)
- f.flush()
- relevances = explain_tabular(f.name,
+ if "_preprocess_function" in kwargs:
+ kwargs["preprocess_function"] = kwargs["_preprocess_function"]
+ del kwargs["_preprocess_function"]
+
+ def run_model(tabular_input):
+ return predict(model=model, tabular_input=tabular_input)
+
+ relevances = explain_tabular(run_model,
table,
method='KernelSHAP',
training_data=training_data,
+ feature_names=_feature_names,
**kwargs)
- return relevances[0]
+ return np.array(relevances)
explain_tabular_dispatcher = {
diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py
index 6ed408ae..0f0c91f8 100644
--- a/dianna/dashboard/_shared.py
+++ b/dianna/dashboard/_shared.py
@@ -74,13 +74,25 @@ def _methods_checkboxes(*, choices: Sequence, key):
def _get_params(method: str, key):
if method == 'RISE':
+ n_masks = 1000
+ fr = 8
+ pkeep = 0.1
+ if 'FRB' in key:
+ n_masks = 5000
+ fr = 16
+ elif 'Tabular' in key:
+ pkeep = 0.5
+ elif 'Weather' in key:
+ n_masks = 10000
+ elif 'Digits' in key:
+ n_masks = 5000
return {
'n_masks':
- st.number_input('Number of masks', value=1000, key=f'{key}_{method}_nmasks'),
+ st.number_input('Number of masks', value=n_masks, key=f'{key}_{method}_nmasks'),
'feature_res':
- st.number_input('Feature resolution', value=6, key=f'{key}_{method}_fr'),
+ st.number_input('Feature resolution', value=fr, key=f'{key}_{method}_fr'),
'p_keep':
- st.number_input('Probability to be kept unmasked', value=0.1, key=f'{key}_{method}_pkeep'),
+ st.number_input('Probability to be kept unmasked', value=pkeep, key=f'{key}_{method}_pkeep'),
}
elif method == 'KernelSHAP':
@@ -97,9 +109,14 @@ def _get_params(method: str, key):
}
elif method == 'LIME':
- return {
- 'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'),
+ if 'Tabular' in key:
+ return {
+ 'random_state': st.number_input('Random state', value=0, key=f'{key}_{method}_rs'),
}
+ else:
+ return {
+ 'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'),
+ }
else:
raise ValueError(f'No such method: {method}')
diff --git a/dianna/dashboard/dashboard-screenshot.png b/dianna/dashboard/dashboard-screenshot.png
deleted file mode 100644
index 61a86a17..00000000
Binary files a/dianna/dashboard/dashboard-screenshot.png and /dev/null differ
diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py
index edab213d..c418db01 100644
--- a/dianna/dashboard/pages/Images.py
+++ b/dianna/dashboard/pages/Images.py
@@ -41,6 +41,8 @@
image_model_file = download('mnist_model_tf.onnx', 'model')
image_label_file = download('labels_mnist.txt', 'label')
+ imagekey = 'Digits_Image_cb'
+
st.markdown(
"""
This example demonstrates the use of DIANNA on a pretrained binary
@@ -71,6 +73,8 @@
image_label_file = st.sidebar.file_uploader('Select labels',
type='txt')
+ imagekey = 'Image_cb'
+
if input_type is None:
st.info('Select which input type to use in the left panel to continue')
st.stop()
@@ -93,7 +97,7 @@
with st.container(border=True):
prediction_placeholder = st.empty()
- methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb')
+ methods, method_params = _methods_checkboxes(choices=choices, key=imagekey)
with st.spinner('Predicting class'):
predictions = predict(model=model, image=image)
diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py
index f9825648..55dc5834 100644
--- a/dianna/dashboard/pages/Tabular.py
+++ b/dianna/dashboard/pages/Tabular.py
@@ -1,8 +1,11 @@
import numpy as np
+import seaborn as sns
import streamlit as st
from _model_utils import load_data
from _model_utils import load_labels
from _model_utils import load_model
+from _model_utils import load_penguins
+from _model_utils import load_sunshine
from _model_utils import load_training_data
from _models_tabular import explain_tabular_dispatcher
from _models_tabular import predict
@@ -10,9 +13,11 @@
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
from _shared import reset_example
+from _shared import reset_method
from st_aggrid import AgGrid
from st_aggrid import GridOptionsBuilder
from st_aggrid import GridUpdateMode
+from dianna.utils.downloader import download
from dianna.visualization import plot_tabular
add_sidebar_logo()
@@ -31,14 +36,64 @@
# Use the examples
if input_type == 'Use an example':
- """load_example = st.sidebar.radio(
+ load_example = st.sidebar.radio(
label='Use example',
- options=(''),
+ options=('Sunshine hours prediction', 'Penguin identification'),
index = None,
on_change = reset_method,
- key='Tabular_load_example')"""
- st.info("No examples availble yet")
- st.stop()
+ key='Tabular_load_example')
+
+ if load_example == "Sunshine hours prediction":
+ tabular_data_file = download('weather_prediction_dataset_light.csv', 'data')
+ tabular_model_file = download('sunshine_hours_regression_model.onnx', 'model')
+ tabular_training_data_file = tabular_data_file
+ tabular_label_file = None
+
+ training_data, data = load_sunshine(tabular_data_file)
+ labels = None
+
+ mode = 'regression'
+ st.markdown(
+ """
+ This example demonstrates the use of DIANNA on a pre-trained regression
+ [model to predict tomorrow's sunshine hours](https://zenodo.org/records/10580833)
+ based on meteorological data from today.
+ The model is trained on the
+ [weather prediction dataset](https://zenodo.org/records/5071376).
+ The meteorological data includes for various European cities the
+ cloud coverage,humidity, air pressure, global radiation, precipitation, and
+ mean, min and max temeprature.
+
+ DIANNA's visualisation shows the top most important features contributing to the
+ sunshine hours prediction, where features contrinuting positively are indicated in red
+ and those who contribute negatively in blue.
+ """)
+ elif load_example == 'Penguin identification':
+ tabular_model_file = download('penguin_model.onnx', 'model')
+ data_penguins = sns.load_dataset('penguins')
+ labels = data_penguins['species'].unique()
+
+ training_data, data = load_penguins(data_penguins)
+
+ mode = 'classification'
+
+ st.markdown(
+ """
+ This example demonstrates the use of DIANNA on a pre-trained classification
+ [model to classify penguins in to three different species](https://zenodo.org/records/10580743)
+ based on a number of measurable physical characteristics.
+ The model is trained on the
+ [weather prediction dataset](https://zenodo.org/records/5071376). The data is obtained from
+ the Python seaborn package
+ The penguin characteristics include the bill length, bill depth, flipper length and body mass.
+
+ DIANNA's visualisation shows the top most important characteristics contributing to the
+ penguin species classification, where characteristics contributing positively are indicated in red
+ and those who contribute negatively in blue.
+ """)
+ else:
+ st.info('Select an example in the left panel to coninue')
+ st.stop()
# Option to upload your own data
if input_type == 'Use your own data':
@@ -47,29 +102,29 @@
tabular_training_data_file = st.sidebar.file_uploader('Select training data', type='npy')
tabular_label_file = st.sidebar.file_uploader('Select labels in case of classification model', type='txt')
+ if not (tabular_data_file and tabular_model_file and tabular_training_data_file):
+ st.info('Add your input data in the left panel to continue')
+ st.stop()
+
+ data = load_data(tabular_data_file)
+ model = load_model(tabular_model_file)
+ training_data = load_training_data(tabular_training_data_file)
+
+ if tabular_label_file:
+ labels = load_labels(tabular_label_file)
+ mode = 'classification'
+ else:
+ labels = None
+ mode = 'regression'
+
if input_type is None:
st.info('Select which input type to use in the left panel to continue')
st.stop()
-if not (tabular_data_file and tabular_model_file and tabular_training_data_file):
- st.info('Add your input data in the left panel to continue')
- st.stop()
-
-data = load_data(tabular_data_file)
-
model = load_model(tabular_model_file)
serialized_model = model.SerializeToString()
-training_data = load_training_data(tabular_training_data_file)
-
-if tabular_label_file:
- labels = load_labels(tabular_label_file)
- mode = 'classification'
-else:
- labels = None
- mode = 'regression'
-
-choices = ('RISE', 'LIME')
+choices = ('RISE', 'LIME', 'KernelSHAP')
st.text("")
st.text("")
@@ -94,10 +149,10 @@
)
if grid_response['selected_rows'] is not None:
- selected_row = grid_response['selected_rows']['Index'].iloc[0]
- selected_data = data.iloc[selected_row, 1:].to_numpy(dtype=np.float32)
+ selected_row = int(grid_response['selected_rows'].index[0])
+ selected_data = data.iloc[selected_row].to_numpy()[1:]
with st.spinner('Predicting class'):
- predictions = predict(model=serialized_model, tabular_input=selected_data)
+ predictions = predict(model=serialized_model, tabular_input=selected_data.reshape(1,-1))
with prediction_placeholder:
top_indices, top_labels = _get_top_indices_and_labels(
@@ -125,17 +180,21 @@
for col, method in zip(columns, methods):
kwargs = method_params[method].copy()
- kwargs['labels'] = [index]
kwargs['mode'] = mode
- if method == 'LIME':
- kwargs['_feature_names']=data[:1].columns.to_list()
+ kwargs['_feature_names']=data.columns.to_list()[1:]
func = explain_tabular_dispatcher[method]
with col:
with st.spinner(f'Running {method}'):
relevances = func(serialized_model, selected_data, training_data, **kwargs)
- fig, _ = plot_tabular(x=relevances, y=data[:1].columns, num_features=10, show_plot=False)
+ if mode == 'classification':
+ plot_relevances = relevances[np.argmax(predictions)]
+ else:
+ plot_relevances = relevances
+
+ fig, _ = plot_tabular(x=plot_relevances, y=kwargs['_feature_names'],
+ num_features=10, show_plot=False)
st.pyplot(fig)
# add some white space to separate rows
diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py
index 387e37a2..b414d5f1 100644
--- a/dianna/dashboard/pages/Text.py
+++ b/dianna/dashboard/pages/Text.py
@@ -36,7 +36,9 @@
key='Text_load_example')
if load_example == 'Movie sentiment':
- text_input = 'The movie started out great but the ending was disappointing'
+ text_input = st.sidebar.text_input(
+ 'Input string',
+ value='The movie started out great but the ending was disappointing')
text_model_file = download('movie_review_model.onnx', 'model')
text_label_file = download('labels_text.txt', 'label')
@@ -46,7 +48,8 @@
Treebank dataset](https://nlp.stanford.edu/sentiment/index.html) which
contains one-sentence movie reviews. A pre-trained neural network
classifier is used, which identifies whether a movie review is positive
- or negative.
+ or negative. The input string to which the model is applied can be modified
+ in the left menu.
""")
else:
st.info('Select an example in the left panel to coninue')
diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py
index 07c88e15..a8657da0 100644
--- a/dianna/dashboard/pages/Time_series.py
+++ b/dianna/dashboard/pages/Time_series.py
@@ -11,8 +11,8 @@
from _shared import reset_method
from _ts_utils import _convert_to_segments
from _ts_utils import open_timeseries
+from matplotlib import pyplot as plt
from dianna.utils.downloader import download
-from dianna.visualization import plot_image
from dianna.visualization import plot_timeseries
st.title('Time series explanation')
@@ -44,6 +44,8 @@
'season_prediction_model_temp_max_binary.onnx', 'model')
ts_label_file = download('weather_data_labels.txt', 'label')
+ param_key = 'Weather_TS_cb'
+
st.markdown(
"""
This example demonstrates the use of DIANNA
@@ -72,6 +74,8 @@ def preprocess(data):
ts_data_explainer = ts_data.T[None, ...]
ts_data_predictor = ts_data[None, ..., None]
+ param_key = 'FRB_TS_cb'
+
st.markdown(
"""This example demonstrates the use of DIANNA
on a pre-trained binary classification model trained to classify
@@ -98,6 +102,8 @@ def preprocess(data):
ts_label_file = st.sidebar.file_uploader('Select labels',
type='txt')
+ param_key = 'TS_cb'
+
if input_type is None:
st.info('Select which input type to use in the left panel to continue')
st.stop()
@@ -126,7 +132,7 @@ def preprocess(data):
with st.container(border=True):
prediction_placeholder = st.empty()
- methods, method_params = _methods_checkboxes(choices=choices, key='TS_cb')
+ methods, method_params = _methods_checkboxes(choices=choices, key=param_key)
with st.spinner('Predicting class'):
predictions = predict(model=serialized_model, ts_data=ts_data_predictor)
@@ -162,8 +168,21 @@ def preprocess(data):
explanation = func(serialized_model, ts_data=ts_data_explainer, **kwargs)
if load_example == "Scientific case: FRB":
- # FRB data: get rid of last dimension
- fig, _ = plot_image(explanation[0, :, ::-1].T)
+ fig, axes = plt.subplots(ncols=2, figsize=(14, 5))
+ # FRB: plot original data
+ ax = axes[0]
+ ax.imshow(ts_data, aspect='auto', origin='lower')
+ ax.set_xlabel('Time step')
+ ax.set_ylabel('Channel index')
+ ax.set_title('Input data')
+ # FRB data explanation has to be transposed
+ ax = axes[1]
+ plot = ax.imshow(explanation[0].T, aspect='auto', origin='lower', cmap='bwr')
+ ax.set_xlabel('Time step')
+ ax.set_ylabel('Channel index')
+ ax.set_title('Explanation')
+ fig.colorbar(plot)
+
else:
segments = _convert_to_segments(explanation)
diff --git a/setup.cfg b/setup.cfg
index 6bef3249..b363bd44 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -94,7 +94,9 @@ dashboard =
Pillow
plotly
scipy
+ seaborn
spacy
+ streamlit-aggrid
streamlit
streamlit_option_menu
torchtext
diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py
deleted file mode 100644
index 0ce296c0..00000000
--- a/tests/test_dashboard.py
+++ /dev/null
@@ -1,267 +0,0 @@
-"""Module to test the dashboard.
-
-This test module uses (playwright)[https://playwright.dev/python/]
-to test the user workflow.
-
-Installation:
-
- pip install pytest-playwright
- playwright install
-
-Make sure that the server is running by:
-```bash
-cd dianna/dashboard
-streamlit run Home.py
-```
-Then, set variable `LOCAL=True` (see below) to connect to local instance for
-debugging. Then, you can run the tests with:
-
-```bash
-pytest -v -m dashboard --dashboard
-```
-See more documentation about dashboard in: dianna/dashboard/readme.md
-
-For Code generation (https://playwright.dev/python/docs/codegen):
-
- playwright codegen http://localhost:8501
-"""
-
-import time
-from contextlib import contextmanager
-import pytest
-from playwright.sync_api import Page
-from playwright.sync_api import expect
-
-LOCAL = False
-
-PORT = '8501' if LOCAL else '8502'
-BASE_URL = f'localhost:{PORT}'
-
-pytestmark = pytest.mark.dashboard
-
-
-@pytest.fixture(scope='module', autouse=True)
-def before_module():
- """Run dashboard in module scope."""
- with run_streamlit():
- yield
-
-
-@contextmanager
-def run_streamlit():
- """Run the dashboard."""
- import subprocess
-
- if not LOCAL:
- p = subprocess.Popen([
- 'dianna-dashboard',
- '--server.port',
- PORT,
- '--server.headless',
- 'true',
- ])
- time.sleep(5)
-
- yield
-
- if not LOCAL:
- p.kill()
-
-
-def test_page_load(page: Page):
- """Test performance of landing page."""
- page.goto(BASE_URL)
-
- selector = page.get_by_text('Running...')
- selector.wait_for(state='detached')
-
- expect(page).to_have_title("Dianna's dashboard")
- for selector in (
- page.get_by_role('img', name='0'),
- page.get_by_text('Pages'),
- page.get_by_text('More information'),
- ):
- expect(selector).to_be_visible()
-
-
-def test_text_page(page: Page):
- """Test performance of text page."""
- page.goto(f'{BASE_URL}/Text')
-
- page.get_by_text('Running...').wait_for(state='detached')
-
- expect(page).to_have_title('Text')
-
- # Movie sentiment example
- page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
- page.get_by_text("Movie sentiment").click()
- expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=50_000)
-
- page.locator('label').filter(has_text='RISE').locator('span').click()
- page.locator('label').filter(has_text='LIME').locator('span').click()
- page.get_by_test_id("stNumberInput-StepUp").click()
- page.get_by_text('Running...').wait_for(state='detached', timeout=100_000)
-
- for selector in (
- page.get_by_role('heading', name='RISE').get_by_text('RISE'),
- page.get_by_role('heading', name='LIME').get_by_text('LIME'),
- # Images for positive (RISE/LIME)
- page.get_by_role('heading',
- name='positive').get_by_text('positive'),
- page.get_by_role('img', name='0').first,
- page.get_by_role('img', name='0').nth(1),
-
- # Images for negative (RISE/LIME)
- page.get_by_role('heading',
- name='negative').get_by_text('negative'),
- page.get_by_role('img', name='0').nth(2),
- page.get_by_role('img', name='0').nth(3),
- ):
- expect(selector).to_be_visible()
-
- # Own data option
- page.locator("label").filter(has_text="Use your own data").locator("div").nth(1).click()
- selector = page.get_by_text(
- 'Add your input data in the left panel to continue')
-
- expect(selector).to_be_visible(timeout=30_000)
-
- # Check input panel
- page.get_by_label("Input string").click()
- expect(page.get_by_label("Select model").get_by_test_id("baseButton-secondary")).to_be_visible()
- page.get_by_label("Select labels").get_by_test_id("baseButton-secondary").click()
-
-
-def test_image_page(page: Page):
- """Test performance of image page."""
- page.goto(f'{BASE_URL}/Images')
-
- page.get_by_text('Running...').wait_for(state='detached')
-
- expect(page).to_have_title('Images')
-
- expect(
- page.get_by_text('Select which input type to')
- ).to_be_visible(timeout=100_000)
-
- # Digits example
- page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
- page.get_by_text("Hand-written digit recognition").click()
-
- expect(page.get_by_text('Select a method to continue')).to_be_visible(timeout=100_000)
-
- page.locator('label').filter(has_text='RISE').locator('span').click()
- page.locator('label').filter(has_text='KernelSHAP').locator('span').click()
- page.locator('label').filter(has_text='LIME').locator('span').click()
- page.get_by_test_id("stNumberInput-StepUp").click()
- page.get_by_text('Running...').wait_for(state='detached', timeout=50_000)
-
- for selector in (
- page.get_by_role('heading', name='RISE').get_by_text('RISE'),
- page.get_by_role('heading', name='KernelSHAP').get_by_text('KernelSHAP'),
- page.get_by_role('heading', name='LIME').get_by_text('LIME'),
- # first image
- page.get_by_role('heading', name='0').get_by_text('0'),
- page.get_by_role('img', name='0').first,
- page.get_by_role('img', name='0').nth(1),
- page.get_by_role('img', name='0').nth(2),
- # second image
- page.get_by_role('heading', name='1').get_by_text('1'),
- page.get_by_role('img', name='0').nth(3),
- page.get_by_role('img', name='0').nth(4),
- page.get_by_role('img', name='0').nth(5),
- ):
- expect(selector).to_be_visible(timeout=100_000)
-
- # Own data
- page.locator("label").filter(has_text="Use your own data").locator("div").nth(1).click()
- expect(page.get_by_label("Select image").get_by_test_id("baseButton-secondary")).to_be_visible()
- page.get_by_label("Select model").get_by_test_id("baseButton-secondary").click()
- page.get_by_label("Select labels").get_by_test_id("baseButton-secondary").click()
-
-
-def test_timeseries_page(page: Page):
- """Test performance of timeseries page."""
- page.goto(f'{BASE_URL}/Time_series')
-
- page.get_by_text('Running...').wait_for(state='detached')
-
- expect(page).to_have_title('Time_series')
-
- expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000)
-
- page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
- expect(page.get_by_text("Select an example in the left")).to_be_visible()
- expect(page.get_by_text("Weather")).to_be_visible()
- expect(page.get_by_text("FRB")).to_be_visible()
-
- # Test weather example
- page.locator("label").filter(has_text="Weather").locator("div").nth(1).click()
- expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000)
-
- page.locator('label').filter(has_text='LIME').locator('span').click()
- page.locator('label').filter(has_text='RISE').locator('span').click()
- page.get_by_test_id("stNumberInput-StepUp").click()
- page.get_by_text('Running...').wait_for(state='detached', timeout=100_000)
-
- for selector in (
- page.get_by_role('heading', name='LIME').get_by_text('LIME'),
- page.get_by_role('heading', name='RISE').get_by_text('RISE'),
- # First image
- page.get_by_role('heading', name='winter').get_by_text('winter'),
- page.get_by_role('img', name='0').first,
- page.get_by_role('img', name='0').nth(1),
- # Second image
- page.get_by_role('heading', name='summer').get_by_text('summer'),
- page.get_by_role('img', name='0').nth(2),
- page.get_by_role('img', name='0').nth(3),
- ):
- expect(selector).to_be_visible()
-
- # Test FRB example
- page.locator("label").filter(has_text="FRB").locator("div").nth(1).click()
- expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000)
-
- page.locator('label').filter(has_text='RISE').locator('span').click()
-
- page.get_by_text('Running...').wait_for(state='detached', timeout=100_000)
-
- for selector in (
- page.get_by_role('heading', name='RISE').get_by_text('RISE'),
- # First image
- page.get_by_role('heading', name='FRB').get_by_text('FRB'),
- page.get_by_role('img', name='0').first,
- page.get_by_role('img', name='0').nth(1),
- ):
- expect(selector).to_be_visible()
-
- # Test using your own data
- page.locator("label").filter(
- has_text="Use your own data").locator("div").nth(1).click()
- page.get_by_label("Select input data").get_by_test_id(
- "baseButton-secondary").click()
- page.get_by_label("Select model").get_by_test_id(
- "baseButton-secondary").click()
- page.get_by_label("Select labels").get_by_test_id(
- "baseButton-secondary").click()
-
-
-def test_tabular_page(page: Page):
- """Test performance of tabular page."""
- page.goto(f'{BASE_URL}/Tabular')
-
- page.get_by_text('Running...').wait_for(state='detached')
-
- expect(page).to_have_title('Tabular')
-
- expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000)
-
- page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
-
- # Test using your own data
- page.locator("label").filter(
- has_text="Use your own data").locator("div").nth(1).click()
- page.get_by_label("Select tabular data").get_by_test_id("baseButton-secondary").click()
- page.get_by_label("Select model").get_by_test_id("baseButton-secondary").click()
- page.get_by_label("Select training data").get_by_test_id("baseButton-secondary").click()
- page.get_by_label("Select labels in case of").get_by_test_id("baseButton-secondary").click()
diff --git a/tests/test_dashboard_image.py b/tests/test_dashboard_image.py
new file mode 100644
index 00000000..73b1af5b
--- /dev/null
+++ b/tests/test_dashboard_image.py
@@ -0,0 +1,123 @@
+"""Module to test the dashboard.
+
+This test module uses (playwright)[https://playwright.dev/python/]
+to test the user workflow.
+
+Installation:
+
+ pip install pytest-playwright
+ playwright install
+
+Make sure that the server is running by:
+```bash
+cd dianna/dashboard
+streamlit run Home.py
+```
+Then, set variable `LOCAL=True` (see below) to connect to local instance for
+debugging. Then, you can run the tests with:
+
+```bash
+pytest -v -m dashboard --dashboard
+```
+See more documentation about dashboard in: dianna/dashboard/readme.md
+
+For Code generation (https://playwright.dev/python/docs/codegen):
+
+ playwright codegen http://localhost:8501
+"""
+
+import time
+from contextlib import contextmanager
+import pytest
+from playwright.sync_api import Page
+from playwright.sync_api import expect
+
+LOCAL = False
+
+PORT = '8501' if LOCAL else '8502'
+BASE_URL = f'localhost:{PORT}'
+
+pytestmark = pytest.mark.dashboard
+
+
+@pytest.fixture(scope='module', autouse=True)
+def before_module():
+ """Run dashboard in module scope."""
+ with run_streamlit():
+ yield
+
+
+@contextmanager
+def run_streamlit():
+ """Run the dashboard."""
+ import subprocess
+
+ if not LOCAL:
+ p = subprocess.Popen([
+ 'dianna-dashboard',
+ '--server.port',
+ PORT,
+ '--server.headless',
+ 'true',
+ ])
+ time.sleep(5)
+
+ yield
+
+ if not LOCAL:
+ p.kill()
+
+
+def test_image_page(page: Page):
+ """Test performance of image page."""
+ page.set_viewport_size({"width": 1920, "height": 1080})
+
+ page.goto(f'{BASE_URL}/Images')
+
+ page.get_by_text('Running...').wait_for(state='detached')
+
+ expect(page).to_have_title('Images')
+
+ expect(
+ page.get_by_text('Select which input type to')
+ ).to_be_visible(timeout=100_000)
+
+ # Digits example
+ page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
+ page.get_by_text("Hand-written digit recognition").click()
+
+ expect(page.get_by_text('Select a method to continue')).to_be_visible(timeout=100_000)
+
+ time.sleep(2)
+
+ page.locator('label').filter(has_text='RISE').locator('span').click()
+ page.locator('label').filter(has_text='KernelSHAP').locator('span').click()
+ page.locator('label').filter(has_text='LIME').locator('span').click()
+
+ page.get_by_label("Number of top classes to show").fill("2")
+ page.get_by_label("Number of top classes to show").press("Enter")
+ page.get_by_text('Running...').wait_for(state='detached', timeout=100_000)
+
+ for selector in (
+ page.get_by_role('heading', name='RISE').get_by_text('RISE'),
+ page.get_by_role('heading', name='KernelSHAP').get_by_text('KernelSHAP'),
+ page.get_by_role('heading', name='LIME').get_by_text('LIME'),
+ # first image
+ page.get_by_role('heading', name='0').get_by_text('0'),
+ page.get_by_role('img', name='0').first,
+ page.get_by_role('img', name='0').nth(1),
+ page.get_by_role('img', name='0').nth(2),
+ # second image
+ page.get_by_role('heading', name='1').get_by_text('1'),
+ page.get_by_role('img', name='0').nth(3),
+ page.get_by_role('img', name='0').nth(4),
+ page.get_by_role('img', name='0').nth(5),
+ ):
+ expect(selector).to_be_visible(timeout=200_000)
+
+ # Own data
+ page.locator("label").filter(has_text="Use your own data").locator("div").nth(1).click()
+
+ page.get_by_label("Select image").click()
+ page.get_by_label("Select model").click()
+ page.get_by_label("Select labels").click()
diff --git a/tests/test_dashboard_setup.py b/tests/test_dashboard_setup.py
new file mode 100644
index 00000000..75e1c602
--- /dev/null
+++ b/tests/test_dashboard_setup.py
@@ -0,0 +1,85 @@
+"""Module to test the dashboard.
+
+This test module uses (playwright)[https://playwright.dev/python/]
+to test the user workflow.
+
+Installation:
+
+ pip install pytest-playwright
+ playwright install
+
+Make sure that the server is running by:
+```bash
+cd dianna/dashboard
+streamlit run Home.py
+```
+Then, set variable `LOCAL=True` (see below) to connect to local instance for
+debugging. Then, you can run the tests with:
+
+```bash
+pytest -v -m dashboard --dashboard
+```
+See more documentation about dashboard in: dianna/dashboard/readme.md
+
+For Code generation (https://playwright.dev/python/docs/codegen):
+
+ playwright codegen http://localhost:8501
+"""
+
+import time
+from contextlib import contextmanager
+import pytest
+from playwright.sync_api import Page
+from playwright.sync_api import expect
+
+LOCAL = False
+
+PORT = '8501' if LOCAL else '8502'
+BASE_URL = f'localhost:{PORT}'
+
+pytestmark = pytest.mark.dashboard
+
+
+@pytest.fixture(scope='module', autouse=True)
+def before_module():
+ """Run dashboard in module scope."""
+ with run_streamlit():
+ yield
+
+
+@contextmanager
+def run_streamlit():
+ """Run the dashboard."""
+ import subprocess
+
+ if not LOCAL:
+ p = subprocess.Popen([
+ 'dianna-dashboard',
+ '--server.port',
+ PORT,
+ '--server.headless',
+ 'true',
+ ])
+ time.sleep(5)
+
+ yield
+
+ if not LOCAL:
+ p.kill()
+
+
+def test_page_load(page: Page):
+ """Test performance of landing page."""
+ page.goto(BASE_URL)
+
+ selector = page.get_by_text('Running...')
+ selector.wait_for(state='detached')
+
+ expect(page).to_have_title("Dianna's dashboard")
+
+ for selector in (
+ page.get_by_role('img', name='0'),
+ page.get_by_text('More information'),
+ ):
+ expect(selector).to_be_visible()
+
diff --git a/tests/test_dashboard_tabular.py b/tests/test_dashboard_tabular.py
new file mode 100644
index 00000000..f1fb4e76
--- /dev/null
+++ b/tests/test_dashboard_tabular.py
@@ -0,0 +1,181 @@
+"""Module to test the dashboard.
+
+This test module uses (playwright)[https://playwright.dev/python/]
+to test the user workflow.
+
+Installation:
+
+ pip install pytest-playwright
+ playwright install
+
+Make sure that the server is running by:
+```bash
+cd dianna/dashboard
+streamlit run Home.py
+```
+Then, set variable `LOCAL=True` (see below) to connect to local instance for
+debugging. Then, you can run the tests with:
+
+```bash
+pytest -v -m dashboard --dashboard
+```
+See more documentation about dashboard in: dianna/dashboard/readme.md
+
+For Code generation (https://playwright.dev/python/docs/codegen):
+
+ playwright codegen http://localhost:8501
+"""
+
+import time
+from contextlib import contextmanager
+import pytest
+from playwright.sync_api import Page
+from playwright.sync_api import expect
+
+LOCAL = False
+
+PORT = '8501' if LOCAL else '8502'
+BASE_URL = f'localhost:{PORT}'
+
+pytestmark = pytest.mark.dashboard
+
+
+@pytest.fixture(scope='module', autouse=True)
+def before_module():
+ """Run dashboard in module scope."""
+ with run_streamlit():
+ yield
+
+
+@contextmanager
+def run_streamlit():
+ """Run the dashboard."""
+ import subprocess
+
+ if not LOCAL:
+ p = subprocess.Popen([
+ 'dianna-dashboard',
+ '--server.port',
+ PORT,
+ '--server.headless',
+ 'true',
+ ])
+ time.sleep(5)
+
+ yield
+
+ if not LOCAL:
+ p.kill()
+
+
+def test_tabular_page(page: Page):
+ """Test performance of tabular page."""
+ page.set_viewport_size({"width": 1920, "height": 1080})
+
+ page.goto(f'{BASE_URL}/Tabular')
+
+ page.get_by_text('Running...').wait_for(state='detached')
+
+ expect(page).to_have_title('Tabular')
+
+ expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000)
+
+ # Test using your own data
+ page.locator("label").filter(
+ has_text="Use your own data").locator("div").nth(1).click()
+
+ page.get_by_label("Select tabular data").click()
+ page.get_by_label("Select model").click()
+ page.get_by_label("Select training data").click()
+ page.get_by_label("Select labels in case of").click()
+
+
+def test_tabular_sunshine(page: Page):
+ """Test tabular sunshine example."""
+ page.set_viewport_size({"width": 1920, "height": 1080})
+
+ page.goto(f'{BASE_URL}/Tabular')
+
+ page.get_by_text('Running...').wait_for(state='detached')
+
+ expect(page).to_have_title('Tabular')
+
+ expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000)
+
+ page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
+ expect(page.get_by_text("Select an example in the left")).to_be_visible()
+ expect(page.get_by_text("Sunshine hours prediction")).to_be_visible()
+ expect(page.get_by_text("Penguin identification")).to_be_visible()
+
+ # Test sunshine example
+ page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
+ page.locator("label").filter(has_text="Sunshine hours prediction").locator("div").nth(1).click()
+ expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000)
+
+ time.sleep(2)
+
+ page.locator("label").filter(has_text="RISE").locator("span").click()
+ page.locator("label").filter(has_text="LIME").locator("span").click()
+ page.locator("label").filter(has_text="KernelSHAP").locator("span").click()
+ page.locator("summary").filter(has_text="Click to modify RISE").get_by_test_id("stExpanderToggleIcon").click()
+
+ expect(page.get_by_text("Select the input data by")).to_be_visible(timeout=100_000)
+ page.frame_locator("iframe[title=\"st_aggrid\\.agGrid\"]").get_by_role(
+ "gridcell", name="10", exact=True).click()
+ page.get_by_text('Running...').wait_for(state='detached', timeout=200_000)
+
+ expect(page.get_by_text("3.07")).to_be_visible(timeout=200_000)
+
+ for selector in (
+ page.get_by_role('heading', name='RISE').get_by_text('RISE'),
+ page.get_by_role('heading', name='KernelSHAP').get_by_text('KernelSHAP'),
+ page.get_by_role('heading', name='LIME').get_by_text('LIME'),
+ page.get_by_role('img', name='0').first,
+ page.get_by_role('img', name='0').nth(1),
+ page.get_by_role('img', name='0').nth(2),
+ ):
+ expect(selector).to_be_visible(timeout=100_000)
+
+
+def test_tabular_penguin(page: Page):
+ """Test performance of tabular penguin example."""
+ page.set_viewport_size({"width": 1920, "height": 1080})
+
+ page.goto(f'{BASE_URL}/Tabular')
+ page.get_by_text('Running...').wait_for(state='detached')
+
+ expect(page).to_have_title('Tabular')
+ expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000)
+
+ page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
+ expect(page.get_by_text("Select an example in the left")).to_be_visible()
+ expect(page.get_by_text("Sunshine hours prediction")).to_be_visible()
+ expect(page.get_by_text("Penguin identification")).to_be_visible()
+
+ # Test sunshine example
+ page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
+ page.locator("label").filter(has_text="Penguin identification").locator("div").nth(1).click()
+ expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000)
+
+ time.sleep(2)
+
+ page.locator("label").filter(has_text="RISE").locator("span").click(timeout=300_000)
+ page.locator("label").filter(has_text="LIME").locator("span").click(timeout=300_000)
+ page.locator("label").filter(has_text="KernelSHAP").locator("span").click(timeout=300_000)
+
+ expect(page.get_by_text("Select the input data by")).to_be_visible(timeout=300_000)
+ page.frame_locator("iframe[title=\"st_aggrid\\.agGrid\"]").get_by_role(
+ "gridcell", name="10", exact=True).click()
+ page.get_by_text('Running...').wait_for(state='detached', timeout=300_000)
+
+ for selector in (
+ page.get_by_text('Predicted class:'),
+ page.get_by_test_id('stMetricValue').get_by_text('Gentoo'),
+ page.get_by_role('heading', name='RISE').get_by_text('RISE'),
+ page.get_by_role('heading', name='KernelSHAP').get_by_text('KernelSHAP'),
+ page.get_by_role('heading', name='LIME').get_by_text('LIME'),
+ page.get_by_role('img', name='0').first,
+ page.get_by_role('img', name='0').nth(1),
+ page.get_by_role('img', name='0').nth(2),
+ ):
+ expect(selector).to_be_visible(timeout=200_000)
diff --git a/tests/test_dashboard_text.py b/tests/test_dashboard_text.py
new file mode 100644
index 00000000..4c10d5ec
--- /dev/null
+++ b/tests/test_dashboard_text.py
@@ -0,0 +1,115 @@
+"""Module to test the dashboard.
+
+This test module uses (playwright)[https://playwright.dev/python/]
+to test the user workflow.
+
+Installation:
+
+ pip install pytest-playwright
+ playwright install
+
+Make sure that the server is running by:
+```bash
+cd dianna/dashboard
+streamlit run Home.py
+```
+Then, set variable `LOCAL=True` (see below) to connect to local instance for
+debugging. Then, you can run the tests with:
+
+```bash
+pytest -v -m dashboard --dashboard
+```
+See more documentation about dashboard in: dianna/dashboard/readme.md
+
+For Code generation (https://playwright.dev/python/docs/codegen):
+
+ playwright codegen http://localhost:8501
+"""
+
+import time
+from contextlib import contextmanager
+import pytest
+from playwright.sync_api import Page
+from playwright.sync_api import expect
+
+LOCAL = False
+
+PORT = '8501' if LOCAL else '8502'
+BASE_URL = f'localhost:{PORT}'
+
+pytestmark = pytest.mark.dashboard
+
+
+@pytest.fixture(scope='module', autouse=True)
+def before_module():
+ """Run dashboard in module scope."""
+ with run_streamlit():
+ yield
+
+
+@contextmanager
+def run_streamlit():
+ """Run the dashboard."""
+ import subprocess
+
+ if not LOCAL:
+ p = subprocess.Popen([
+ 'dianna-dashboard',
+ '--server.port',
+ PORT,
+ '--server.headless',
+ 'true',
+ ])
+ time.sleep(5)
+
+ yield
+
+ if not LOCAL:
+ p.kill()
+
+
+def test_text_page(page: Page):
+ """Test performance of text page."""
+ page.set_viewport_size({"width": 1920, "height": 1080})
+
+ page.goto(f'{BASE_URL}/Text')
+ page.get_by_text('Running...').wait_for(state='detached')
+ expect(page).to_have_title('Text')
+ # Movie sentiment example
+ page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
+ page.get_by_text("Movie sentiment").click()
+ expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=50_000)
+
+ time.sleep(2)
+ page.locator('label').filter(has_text='RISE').locator('span').click()
+ page.locator('label').filter(has_text='LIME').locator('span').click()
+
+ page.get_by_label("Number of top classes to show").fill("2")
+ page.get_by_label("Number of top classes to show").press("Enter")
+ page.get_by_text('Running...').wait_for(state='detached', timeout=100_000)
+
+ for selector in (
+ page.get_by_role('heading', name='RISE').get_by_text('RISE'),
+ page.get_by_role('heading', name='LIME').get_by_text('LIME'),
+ # Images for positive (RISE/LIME)
+ page.get_by_role('heading',
+ name='positive').get_by_text('positive'),
+ page.get_by_role('img', name='0').first,
+ page.get_by_role('img', name='0').nth(1),
+# # Images for negative (RISE/LIME)
+ page.get_by_role('heading',
+ name='negative').get_by_text('negative'),
+ page.get_by_role('img', name='0').nth(2),
+ page.get_by_role('img', name='0').nth(3),
+ ):
+ expect(selector).to_be_visible(timeout=100_000)
+
+ # Own data option
+ page.locator("label").filter(has_text="Use your own data").locator("div").nth(1).click()
+ selector = page.get_by_text(
+ 'Add your input data in the left panel to continue')
+ expect(selector).to_be_visible(timeout=30_000)
+ # Check input panel
+ expect(page.get_by_label("Input string")).to_be_visible(timeout=200_000)
+ page.get_by_label("Select model").click()
+ page.get_by_label("Select labels").click()
diff --git a/tests/test_dashboard_time_series.py b/tests/test_dashboard_time_series.py
new file mode 100644
index 00000000..a4baacd1
--- /dev/null
+++ b/tests/test_dashboard_time_series.py
@@ -0,0 +1,146 @@
+"""Module to test the dashboard.
+
+This test module uses (playwright)[https://playwright.dev/python/]
+to test the user workflow.
+
+Installation:
+
+ pip install pytest-playwright
+ playwright install
+
+Make sure that the server is running by:
+```bash
+cd dianna/dashboard
+streamlit run Home.py
+```
+Then, set variable `LOCAL=True` (see below) to connect to local instance for
+debugging. Then, you can run the tests with:
+
+```bash
+pytest -v -m dashboard --dashboard
+```
+See more documentation about dashboard in: dianna/dashboard/readme.md
+
+For Code generation (https://playwright.dev/python/docs/codegen):
+
+ playwright codegen http://localhost:8501
+"""
+
+import time
+from contextlib import contextmanager
+import pytest
+from playwright.sync_api import Page
+from playwright.sync_api import expect
+
+LOCAL = False
+
+PORT = '8501' if LOCAL else '8502'
+BASE_URL = f'localhost:{PORT}'
+
+pytestmark = pytest.mark.dashboard
+
+
+@pytest.fixture(scope='module', autouse=True)
+def before_module():
+ """Run dashboard in module scope."""
+ with run_streamlit():
+ yield
+
+
+@contextmanager
+def run_streamlit():
+ """Run the dashboard."""
+ import subprocess
+
+ if not LOCAL:
+ p = subprocess.Popen([
+ 'dianna-dashboard',
+ '--server.port',
+ PORT,
+ '--server.headless',
+ 'true',
+ ])
+ time.sleep(5)
+
+ yield
+
+ if not LOCAL:
+ p.kill()
+
+
+def test_timeseries_page(page: Page):
+ """Test performance of timeseries page."""
+ page.set_viewport_size({"width": 1920, "height": 1080})
+
+ page.goto(f'{BASE_URL}/Time_series')
+
+ page.get_by_text('Running...').wait_for(state='detached')
+
+ expect(page).to_have_title('Time_series')
+
+ expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000)
+
+ page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
+ expect(page.get_by_text("Select an example in the left")).to_be_visible(timeout=200_000)
+ expect(page.get_by_text("Weather")).to_be_visible()
+ expect(page.get_by_text("FRB")).to_be_visible()
+
+ # Test weather example
+ page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
+ page.locator("label").filter(has_text="Weather").locator("div").nth(1).click()
+ expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000)
+
+ time.sleep(2)
+
+ page.locator('label').filter(has_text='LIME').locator('span').click(timeout=200_000)
+ page.locator('label').filter(has_text='RISE').locator('span').click(timeout=200_000)
+
+ page.get_by_label("Number of top classes to show").fill("2")
+ page.get_by_label("Number of top classes to show").press("Enter")
+ page.get_by_text('Running...').wait_for(state='detached', timeout=100_000)
+
+ for selector in (
+ page.get_by_role('heading', name='LIME').get_by_text('LIME'),
+ page.get_by_role('heading', name='RISE').get_by_text('RISE'),
+ # First image
+ page.get_by_role('heading', name='winter').get_by_text('winter'),
+ page.get_by_role('img', name='0').first,
+ page.get_by_role('img', name='0').nth(1),
+ # Second image
+ page.get_by_role('heading', name='summer').get_by_text('summer'),
+ page.get_by_role('img', name='0').nth(2),
+ page.get_by_role('img', name='0').nth(3),
+ ):
+ expect(selector).to_be_visible(timeout=100_000)
+
+ # Test FRB example
+ page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
+ page.locator("label").filter(has_text="FRB").locator("div").nth(1).click()
+ expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000)
+
+ time.sleep(2)
+
+ page.locator('label').filter(has_text='RISE').locator('span').click()
+
+ page.get_by_label("Number of top classes to show").fill("2")
+ page.get_by_label("Number of top classes to show").press("Enter")
+
+ page.get_by_text('Running...').wait_for(state='detached', timeout=100_000)
+
+ for selector in (
+ page.get_by_role('heading', name='RISE').get_by_text('RISE'),
+ # First image
+ page.get_by_role('heading', name='FRB').get_by_text('FRB'),
+ page.get_by_role('img', name='0').nth(1),
+ # Second image
+ page.get_by_role('heading', name='Noise').get_by_text('Noise'),
+ page.get_by_role('img', name='0').nth(2),
+ ):
+ expect(selector).to_be_visible(timeout=300_000)
+
+ # Test using your own data
+ page.locator("label").filter(
+ has_text="Use your own data").locator("div").nth(1).click()
+ page.get_by_label("Select input data").click()
+ page.get_by_label("Select model").click()
+ page.get_by_label("Select labels").click()