diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index de265c91..5dcc714b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -55,8 +55,8 @@ jobs: - name: Run unit tests run: python -m pytest -v - - name: Verify that we can build the package - run: python setup.py sdist bdist_wheel + #- name: Verify that we can build the package + # run: python setup.py sdist bdist_wheel test_downloader: name: Test file downloader @@ -73,7 +73,8 @@ jobs: test_dashboard: name: Test dashboard - if: github.event.pull_request.draft == false + if: always() + #github.event.pull_request.draft == false runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 diff --git a/dianna/cli.py b/dianna/cli.py index 4c3e1977..d8508b24 100644 --- a/dianna/cli.py +++ b/dianna/cli.py @@ -21,6 +21,7 @@ def dashboard(): *('--theme.primaryColor', '7030a0'), *('--theme.secondaryBackgroundColor', 'e4f3f9'), *('--browser.gatherUsageStats', 'false'), + *('--client.showSidebarNavigation', 'false'), *args, ] diff --git a/dianna/dashboard/Home.py b/dianna/dashboard/Home.py index 51d5a118..f3cd87e2 100644 --- a/dianna/dashboard/Home.py +++ b/dianna/dashboard/Home.py @@ -46,14 +46,6 @@ with and for (academic) researchers and research software engineers working on machine learning projects. - ### Pages - - - Image data - - Tabular data - - Text data - - Time series data - - ### More information - [Source code](https://github.com/dianna-ai/dianna) diff --git a/dianna/dashboard/_model_utils.py b/dianna/dashboard/_model_utils.py index cc8084d0..67b61b2a 100644 --- a/dianna/dashboard/_model_utils.py +++ b/dianna/dashboard/_model_utils.py @@ -2,6 +2,7 @@ import numpy as np import onnx import pandas as pd +from sklearn.model_selection import train_test_split def load_data(file): @@ -42,3 +43,41 @@ def load_labels(file): def load_training_data(file): return np.float32(np.load(file, allow_pickle=False)) + + +def load_sunshine(file): + """Tabular sunshine example. + + Load the csv file in a pandas dataframe and split the data in a train and test set. + """ + data = load_data(file) + + # Drop unused columns + X_data = data.drop(columns=['DATE', 'MONTH', 'Index'])[:-1] + y_data = data.loc[1:]["BASEL_sunshine"] + + # Split the data + X_train, X_holdout, _, y_holdout = train_test_split(X_data, y_data, test_size=0.3, random_state=0) + _, X_test, _, _ = train_test_split(X_holdout, y_holdout, test_size=0.5, random_state=0) + X_test = X_test.reset_index(drop=True) + X_test.insert(0, 'Index', X_test.index) + + return X_train.to_numpy(dtype=np.float32), X_test + +def load_penguins(penguins): + """Prep the data for the penguin model example as per ntoebook.""" + # Remove categorial columns and NaN values + penguins_filtered = penguins.drop(columns=['island', 'sex']).dropna() + + + # Extract inputs and target + input_features = penguins_filtered.drop(columns=['species']) + target = pd.get_dummies(penguins_filtered['species']) + + X_train, X_test, _, _ = train_test_split(input_features, target, test_size=0.2, + random_state=0, shuffle=True, stratify=target) + + X_test = X_test.reset_index(drop=True) + X_test.insert(0, 'Index', X_test.index) + + return X_train.to_numpy(dtype=np.float32), X_test diff --git a/dianna/dashboard/_models_tabular.py b/dianna/dashboard/_models_tabular.py index 96573326..38685917 100644 --- a/dianna/dashboard/_models_tabular.py +++ b/dianna/dashboard/_models_tabular.py @@ -1,24 +1,38 @@ -import tempfile import numpy as np +import onnxruntime as ort import streamlit as st from dianna import explain_tabular -from dianna.utils.onnx_runner import SimpleModelRunner @st.cache_data def predict(*, model, tabular_input): - model_runner = SimpleModelRunner(model) - predictions = model_runner(tabular_input.reshape(1,-1).astype(np.float32)) - return predictions + # Make sure that tabular input is provided as float32 + sess = ort.InferenceSession(model) + input_name = sess.get_inputs()[0].name + output_name = sess.get_outputs()[0].name + + onnx_input = {input_name: tabular_input.astype(np.float32)} + pred_onnx = sess.run([output_name], onnx_input)[0] + + return pred_onnx @st.cache_data -def _run_rise_tabular(_model, table, training_data, **kwargs): +def _run_rise_tabular(_model, table, training_data,_feature_names, **kwargs): + # convert streamlit kwarg requirement back to dianna kwarg requirement + if "_preprocess_function" in kwargs: + kwargs["preprocess_function"] = kwargs["_preprocess_function"] + del kwargs["_preprocess_function"] + + def run_model(tabular_input): + return predict(model=_model, tabular_input=tabular_input) + relevances = explain_tabular( - _model, + run_model, table, method='RISE', training_data=training_data, + feature_names=_feature_names, **kwargs, ) return relevances @@ -26,8 +40,16 @@ def _run_rise_tabular(_model, table, training_data, **kwargs): @st.cache_data def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs): + # convert streamlit kwarg requirement back to dianna kwarg requirement + if "_preprocess_function" in kwargs: + kwargs["preprocess_function"] = kwargs["_preprocess_function"] + del kwargs["_preprocess_function"] + + def run_model(tabular_input): + return predict(model=_model, tabular_input=tabular_input) + relevances = explain_tabular( - _model, + run_model, table, method='LIME', training_data=training_data, @@ -37,17 +59,22 @@ def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs): return relevances @st.cache_data -def _run_kernelshap_tabular(model, table, training_data, **kwargs): +def _run_kernelshap_tabular(model, table, training_data, _feature_names, **kwargs): # Kernelshap interface is different. Write model to temporary file. - with tempfile.NamedTemporaryFile() as f: - f.write(model) - f.flush() - relevances = explain_tabular(f.name, + if "_preprocess_function" in kwargs: + kwargs["preprocess_function"] = kwargs["_preprocess_function"] + del kwargs["_preprocess_function"] + + def run_model(tabular_input): + return predict(model=model, tabular_input=tabular_input) + + relevances = explain_tabular(run_model, table, method='KernelSHAP', training_data=training_data, + feature_names=_feature_names, **kwargs) - return relevances[0] + return np.array(relevances) explain_tabular_dispatcher = { diff --git a/dianna/dashboard/_shared.py b/dianna/dashboard/_shared.py index 6ed408ae..0f0c91f8 100644 --- a/dianna/dashboard/_shared.py +++ b/dianna/dashboard/_shared.py @@ -74,13 +74,25 @@ def _methods_checkboxes(*, choices: Sequence, key): def _get_params(method: str, key): if method == 'RISE': + n_masks = 1000 + fr = 8 + pkeep = 0.1 + if 'FRB' in key: + n_masks = 5000 + fr = 16 + elif 'Tabular' in key: + pkeep = 0.5 + elif 'Weather' in key: + n_masks = 10000 + elif 'Digits' in key: + n_masks = 5000 return { 'n_masks': - st.number_input('Number of masks', value=1000, key=f'{key}_{method}_nmasks'), + st.number_input('Number of masks', value=n_masks, key=f'{key}_{method}_nmasks'), 'feature_res': - st.number_input('Feature resolution', value=6, key=f'{key}_{method}_fr'), + st.number_input('Feature resolution', value=fr, key=f'{key}_{method}_fr'), 'p_keep': - st.number_input('Probability to be kept unmasked', value=0.1, key=f'{key}_{method}_pkeep'), + st.number_input('Probability to be kept unmasked', value=pkeep, key=f'{key}_{method}_pkeep'), } elif method == 'KernelSHAP': @@ -97,9 +109,14 @@ def _get_params(method: str, key): } elif method == 'LIME': - return { - 'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'), + if 'Tabular' in key: + return { + 'random_state': st.number_input('Random state', value=0, key=f'{key}_{method}_rs'), } + else: + return { + 'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'), + } else: raise ValueError(f'No such method: {method}') diff --git a/dianna/dashboard/dashboard-screenshot.png b/dianna/dashboard/dashboard-screenshot.png deleted file mode 100644 index 61a86a17..00000000 Binary files a/dianna/dashboard/dashboard-screenshot.png and /dev/null differ diff --git a/dianna/dashboard/pages/Images.py b/dianna/dashboard/pages/Images.py index edab213d..c418db01 100644 --- a/dianna/dashboard/pages/Images.py +++ b/dianna/dashboard/pages/Images.py @@ -41,6 +41,8 @@ image_model_file = download('mnist_model_tf.onnx', 'model') image_label_file = download('labels_mnist.txt', 'label') + imagekey = 'Digits_Image_cb' + st.markdown( """ This example demonstrates the use of DIANNA on a pretrained binary @@ -71,6 +73,8 @@ image_label_file = st.sidebar.file_uploader('Select labels', type='txt') + imagekey = 'Image_cb' + if input_type is None: st.info('Select which input type to use in the left panel to continue') st.stop() @@ -93,7 +97,7 @@ with st.container(border=True): prediction_placeholder = st.empty() - methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb') + methods, method_params = _methods_checkboxes(choices=choices, key=imagekey) with st.spinner('Predicting class'): predictions = predict(model=model, image=image) diff --git a/dianna/dashboard/pages/Tabular.py b/dianna/dashboard/pages/Tabular.py index f9825648..55dc5834 100644 --- a/dianna/dashboard/pages/Tabular.py +++ b/dianna/dashboard/pages/Tabular.py @@ -1,8 +1,11 @@ import numpy as np +import seaborn as sns import streamlit as st from _model_utils import load_data from _model_utils import load_labels from _model_utils import load_model +from _model_utils import load_penguins +from _model_utils import load_sunshine from _model_utils import load_training_data from _models_tabular import explain_tabular_dispatcher from _models_tabular import predict @@ -10,9 +13,11 @@ from _shared import _methods_checkboxes from _shared import add_sidebar_logo from _shared import reset_example +from _shared import reset_method from st_aggrid import AgGrid from st_aggrid import GridOptionsBuilder from st_aggrid import GridUpdateMode +from dianna.utils.downloader import download from dianna.visualization import plot_tabular add_sidebar_logo() @@ -31,14 +36,64 @@ # Use the examples if input_type == 'Use an example': - """load_example = st.sidebar.radio( + load_example = st.sidebar.radio( label='Use example', - options=(''), + options=('Sunshine hours prediction', 'Penguin identification'), index = None, on_change = reset_method, - key='Tabular_load_example')""" - st.info("No examples availble yet") - st.stop() + key='Tabular_load_example') + + if load_example == "Sunshine hours prediction": + tabular_data_file = download('weather_prediction_dataset_light.csv', 'data') + tabular_model_file = download('sunshine_hours_regression_model.onnx', 'model') + tabular_training_data_file = tabular_data_file + tabular_label_file = None + + training_data, data = load_sunshine(tabular_data_file) + labels = None + + mode = 'regression' + st.markdown( + """ + This example demonstrates the use of DIANNA on a pre-trained regression + [model to predict tomorrow's sunshine hours](https://zenodo.org/records/10580833) + based on meteorological data from today. + The model is trained on the + [weather prediction dataset](https://zenodo.org/records/5071376). + The meteorological data includes for various European cities the + cloud coverage,humidity, air pressure, global radiation, precipitation, and + mean, min and max temeprature. + + DIANNA's visualisation shows the top most important features contributing to the + sunshine hours prediction, where features contrinuting positively are indicated in red + and those who contribute negatively in blue. + """) + elif load_example == 'Penguin identification': + tabular_model_file = download('penguin_model.onnx', 'model') + data_penguins = sns.load_dataset('penguins') + labels = data_penguins['species'].unique() + + training_data, data = load_penguins(data_penguins) + + mode = 'classification' + + st.markdown( + """ + This example demonstrates the use of DIANNA on a pre-trained classification + [model to classify penguins in to three different species](https://zenodo.org/records/10580743) + based on a number of measurable physical characteristics. + The model is trained on the + [weather prediction dataset](https://zenodo.org/records/5071376). The data is obtained from + the Python seaborn package + The penguin characteristics include the bill length, bill depth, flipper length and body mass. + + DIANNA's visualisation shows the top most important characteristics contributing to the + penguin species classification, where characteristics contributing positively are indicated in red + and those who contribute negatively in blue. + """) + else: + st.info('Select an example in the left panel to coninue') + st.stop() # Option to upload your own data if input_type == 'Use your own data': @@ -47,29 +102,29 @@ tabular_training_data_file = st.sidebar.file_uploader('Select training data', type='npy') tabular_label_file = st.sidebar.file_uploader('Select labels in case of classification model', type='txt') + if not (tabular_data_file and tabular_model_file and tabular_training_data_file): + st.info('Add your input data in the left panel to continue') + st.stop() + + data = load_data(tabular_data_file) + model = load_model(tabular_model_file) + training_data = load_training_data(tabular_training_data_file) + + if tabular_label_file: + labels = load_labels(tabular_label_file) + mode = 'classification' + else: + labels = None + mode = 'regression' + if input_type is None: st.info('Select which input type to use in the left panel to continue') st.stop() -if not (tabular_data_file and tabular_model_file and tabular_training_data_file): - st.info('Add your input data in the left panel to continue') - st.stop() - -data = load_data(tabular_data_file) - model = load_model(tabular_model_file) serialized_model = model.SerializeToString() -training_data = load_training_data(tabular_training_data_file) - -if tabular_label_file: - labels = load_labels(tabular_label_file) - mode = 'classification' -else: - labels = None - mode = 'regression' - -choices = ('RISE', 'LIME') +choices = ('RISE', 'LIME', 'KernelSHAP') st.text("") st.text("") @@ -94,10 +149,10 @@ ) if grid_response['selected_rows'] is not None: - selected_row = grid_response['selected_rows']['Index'].iloc[0] - selected_data = data.iloc[selected_row, 1:].to_numpy(dtype=np.float32) + selected_row = int(grid_response['selected_rows'].index[0]) + selected_data = data.iloc[selected_row].to_numpy()[1:] with st.spinner('Predicting class'): - predictions = predict(model=serialized_model, tabular_input=selected_data) + predictions = predict(model=serialized_model, tabular_input=selected_data.reshape(1,-1)) with prediction_placeholder: top_indices, top_labels = _get_top_indices_and_labels( @@ -125,17 +180,21 @@ for col, method in zip(columns, methods): kwargs = method_params[method].copy() - kwargs['labels'] = [index] kwargs['mode'] = mode - if method == 'LIME': - kwargs['_feature_names']=data[:1].columns.to_list() + kwargs['_feature_names']=data.columns.to_list()[1:] func = explain_tabular_dispatcher[method] with col: with st.spinner(f'Running {method}'): relevances = func(serialized_model, selected_data, training_data, **kwargs) - fig, _ = plot_tabular(x=relevances, y=data[:1].columns, num_features=10, show_plot=False) + if mode == 'classification': + plot_relevances = relevances[np.argmax(predictions)] + else: + plot_relevances = relevances + + fig, _ = plot_tabular(x=plot_relevances, y=kwargs['_feature_names'], + num_features=10, show_plot=False) st.pyplot(fig) # add some white space to separate rows diff --git a/dianna/dashboard/pages/Text.py b/dianna/dashboard/pages/Text.py index 387e37a2..b414d5f1 100644 --- a/dianna/dashboard/pages/Text.py +++ b/dianna/dashboard/pages/Text.py @@ -36,7 +36,9 @@ key='Text_load_example') if load_example == 'Movie sentiment': - text_input = 'The movie started out great but the ending was disappointing' + text_input = st.sidebar.text_input( + 'Input string', + value='The movie started out great but the ending was disappointing') text_model_file = download('movie_review_model.onnx', 'model') text_label_file = download('labels_text.txt', 'label') @@ -46,7 +48,8 @@ Treebank dataset](https://nlp.stanford.edu/sentiment/index.html) which contains one-sentence movie reviews. A pre-trained neural network classifier is used, which identifies whether a movie review is positive - or negative. + or negative. The input string to which the model is applied can be modified + in the left menu. """) else: st.info('Select an example in the left panel to coninue') diff --git a/dianna/dashboard/pages/Time_series.py b/dianna/dashboard/pages/Time_series.py index 07c88e15..a8657da0 100644 --- a/dianna/dashboard/pages/Time_series.py +++ b/dianna/dashboard/pages/Time_series.py @@ -11,8 +11,8 @@ from _shared import reset_method from _ts_utils import _convert_to_segments from _ts_utils import open_timeseries +from matplotlib import pyplot as plt from dianna.utils.downloader import download -from dianna.visualization import plot_image from dianna.visualization import plot_timeseries st.title('Time series explanation') @@ -44,6 +44,8 @@ 'season_prediction_model_temp_max_binary.onnx', 'model') ts_label_file = download('weather_data_labels.txt', 'label') + param_key = 'Weather_TS_cb' + st.markdown( """ This example demonstrates the use of DIANNA @@ -72,6 +74,8 @@ def preprocess(data): ts_data_explainer = ts_data.T[None, ...] ts_data_predictor = ts_data[None, ..., None] + param_key = 'FRB_TS_cb' + st.markdown( """This example demonstrates the use of DIANNA on a pre-trained binary classification model trained to classify @@ -98,6 +102,8 @@ def preprocess(data): ts_label_file = st.sidebar.file_uploader('Select labels', type='txt') + param_key = 'TS_cb' + if input_type is None: st.info('Select which input type to use in the left panel to continue') st.stop() @@ -126,7 +132,7 @@ def preprocess(data): with st.container(border=True): prediction_placeholder = st.empty() - methods, method_params = _methods_checkboxes(choices=choices, key='TS_cb') + methods, method_params = _methods_checkboxes(choices=choices, key=param_key) with st.spinner('Predicting class'): predictions = predict(model=serialized_model, ts_data=ts_data_predictor) @@ -162,8 +168,21 @@ def preprocess(data): explanation = func(serialized_model, ts_data=ts_data_explainer, **kwargs) if load_example == "Scientific case: FRB": - # FRB data: get rid of last dimension - fig, _ = plot_image(explanation[0, :, ::-1].T) + fig, axes = plt.subplots(ncols=2, figsize=(14, 5)) + # FRB: plot original data + ax = axes[0] + ax.imshow(ts_data, aspect='auto', origin='lower') + ax.set_xlabel('Time step') + ax.set_ylabel('Channel index') + ax.set_title('Input data') + # FRB data explanation has to be transposed + ax = axes[1] + plot = ax.imshow(explanation[0].T, aspect='auto', origin='lower', cmap='bwr') + ax.set_xlabel('Time step') + ax.set_ylabel('Channel index') + ax.set_title('Explanation') + fig.colorbar(plot) + else: segments = _convert_to_segments(explanation) diff --git a/setup.cfg b/setup.cfg index 6bef3249..b363bd44 100644 --- a/setup.cfg +++ b/setup.cfg @@ -94,7 +94,9 @@ dashboard = Pillow plotly scipy + seaborn spacy + streamlit-aggrid streamlit streamlit_option_menu torchtext diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py deleted file mode 100644 index 0ce296c0..00000000 --- a/tests/test_dashboard.py +++ /dev/null @@ -1,267 +0,0 @@ -"""Module to test the dashboard. - -This test module uses (playwright)[https://playwright.dev/python/] -to test the user workflow. - -Installation: - - pip install pytest-playwright - playwright install - -Make sure that the server is running by: -```bash -cd dianna/dashboard -streamlit run Home.py -``` -Then, set variable `LOCAL=True` (see below) to connect to local instance for -debugging. Then, you can run the tests with: - -```bash -pytest -v -m dashboard --dashboard -``` -See more documentation about dashboard in: dianna/dashboard/readme.md - -For Code generation (https://playwright.dev/python/docs/codegen): - - playwright codegen http://localhost:8501 -""" - -import time -from contextlib import contextmanager -import pytest -from playwright.sync_api import Page -from playwright.sync_api import expect - -LOCAL = False - -PORT = '8501' if LOCAL else '8502' -BASE_URL = f'localhost:{PORT}' - -pytestmark = pytest.mark.dashboard - - -@pytest.fixture(scope='module', autouse=True) -def before_module(): - """Run dashboard in module scope.""" - with run_streamlit(): - yield - - -@contextmanager -def run_streamlit(): - """Run the dashboard.""" - import subprocess - - if not LOCAL: - p = subprocess.Popen([ - 'dianna-dashboard', - '--server.port', - PORT, - '--server.headless', - 'true', - ]) - time.sleep(5) - - yield - - if not LOCAL: - p.kill() - - -def test_page_load(page: Page): - """Test performance of landing page.""" - page.goto(BASE_URL) - - selector = page.get_by_text('Running...') - selector.wait_for(state='detached') - - expect(page).to_have_title("Dianna's dashboard") - for selector in ( - page.get_by_role('img', name='0'), - page.get_by_text('Pages'), - page.get_by_text('More information'), - ): - expect(selector).to_be_visible() - - -def test_text_page(page: Page): - """Test performance of text page.""" - page.goto(f'{BASE_URL}/Text') - - page.get_by_text('Running...').wait_for(state='detached') - - expect(page).to_have_title('Text') - - # Movie sentiment example - page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() - page.get_by_text("Movie sentiment").click() - expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=50_000) - - page.locator('label').filter(has_text='RISE').locator('span').click() - page.locator('label').filter(has_text='LIME').locator('span').click() - page.get_by_test_id("stNumberInput-StepUp").click() - page.get_by_text('Running...').wait_for(state='detached', timeout=100_000) - - for selector in ( - page.get_by_role('heading', name='RISE').get_by_text('RISE'), - page.get_by_role('heading', name='LIME').get_by_text('LIME'), - # Images for positive (RISE/LIME) - page.get_by_role('heading', - name='positive').get_by_text('positive'), - page.get_by_role('img', name='0').first, - page.get_by_role('img', name='0').nth(1), - - # Images for negative (RISE/LIME) - page.get_by_role('heading', - name='negative').get_by_text('negative'), - page.get_by_role('img', name='0').nth(2), - page.get_by_role('img', name='0').nth(3), - ): - expect(selector).to_be_visible() - - # Own data option - page.locator("label").filter(has_text="Use your own data").locator("div").nth(1).click() - selector = page.get_by_text( - 'Add your input data in the left panel to continue') - - expect(selector).to_be_visible(timeout=30_000) - - # Check input panel - page.get_by_label("Input string").click() - expect(page.get_by_label("Select model").get_by_test_id("baseButton-secondary")).to_be_visible() - page.get_by_label("Select labels").get_by_test_id("baseButton-secondary").click() - - -def test_image_page(page: Page): - """Test performance of image page.""" - page.goto(f'{BASE_URL}/Images') - - page.get_by_text('Running...').wait_for(state='detached') - - expect(page).to_have_title('Images') - - expect( - page.get_by_text('Select which input type to') - ).to_be_visible(timeout=100_000) - - # Digits example - page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() - page.get_by_text("Hand-written digit recognition").click() - - expect(page.get_by_text('Select a method to continue')).to_be_visible(timeout=100_000) - - page.locator('label').filter(has_text='RISE').locator('span').click() - page.locator('label').filter(has_text='KernelSHAP').locator('span').click() - page.locator('label').filter(has_text='LIME').locator('span').click() - page.get_by_test_id("stNumberInput-StepUp").click() - page.get_by_text('Running...').wait_for(state='detached', timeout=50_000) - - for selector in ( - page.get_by_role('heading', name='RISE').get_by_text('RISE'), - page.get_by_role('heading', name='KernelSHAP').get_by_text('KernelSHAP'), - page.get_by_role('heading', name='LIME').get_by_text('LIME'), - # first image - page.get_by_role('heading', name='0').get_by_text('0'), - page.get_by_role('img', name='0').first, - page.get_by_role('img', name='0').nth(1), - page.get_by_role('img', name='0').nth(2), - # second image - page.get_by_role('heading', name='1').get_by_text('1'), - page.get_by_role('img', name='0').nth(3), - page.get_by_role('img', name='0').nth(4), - page.get_by_role('img', name='0').nth(5), - ): - expect(selector).to_be_visible(timeout=100_000) - - # Own data - page.locator("label").filter(has_text="Use your own data").locator("div").nth(1).click() - expect(page.get_by_label("Select image").get_by_test_id("baseButton-secondary")).to_be_visible() - page.get_by_label("Select model").get_by_test_id("baseButton-secondary").click() - page.get_by_label("Select labels").get_by_test_id("baseButton-secondary").click() - - -def test_timeseries_page(page: Page): - """Test performance of timeseries page.""" - page.goto(f'{BASE_URL}/Time_series') - - page.get_by_text('Running...').wait_for(state='detached') - - expect(page).to_have_title('Time_series') - - expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000) - - page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() - expect(page.get_by_text("Select an example in the left")).to_be_visible() - expect(page.get_by_text("Weather")).to_be_visible() - expect(page.get_by_text("FRB")).to_be_visible() - - # Test weather example - page.locator("label").filter(has_text="Weather").locator("div").nth(1).click() - expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000) - - page.locator('label').filter(has_text='LIME').locator('span').click() - page.locator('label').filter(has_text='RISE').locator('span').click() - page.get_by_test_id("stNumberInput-StepUp").click() - page.get_by_text('Running...').wait_for(state='detached', timeout=100_000) - - for selector in ( - page.get_by_role('heading', name='LIME').get_by_text('LIME'), - page.get_by_role('heading', name='RISE').get_by_text('RISE'), - # First image - page.get_by_role('heading', name='winter').get_by_text('winter'), - page.get_by_role('img', name='0').first, - page.get_by_role('img', name='0').nth(1), - # Second image - page.get_by_role('heading', name='summer').get_by_text('summer'), - page.get_by_role('img', name='0').nth(2), - page.get_by_role('img', name='0').nth(3), - ): - expect(selector).to_be_visible() - - # Test FRB example - page.locator("label").filter(has_text="FRB").locator("div").nth(1).click() - expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000) - - page.locator('label').filter(has_text='RISE').locator('span').click() - - page.get_by_text('Running...').wait_for(state='detached', timeout=100_000) - - for selector in ( - page.get_by_role('heading', name='RISE').get_by_text('RISE'), - # First image - page.get_by_role('heading', name='FRB').get_by_text('FRB'), - page.get_by_role('img', name='0').first, - page.get_by_role('img', name='0').nth(1), - ): - expect(selector).to_be_visible() - - # Test using your own data - page.locator("label").filter( - has_text="Use your own data").locator("div").nth(1).click() - page.get_by_label("Select input data").get_by_test_id( - "baseButton-secondary").click() - page.get_by_label("Select model").get_by_test_id( - "baseButton-secondary").click() - page.get_by_label("Select labels").get_by_test_id( - "baseButton-secondary").click() - - -def test_tabular_page(page: Page): - """Test performance of tabular page.""" - page.goto(f'{BASE_URL}/Tabular') - - page.get_by_text('Running...').wait_for(state='detached') - - expect(page).to_have_title('Tabular') - - expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000) - - page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() - - # Test using your own data - page.locator("label").filter( - has_text="Use your own data").locator("div").nth(1).click() - page.get_by_label("Select tabular data").get_by_test_id("baseButton-secondary").click() - page.get_by_label("Select model").get_by_test_id("baseButton-secondary").click() - page.get_by_label("Select training data").get_by_test_id("baseButton-secondary").click() - page.get_by_label("Select labels in case of").get_by_test_id("baseButton-secondary").click() diff --git a/tests/test_dashboard_image.py b/tests/test_dashboard_image.py new file mode 100644 index 00000000..73b1af5b --- /dev/null +++ b/tests/test_dashboard_image.py @@ -0,0 +1,123 @@ +"""Module to test the dashboard. + +This test module uses (playwright)[https://playwright.dev/python/] +to test the user workflow. + +Installation: + + pip install pytest-playwright + playwright install + +Make sure that the server is running by: +```bash +cd dianna/dashboard +streamlit run Home.py +``` +Then, set variable `LOCAL=True` (see below) to connect to local instance for +debugging. Then, you can run the tests with: + +```bash +pytest -v -m dashboard --dashboard +``` +See more documentation about dashboard in: dianna/dashboard/readme.md + +For Code generation (https://playwright.dev/python/docs/codegen): + + playwright codegen http://localhost:8501 +""" + +import time +from contextlib import contextmanager +import pytest +from playwright.sync_api import Page +from playwright.sync_api import expect + +LOCAL = False + +PORT = '8501' if LOCAL else '8502' +BASE_URL = f'localhost:{PORT}' + +pytestmark = pytest.mark.dashboard + + +@pytest.fixture(scope='module', autouse=True) +def before_module(): + """Run dashboard in module scope.""" + with run_streamlit(): + yield + + +@contextmanager +def run_streamlit(): + """Run the dashboard.""" + import subprocess + + if not LOCAL: + p = subprocess.Popen([ + 'dianna-dashboard', + '--server.port', + PORT, + '--server.headless', + 'true', + ]) + time.sleep(5) + + yield + + if not LOCAL: + p.kill() + + +def test_image_page(page: Page): + """Test performance of image page.""" + page.set_viewport_size({"width": 1920, "height": 1080}) + + page.goto(f'{BASE_URL}/Images') + + page.get_by_text('Running...').wait_for(state='detached') + + expect(page).to_have_title('Images') + + expect( + page.get_by_text('Select which input type to') + ).to_be_visible(timeout=100_000) + + # Digits example + page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() + page.get_by_text("Hand-written digit recognition").click() + + expect(page.get_by_text('Select a method to continue')).to_be_visible(timeout=100_000) + + time.sleep(2) + + page.locator('label').filter(has_text='RISE').locator('span').click() + page.locator('label').filter(has_text='KernelSHAP').locator('span').click() + page.locator('label').filter(has_text='LIME').locator('span').click() + + page.get_by_label("Number of top classes to show").fill("2") + page.get_by_label("Number of top classes to show").press("Enter") + page.get_by_text('Running...').wait_for(state='detached', timeout=100_000) + + for selector in ( + page.get_by_role('heading', name='RISE').get_by_text('RISE'), + page.get_by_role('heading', name='KernelSHAP').get_by_text('KernelSHAP'), + page.get_by_role('heading', name='LIME').get_by_text('LIME'), + # first image + page.get_by_role('heading', name='0').get_by_text('0'), + page.get_by_role('img', name='0').first, + page.get_by_role('img', name='0').nth(1), + page.get_by_role('img', name='0').nth(2), + # second image + page.get_by_role('heading', name='1').get_by_text('1'), + page.get_by_role('img', name='0').nth(3), + page.get_by_role('img', name='0').nth(4), + page.get_by_role('img', name='0').nth(5), + ): + expect(selector).to_be_visible(timeout=200_000) + + # Own data + page.locator("label").filter(has_text="Use your own data").locator("div").nth(1).click() + + page.get_by_label("Select image").click() + page.get_by_label("Select model").click() + page.get_by_label("Select labels").click() diff --git a/tests/test_dashboard_setup.py b/tests/test_dashboard_setup.py new file mode 100644 index 00000000..75e1c602 --- /dev/null +++ b/tests/test_dashboard_setup.py @@ -0,0 +1,85 @@ +"""Module to test the dashboard. + +This test module uses (playwright)[https://playwright.dev/python/] +to test the user workflow. + +Installation: + + pip install pytest-playwright + playwright install + +Make sure that the server is running by: +```bash +cd dianna/dashboard +streamlit run Home.py +``` +Then, set variable `LOCAL=True` (see below) to connect to local instance for +debugging. Then, you can run the tests with: + +```bash +pytest -v -m dashboard --dashboard +``` +See more documentation about dashboard in: dianna/dashboard/readme.md + +For Code generation (https://playwright.dev/python/docs/codegen): + + playwright codegen http://localhost:8501 +""" + +import time +from contextlib import contextmanager +import pytest +from playwright.sync_api import Page +from playwright.sync_api import expect + +LOCAL = False + +PORT = '8501' if LOCAL else '8502' +BASE_URL = f'localhost:{PORT}' + +pytestmark = pytest.mark.dashboard + + +@pytest.fixture(scope='module', autouse=True) +def before_module(): + """Run dashboard in module scope.""" + with run_streamlit(): + yield + + +@contextmanager +def run_streamlit(): + """Run the dashboard.""" + import subprocess + + if not LOCAL: + p = subprocess.Popen([ + 'dianna-dashboard', + '--server.port', + PORT, + '--server.headless', + 'true', + ]) + time.sleep(5) + + yield + + if not LOCAL: + p.kill() + + +def test_page_load(page: Page): + """Test performance of landing page.""" + page.goto(BASE_URL) + + selector = page.get_by_text('Running...') + selector.wait_for(state='detached') + + expect(page).to_have_title("Dianna's dashboard") + + for selector in ( + page.get_by_role('img', name='0'), + page.get_by_text('More information'), + ): + expect(selector).to_be_visible() + diff --git a/tests/test_dashboard_tabular.py b/tests/test_dashboard_tabular.py new file mode 100644 index 00000000..f1fb4e76 --- /dev/null +++ b/tests/test_dashboard_tabular.py @@ -0,0 +1,181 @@ +"""Module to test the dashboard. + +This test module uses (playwright)[https://playwright.dev/python/] +to test the user workflow. + +Installation: + + pip install pytest-playwright + playwright install + +Make sure that the server is running by: +```bash +cd dianna/dashboard +streamlit run Home.py +``` +Then, set variable `LOCAL=True` (see below) to connect to local instance for +debugging. Then, you can run the tests with: + +```bash +pytest -v -m dashboard --dashboard +``` +See more documentation about dashboard in: dianna/dashboard/readme.md + +For Code generation (https://playwright.dev/python/docs/codegen): + + playwright codegen http://localhost:8501 +""" + +import time +from contextlib import contextmanager +import pytest +from playwright.sync_api import Page +from playwright.sync_api import expect + +LOCAL = False + +PORT = '8501' if LOCAL else '8502' +BASE_URL = f'localhost:{PORT}' + +pytestmark = pytest.mark.dashboard + + +@pytest.fixture(scope='module', autouse=True) +def before_module(): + """Run dashboard in module scope.""" + with run_streamlit(): + yield + + +@contextmanager +def run_streamlit(): + """Run the dashboard.""" + import subprocess + + if not LOCAL: + p = subprocess.Popen([ + 'dianna-dashboard', + '--server.port', + PORT, + '--server.headless', + 'true', + ]) + time.sleep(5) + + yield + + if not LOCAL: + p.kill() + + +def test_tabular_page(page: Page): + """Test performance of tabular page.""" + page.set_viewport_size({"width": 1920, "height": 1080}) + + page.goto(f'{BASE_URL}/Tabular') + + page.get_by_text('Running...').wait_for(state='detached') + + expect(page).to_have_title('Tabular') + + expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000) + + # Test using your own data + page.locator("label").filter( + has_text="Use your own data").locator("div").nth(1).click() + + page.get_by_label("Select tabular data").click() + page.get_by_label("Select model").click() + page.get_by_label("Select training data").click() + page.get_by_label("Select labels in case of").click() + + +def test_tabular_sunshine(page: Page): + """Test tabular sunshine example.""" + page.set_viewport_size({"width": 1920, "height": 1080}) + + page.goto(f'{BASE_URL}/Tabular') + + page.get_by_text('Running...').wait_for(state='detached') + + expect(page).to_have_title('Tabular') + + expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000) + + page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() + expect(page.get_by_text("Select an example in the left")).to_be_visible() + expect(page.get_by_text("Sunshine hours prediction")).to_be_visible() + expect(page.get_by_text("Penguin identification")).to_be_visible() + + # Test sunshine example + page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() + page.locator("label").filter(has_text="Sunshine hours prediction").locator("div").nth(1).click() + expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000) + + time.sleep(2) + + page.locator("label").filter(has_text="RISE").locator("span").click() + page.locator("label").filter(has_text="LIME").locator("span").click() + page.locator("label").filter(has_text="KernelSHAP").locator("span").click() + page.locator("summary").filter(has_text="Click to modify RISE").get_by_test_id("stExpanderToggleIcon").click() + + expect(page.get_by_text("Select the input data by")).to_be_visible(timeout=100_000) + page.frame_locator("iframe[title=\"st_aggrid\\.agGrid\"]").get_by_role( + "gridcell", name="10", exact=True).click() + page.get_by_text('Running...').wait_for(state='detached', timeout=200_000) + + expect(page.get_by_text("3.07")).to_be_visible(timeout=200_000) + + for selector in ( + page.get_by_role('heading', name='RISE').get_by_text('RISE'), + page.get_by_role('heading', name='KernelSHAP').get_by_text('KernelSHAP'), + page.get_by_role('heading', name='LIME').get_by_text('LIME'), + page.get_by_role('img', name='0').first, + page.get_by_role('img', name='0').nth(1), + page.get_by_role('img', name='0').nth(2), + ): + expect(selector).to_be_visible(timeout=100_000) + + +def test_tabular_penguin(page: Page): + """Test performance of tabular penguin example.""" + page.set_viewport_size({"width": 1920, "height": 1080}) + + page.goto(f'{BASE_URL}/Tabular') + page.get_by_text('Running...').wait_for(state='detached') + + expect(page).to_have_title('Tabular') + expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000) + + page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() + expect(page.get_by_text("Select an example in the left")).to_be_visible() + expect(page.get_by_text("Sunshine hours prediction")).to_be_visible() + expect(page.get_by_text("Penguin identification")).to_be_visible() + + # Test sunshine example + page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() + page.locator("label").filter(has_text="Penguin identification").locator("div").nth(1).click() + expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000) + + time.sleep(2) + + page.locator("label").filter(has_text="RISE").locator("span").click(timeout=300_000) + page.locator("label").filter(has_text="LIME").locator("span").click(timeout=300_000) + page.locator("label").filter(has_text="KernelSHAP").locator("span").click(timeout=300_000) + + expect(page.get_by_text("Select the input data by")).to_be_visible(timeout=300_000) + page.frame_locator("iframe[title=\"st_aggrid\\.agGrid\"]").get_by_role( + "gridcell", name="10", exact=True).click() + page.get_by_text('Running...').wait_for(state='detached', timeout=300_000) + + for selector in ( + page.get_by_text('Predicted class:'), + page.get_by_test_id('stMetricValue').get_by_text('Gentoo'), + page.get_by_role('heading', name='RISE').get_by_text('RISE'), + page.get_by_role('heading', name='KernelSHAP').get_by_text('KernelSHAP'), + page.get_by_role('heading', name='LIME').get_by_text('LIME'), + page.get_by_role('img', name='0').first, + page.get_by_role('img', name='0').nth(1), + page.get_by_role('img', name='0').nth(2), + ): + expect(selector).to_be_visible(timeout=200_000) diff --git a/tests/test_dashboard_text.py b/tests/test_dashboard_text.py new file mode 100644 index 00000000..4c10d5ec --- /dev/null +++ b/tests/test_dashboard_text.py @@ -0,0 +1,115 @@ +"""Module to test the dashboard. + +This test module uses (playwright)[https://playwright.dev/python/] +to test the user workflow. + +Installation: + + pip install pytest-playwright + playwright install + +Make sure that the server is running by: +```bash +cd dianna/dashboard +streamlit run Home.py +``` +Then, set variable `LOCAL=True` (see below) to connect to local instance for +debugging. Then, you can run the tests with: + +```bash +pytest -v -m dashboard --dashboard +``` +See more documentation about dashboard in: dianna/dashboard/readme.md + +For Code generation (https://playwright.dev/python/docs/codegen): + + playwright codegen http://localhost:8501 +""" + +import time +from contextlib import contextmanager +import pytest +from playwright.sync_api import Page +from playwright.sync_api import expect + +LOCAL = False + +PORT = '8501' if LOCAL else '8502' +BASE_URL = f'localhost:{PORT}' + +pytestmark = pytest.mark.dashboard + + +@pytest.fixture(scope='module', autouse=True) +def before_module(): + """Run dashboard in module scope.""" + with run_streamlit(): + yield + + +@contextmanager +def run_streamlit(): + """Run the dashboard.""" + import subprocess + + if not LOCAL: + p = subprocess.Popen([ + 'dianna-dashboard', + '--server.port', + PORT, + '--server.headless', + 'true', + ]) + time.sleep(5) + + yield + + if not LOCAL: + p.kill() + + +def test_text_page(page: Page): + """Test performance of text page.""" + page.set_viewport_size({"width": 1920, "height": 1080}) + + page.goto(f'{BASE_URL}/Text') + page.get_by_text('Running...').wait_for(state='detached') + expect(page).to_have_title('Text') + # Movie sentiment example + page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() + page.get_by_text("Movie sentiment").click() + expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=50_000) + + time.sleep(2) + page.locator('label').filter(has_text='RISE').locator('span').click() + page.locator('label').filter(has_text='LIME').locator('span').click() + + page.get_by_label("Number of top classes to show").fill("2") + page.get_by_label("Number of top classes to show").press("Enter") + page.get_by_text('Running...').wait_for(state='detached', timeout=100_000) + + for selector in ( + page.get_by_role('heading', name='RISE').get_by_text('RISE'), + page.get_by_role('heading', name='LIME').get_by_text('LIME'), + # Images for positive (RISE/LIME) + page.get_by_role('heading', + name='positive').get_by_text('positive'), + page.get_by_role('img', name='0').first, + page.get_by_role('img', name='0').nth(1), +# # Images for negative (RISE/LIME) + page.get_by_role('heading', + name='negative').get_by_text('negative'), + page.get_by_role('img', name='0').nth(2), + page.get_by_role('img', name='0').nth(3), + ): + expect(selector).to_be_visible(timeout=100_000) + + # Own data option + page.locator("label").filter(has_text="Use your own data").locator("div").nth(1).click() + selector = page.get_by_text( + 'Add your input data in the left panel to continue') + expect(selector).to_be_visible(timeout=30_000) + # Check input panel + expect(page.get_by_label("Input string")).to_be_visible(timeout=200_000) + page.get_by_label("Select model").click() + page.get_by_label("Select labels").click() diff --git a/tests/test_dashboard_time_series.py b/tests/test_dashboard_time_series.py new file mode 100644 index 00000000..a4baacd1 --- /dev/null +++ b/tests/test_dashboard_time_series.py @@ -0,0 +1,146 @@ +"""Module to test the dashboard. + +This test module uses (playwright)[https://playwright.dev/python/] +to test the user workflow. + +Installation: + + pip install pytest-playwright + playwright install + +Make sure that the server is running by: +```bash +cd dianna/dashboard +streamlit run Home.py +``` +Then, set variable `LOCAL=True` (see below) to connect to local instance for +debugging. Then, you can run the tests with: + +```bash +pytest -v -m dashboard --dashboard +``` +See more documentation about dashboard in: dianna/dashboard/readme.md + +For Code generation (https://playwright.dev/python/docs/codegen): + + playwright codegen http://localhost:8501 +""" + +import time +from contextlib import contextmanager +import pytest +from playwright.sync_api import Page +from playwright.sync_api import expect + +LOCAL = False + +PORT = '8501' if LOCAL else '8502' +BASE_URL = f'localhost:{PORT}' + +pytestmark = pytest.mark.dashboard + + +@pytest.fixture(scope='module', autouse=True) +def before_module(): + """Run dashboard in module scope.""" + with run_streamlit(): + yield + + +@contextmanager +def run_streamlit(): + """Run the dashboard.""" + import subprocess + + if not LOCAL: + p = subprocess.Popen([ + 'dianna-dashboard', + '--server.port', + PORT, + '--server.headless', + 'true', + ]) + time.sleep(5) + + yield + + if not LOCAL: + p.kill() + + +def test_timeseries_page(page: Page): + """Test performance of timeseries page.""" + page.set_viewport_size({"width": 1920, "height": 1080}) + + page.goto(f'{BASE_URL}/Time_series') + + page.get_by_text('Running...').wait_for(state='detached') + + expect(page).to_have_title('Time_series') + + expect(page.get_by_text("Select which input type to")).to_be_visible(timeout=100_000) + + page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() + expect(page.get_by_text("Select an example in the left")).to_be_visible(timeout=200_000) + expect(page.get_by_text("Weather")).to_be_visible() + expect(page.get_by_text("FRB")).to_be_visible() + + # Test weather example + page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() + page.locator("label").filter(has_text="Weather").locator("div").nth(1).click() + expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000) + + time.sleep(2) + + page.locator('label').filter(has_text='LIME').locator('span').click(timeout=200_000) + page.locator('label').filter(has_text='RISE').locator('span').click(timeout=200_000) + + page.get_by_label("Number of top classes to show").fill("2") + page.get_by_label("Number of top classes to show").press("Enter") + page.get_by_text('Running...').wait_for(state='detached', timeout=100_000) + + for selector in ( + page.get_by_role('heading', name='LIME').get_by_text('LIME'), + page.get_by_role('heading', name='RISE').get_by_text('RISE'), + # First image + page.get_by_role('heading', name='winter').get_by_text('winter'), + page.get_by_role('img', name='0').first, + page.get_by_role('img', name='0').nth(1), + # Second image + page.get_by_role('heading', name='summer').get_by_text('summer'), + page.get_by_role('img', name='0').nth(2), + page.get_by_role('img', name='0').nth(3), + ): + expect(selector).to_be_visible(timeout=100_000) + + # Test FRB example + page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click() + page.locator("label").filter(has_text="FRB").locator("div").nth(1).click() + expect(page.get_by_text("Select a method to continue")).to_be_visible(timeout=100_000) + + time.sleep(2) + + page.locator('label').filter(has_text='RISE').locator('span').click() + + page.get_by_label("Number of top classes to show").fill("2") + page.get_by_label("Number of top classes to show").press("Enter") + + page.get_by_text('Running...').wait_for(state='detached', timeout=100_000) + + for selector in ( + page.get_by_role('heading', name='RISE').get_by_text('RISE'), + # First image + page.get_by_role('heading', name='FRB').get_by_text('FRB'), + page.get_by_role('img', name='0').nth(1), + # Second image + page.get_by_role('heading', name='Noise').get_by_text('Noise'), + page.get_by_role('img', name='0').nth(2), + ): + expect(selector).to_be_visible(timeout=300_000) + + # Test using your own data + page.locator("label").filter( + has_text="Use your own data").locator("div").nth(1).click() + page.get_by_label("Select input data").click() + page.get_by_label("Select model").click() + page.get_by_label("Select labels").click()