Skip to content

Commit

Permalink
test(interactive): improve interactive tool test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
kmnhan committed Jul 15, 2024
1 parent 37d93bc commit 7f24b81
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 20 deletions.
2 changes: 2 additions & 0 deletions src/erlab/interactive/fermiedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AnalysisWindow,
ParameterGroup,
ROIControls,
_coverage_resolve_trace,
gen_function_code,
xImageItem,
)
Expand Down Expand Up @@ -74,6 +75,7 @@ def abort_fit(self) -> None:
self.parallel_obj._aborting = True
self.parallel_obj._exception = True

@_coverage_resolve_trace
def run(self) -> None:
self.sigIterated.emit(0)
with joblib_progress_qt(self.sigIterated) as _:
Expand Down
15 changes: 7 additions & 8 deletions src/erlab/interactive/imagetool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ def _parse_input(
| npt.NDArray
| xr.Dataset,
) -> list[xr.DataArray]:
if isinstance(data, xr.Dataset):
data = [d for d in data.data_vars.values() if d.ndim >= 2 and d.ndim <= 4]
if len(data) == 0:
raise ValueError("No valid data for ImageTool found in the Dataset")

if isinstance(data, np.ndarray | xr.DataArray):
data = (data,)
elif isinstance(data, xr.Dataset):
data = tuple(d for d in data.data_vars.values() if d.ndim >= 2 and d.ndim <= 4)
if len(data) == 0:
raise ValueError("No valid data for ImageTool found in the Dataset")

return [xr.DataArray(d) if not isinstance(d, xr.DataArray) else d for d in data]

Expand Down Expand Up @@ -165,9 +164,6 @@ def itool(
for w in itool_list:
w.show()

if len(itool_list) == 0:
raise ValueError("No data provided")

if link:
linker = SlicerLinkProxy( # noqa: F841
*[w.slicer_area for w in itool_list], link_colors=link_colors
Expand Down Expand Up @@ -403,8 +399,11 @@ def _update_title(self) -> None:
path: str | None = self.slicer_area._file_path

if name is not None and name.strip() == "":
# Name contains only whitespace
name = None

if path is not None:
# If opened from a file
path = os.path.basename(path)

if name is None and path is None:
Expand Down
13 changes: 1 addition & 12 deletions src/erlab/interactive/imagetool/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
__all__ = ["is_running", "main", "show_in_manager"]

import contextlib
import functools
import gc
import os
import pickle
Expand All @@ -37,6 +36,7 @@
from erlab.interactive.imagetool import ImageTool, _parse_input
from erlab.interactive.imagetool.controls import IconButton
from erlab.interactive.imagetool.core import SlicerLinkProxy
from erlab.interactive.utils import _coverage_resolve_trace

if TYPE_CHECKING:
from collections.abc import Collection
Expand Down Expand Up @@ -72,17 +72,6 @@
"""Colors for different linkers."""


def _coverage_resolve_trace(fn):
# https://github.com/nedbat/coveragepy/issues/686#issuecomment-634932753
@functools.wraps(fn)
def _wrapped_for_coverage(*args, **kwargs) -> None:
if threading._trace_hook: # type: ignore[attr-defined]
sys.settrace(threading._trace_hook) # type: ignore[attr-defined]
fn(*args, **kwargs)

return _wrapped_for_coverage


def _save_pickle(obj: Any, filename: str) -> None:
with open(filename, "wb") as file:
pickle.dump(obj, file)
Expand Down
13 changes: 13 additions & 0 deletions src/erlab/interactive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

import functools
import re
import sys
import threading
import types
import warnings
from typing import TYPE_CHECKING, Any, Literal, cast, no_type_check
Expand Down Expand Up @@ -37,6 +39,17 @@
]


def _coverage_resolve_trace(fn):
# https://github.com/nedbat/coveragepy/issues/686#issuecomment-634932753
@functools.wraps(fn)
def _wrapped_for_coverage(*args, **kwargs) -> None:
if threading._trace_hook: # type: ignore[attr-defined]
sys.settrace(threading._trace_hook) # type: ignore[attr-defined]
fn(*args, **kwargs)

return _wrapped_for_coverage


def parse_data(data) -> xr.DataArray:
if isinstance(data, xr.Dataset):
raise TypeError(
Expand Down
45 changes: 45 additions & 0 deletions tests/interactive/test_imagetool.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,51 @@ def test_itool(qtbot):
win.slicer_area.state = expected_state
assert win.slicer_area.state == expected_state

# Setting data
win.slicer_area.set_data(data.rename("new_data"))
assert win.windowTitle() == "new_data"


def test_itool_ds(qtbot):
# If no 2D to 4D data is present in given Dataset, ValueError is raised
with pytest.raises(
ValueError, match="No valid data for ImageTool found in the Dataset"
):
itool(
xr.Dataset(
{
"data1d": xr.DataArray(np.arange(5), dims=["x"]),
"data0d": 1,
}
),
execute=False,
)

data = xr.Dataset(
{
"data1d": xr.DataArray(np.arange(5), dims=["x"]),
"a": xr.DataArray(np.arange(25).reshape((5, 5)), dims=["x", "y"]),
"b": xr.DataArray(np.arange(25).reshape((5, 5)), dims=["x", "y"]),
}
)
wins = itool(data, execute=False, link=True)
assert isinstance(wins, list)
assert len(wins) == 2

qtbot.addWidget(wins[0])
qtbot.addWidget(wins[1])

with qtbot.waitExposed(wins[0]):
wins[0].show()
with qtbot.waitExposed(wins[1]):
wins[1].show()

assert wins[0].windowTitle() == "a"
assert wins[1].windowTitle() == "b"

# Check if properly linked
assert wins[0].slicer_area._linking_proxy == wins[1].slicer_area._linking_proxy


def test_value_update(qtbot):
data = xr.DataArray(np.arange(25).reshape((5, 5)), dims=["x", "y"])
Expand Down
18 changes: 18 additions & 0 deletions tests/interactive/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pyperclip
import xarray as xr
from erlab.interactive.bzplot import BZPlotter
from erlab.interactive.curvefittingtool import edctool, mdctool
Expand All @@ -15,6 +16,23 @@ def test_goldtool(qtbot, gold):
win.show()
win.activateWindow()
win.raise_()
win.params_edge.widgets["# CPU"].setValue(1)
win.params_edge.widgets["Fast"].setChecked(True)
with qtbot.waitSignal(win.fitter.sigFinished):
win.params_edge.widgets["go"].click()
win.params_poly.widgets["copy"].click()
assert (
pyperclip.paste()
== """modelresult = era.gold.poly(
gold,
angle_range=(-13.5, 13.5),
eV_range=(-0.204, 0.276),
bin_size=(1, 1),
method="leastsq",
degree=4,
fast=True,
)"""
)


def test_dtool(qtbot):
Expand Down

0 comments on commit 7f24b81

Please sign in to comment.