diff --git a/doc/changes/devel.rst b/doc/changes/devel.rst index d4532d60721..84a48ef6e5c 100644 --- a/doc/changes/devel.rst +++ b/doc/changes/devel.rst @@ -23,7 +23,7 @@ Version 1.6.dev0 (development) Enhancements ~~~~~~~~~~~~ -- Improve tests for saving splits with `Epochs` (:gh:`11884` by `Dmitrii Altukhov`_) +- Improve tests for saving splits with :class:`mne.Epochs` (:gh:`11884` by `Dmitrii Altukhov`_) - Added functionality for linking interactive figures together, such that changing one figure will affect another, see :ref:`tut-ui-events` and :mod:`mne.viz.ui_events`. Current figures implementing UI events are :func:`mne.viz.plot_topomap` and :func:`mne.viz.plot_source_estimates` (:gh:`11685` :gh:`11891` by `Marijn van Vliet`_) - HTML anchors for :class:`mne.Report` now reflect the ``section-title`` of the report items rather than using a global incrementor ``global-N`` (:gh:`11890` by `Eric Larson`_) - Added public :func:`mne.io.write_info` to complement :func:`mne.io.read_info` (:gh:`11918` by `Eric Larson`_) @@ -37,6 +37,7 @@ Enhancements - Add support for writing forward solutions to HDF5 and convenience function :meth:`mne.Forward.save` (:gh:`12036` by `Eric Larson`_) - Refactored internals of :func:`mne.read_annotations` (:gh:`11964` by `Paul Roujansky`_) - Enhance :func:`~mne.viz.plot_evoked_field` with a GUI that has controls for time, colormap, and contour lines (:gh:`11942` by `Marijn van Vliet`_) +- Add :class:`mne.viz.ui_events.UIEvent` linking for interactive colorbars, allowing users to link figures and change the colormap and limits interactively. This supports :func:`~mne.viz.plot_evoked_topomap`, :func:`~mne.viz.plot_ica_components`, :func:`~mne.viz.plot_tfr_topomap`, :func:`~mne.viz.plot_projs_topomap`, :meth:`~mne.Evoked.plot_image`, and :meth:`~mne.Epochs.plot_image` (:gh:`12057` by `Santeri Ruuskanen`_) Bugs ~~~~ diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 3e94308b10c..d173c80a45b 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -661,7 +661,9 @@ def _plot_epochs_image( this_colorbar = cbar(im, cax=ax["colorbar"]) this_colorbar.ax.set_ylabel(unit, rotation=270, labelpad=12) if cmap[1]: - ax_im.CB = DraggableColorbar(this_colorbar, im) + ax_im.CB = DraggableColorbar( + this_colorbar, im, kind="epochs_image", ch_type=unit + ) with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") tight_layout(fig=fig) diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 34a9d60dfe3..687203cad49 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -958,7 +958,7 @@ def _plot_image( cbar = plt.colorbar(im, ax=ax) cbar.ax.set_title(ch_unit) if cmap[1]: - ax.CB = DraggableColorbar(cbar, im) + ax.CB = DraggableColorbar(cbar, im, "evoked_image", this_type) ylabel = "Channels" if show_names else "Channel (index)" t = titles[this_type] + " (%d channel%s" % (len(data), _pl(data)) + t_end diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index b4a418652f1..c7f339c9997 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -12,6 +12,8 @@ from mne.viz.utils import ( compare_fiff, _fake_click, + _fake_keypress, + _fake_scroll, _compute_scalings, _validate_if_list_of_axes, _get_color_list, @@ -20,15 +22,18 @@ _make_event_color_dict, concatenate_images, ) +from mne.viz.ui_events import link, subscribe, ColormapRange from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap from mne.io import read_raw_fif from mne.event import read_events from mne.epochs import Epochs +from mne import read_evokeds base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" cov_fname = base_dir / "test-cov.fif" ev_fname = base_dir / "test_raw-eve.fif" +ave_fname = base_dir / "test-ave.fif" def test_setup_vmin_vmax_warns(): @@ -202,3 +207,71 @@ def test_concatenate_images(a_w, a_h, b_w, b_h, axis): else: want_shape = (max(a_h, b_h), a_w + b_w, 3) assert img.shape == want_shape + + +def test_draggable_colorbar(): + """Test that DraggableColorbar publishes correct UI Events.""" + evokeds = read_evokeds(ave_fname) + left_auditory = evokeds[0] + right_auditory = evokeds[1] + vmin, vmax = -400, 400 + fig = left_auditory.plot_topomap("interactive", vlim=(vmin, vmax)) + fig2 = right_auditory.plot_topomap("interactive", vlim=(vmin, vmax)) + link(fig, fig2) + callback_calls = [] + + def callback(event): + callback_calls.append(event) + + subscribe(fig, "colormap_range", callback) + + # Test that correct event is published + _fake_keypress(fig, "down") + _fake_keypress(fig, "up") + assert len(callback_calls) == 2 + event = callback_calls.pop() + assert type(event) is ColormapRange + # Test that scrolling changes color limits + _fake_scroll(fig, 10, 10, 1) + event = callback_calls.pop() + assert abs(event.fmin) < abs(vmin) + assert abs(event.fmax) < abs(vmax) + fmin, fmax = event.fmin, event.fmax + _fake_scroll(fig, 10, 10, -1) + event = callback_calls.pop() + assert abs(event.fmin) > abs(fmin) + assert abs(event.fmax) > abs(fmax) + fmin, fmax = event.fmin, event.fmax + # Test that plus and minus change color limits + _fake_keypress(fig, "+") + event = callback_calls.pop() + assert abs(event.fmin) < abs(fmin) + assert abs(event.fmax) < abs(fmax) + fmin, fmax = event.fmin, event.fmax + _fake_keypress(fig, "-") + event = callback_calls.pop() + assert abs(event.fmin) > abs(fmin) + assert abs(event.fmax) > abs(fmax) + fmin, fmax = event.fmin, event.fmax + # Test that page up and page down change color limits + _fake_keypress(fig, "pageup") + event = callback_calls.pop() + assert event.fmin < fmin + assert event.fmax < fmax + fmin, fmax = event.fmin, event.fmax + _fake_keypress(fig, "pagedown") + event = callback_calls.pop() + assert event.fmin > fmin + assert event.fmax > fmax + # Test that space key resets color limits + _fake_keypress(fig, " ") + event = callback_calls.pop() + assert event.fmax == vmax + assert event.fmin == vmin + # Test that colormap change in one figure changes that of another one + cmap_want = fig.axes[0].CB.cycle[fig.axes[0].CB.index + 1] + cmap_old = fig.axes[0].CB.mappable.get_cmap().name + _fake_keypress(fig, "down") + cmap_new1 = fig.axes[0].CB.mappable.get_cmap().name + cmap_new2 = fig2.axes[0].CB.mappable.get_cmap().name + assert cmap_new1 == cmap_new2 == cmap_want != cmap_old diff --git a/mne/viz/topo.py b/mne/viz/topo.py index a01ee72a0c2..683c22d9a6a 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -457,7 +457,7 @@ def _imshow_tfr( else: cbar = plt.colorbar(mappable=img, ax=ax) if interactive_cmap: - ax.CB = DraggableColorbar(cbar, img) + ax.CB = DraggableColorbar(cbar, img, kind="tfr_image", ch_type=None) ax.RS = RectangleSelector(ax, onselect=onselect) # reference must be kept return t_end diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index 0802362a27f..d47ec145e07 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -298,7 +298,16 @@ def _plot_update_evoked_topomap(params, bools): def _add_colorbar( - ax, im, cmap, side="right", pad=0.05, title=None, format=None, size="5%" + ax, + im, + cmap, + side="right", + pad=0.05, + title=None, + format=None, + size="5%", + kind=None, + ch_type=None, ): """Add a colorbar to an axis.""" import matplotlib.pyplot as plt @@ -308,7 +317,7 @@ def _add_colorbar( cax = divider.append_axes(side, size=size, pad=pad) cbar = plt.colorbar(im, cax=cax, format=format) if cmap is not None and cmap[1]: - ax.CB = DraggableColorbar(cbar, im) + ax.CB = DraggableColorbar(cbar, im, kind, ch_type) if title is not None: cax.set_title(title, y=1.05, fontsize=10) return cbar, cax @@ -587,7 +596,15 @@ def _plot_projs_topomap( ) if colorbar: - _add_colorbar(ax, im, cmap, title=units, format=cbar_fmt) + _add_colorbar( + ax, + im, + cmap, + title=units, + format=cbar_fmt, + kind="projs_topomap", + ch_type=_ch_type, + ) return ax.get_figure() @@ -973,7 +990,7 @@ def plot_topomap( .. versionadded:: 0.20 %(res_topomap)s %(size_topomap)s - %(cmap_topomap_simple)s + %(cmap_topomap)s %(vlim_plot_topomap)s .. versionadded:: 1.2 @@ -1454,7 +1471,16 @@ def _plot_ica_topomap( ch_type=ch_type, )[0] if colorbar: - cbar, cax = _add_colorbar(axes, im, cmap, pad=0.05, title="AU", format="%3.2f") + cbar, cax = _add_colorbar( + axes, + im, + cmap, + pad=0.05, + title="AU", + format="%3.2f", + kind="ica_topomap", + ch_type=ch_type, + ) cbar.ax.tick_params(labelsize=12) cbar.set_ticks(vlim) _hide_frame(axes) @@ -1685,7 +1711,15 @@ def plot_ica_components( im.axes.set_label(ica._ica_names[ii]) if colorbar: cbar, cax = _add_colorbar( - ax, im, cmap, title="AU", side="right", pad=0.05, format=cbar_fmt + ax, + im, + cmap, + title="AU", + side="right", + pad=0.05, + format=cbar_fmt, + kind="ica_comp_topomap", + ch_type=ch_type, ) cbar.ax.tick_params(labelsize=12) cbar.set_ticks(_vlim) @@ -1956,7 +1990,15 @@ def plot_tfr_topomap( from matplotlib import ticker units = _handle_default("units", units)["misc"] - cbar, cax = _add_colorbar(axes, im, cmap, title=units, format=cbar_fmt) + cbar, cax = _add_colorbar( + axes, + im, + cmap, + title=units, + format=cbar_fmt, + kind="tfr_topomap", + ch_type=ch_type, + ) if locator is None: locator = ticker.MaxNLocator(nbins=5) cbar.locator = locator @@ -2363,6 +2405,11 @@ def _slider_changed(val): kwargs=kwargs, ), ) + subscribe( + fig, + "colormap_range", + partial(_on_colormap_range, kwargs=kwargs), + ) if colorbar: if interactive: @@ -2383,7 +2430,9 @@ def _slider_changed(val): cbar.ax.tick_params(labelsize=7) if cmap[1]: for im in images: - im.axes.CB = DraggableColorbar(cbar, im) + im.axes.CB = DraggableColorbar( + cbar, im, kind="evoked_topomap", ch_type=ch_type + ) if proj == "interactive": _check_delayed_ssp(evoked) @@ -2460,6 +2509,11 @@ def _on_time_change( ax.figure.canvas.draw_idle() +def _on_colormap_range(event, kwargs): + """Handle updating colormap range.""" + kwargs.update(vlim=(event.fmin, event.fmax), cmap=event.cmap) + + def _plot_topomap_multi_cbar( data, pos, diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index ba5b1db9a33..78c1419ca2f 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -11,10 +11,12 @@ """ import contextlib from dataclasses import dataclass -from typing import Optional, List +from typing import Optional, List, Union import weakref import re +from matplotlib.colors import Colormap + from ..utils import warn, fill_doc, _validate_type, logger, verbose # Global dict {fig: channel} containing all currently active event channels. @@ -114,26 +116,38 @@ class ColormapRange(UIEvent): kind : str Kind of colormap being updated. The Notes section of the drawing routine publishing this event should mention the possible kinds. + ch_type : str + Type of sensor the data originates from. %(fmin_fmid_fmax)s %(alpha)s + cmap : str + The colormap to use. Either string or matplotlib.colors.Colormap + instance. Attributes ---------- kind : str Kind of colormap being updated. The Notes section of the drawing routine publishing this event should mention the possible kinds. + ch_type : str + Type of sensor the data originates from. unit : str The unit of the values. %(ui_event_name_source)s %(fmin_fmid_fmax)s %(alpha)s + cmap : str + The colormap to use. Either string or matplotlib.colors.Colormap + instance. """ kind: str + ch_type: Optional[str] = None fmin: Optional[float] = None fmid: Optional[float] = None fmax: Optional[float] = None alpha: Optional[bool] = None + cmap: Optional[Union[Colormap, str]] = None @dataclass diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 264505b67ad..78f05ee9109 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -63,6 +63,7 @@ check_version, _check_decim, ) +from .ui_events import publish, subscribe, ColormapRange from ..transforms import apply_trans @@ -1569,11 +1570,14 @@ class DraggableColorbar: See http://www.ster.kuleuven.be/~pieterd/python/html/plotting/interactive_colorbar.html """ # noqa: E501 - def __init__(self, cbar, mappable): + def __init__(self, cbar, mappable, kind, ch_type): import matplotlib.pyplot as plt self.cbar = cbar self.mappable = mappable + self.kind = kind + self.ch_type = ch_type + self.fig = self.cbar.ax.figure self.press = None self.cycle = sorted( [i for i in dir(plt.cm) if hasattr(getattr(plt.cm, i), "N")] @@ -1582,6 +1586,7 @@ def __init__(self, cbar, mappable): self.index = self.cycle.index(mappable.get_cmap().name) self.lims = (self.cbar.norm.vmin, self.cbar.norm.vmax) self.connect() + subscribe(self.fig, "colormap_range", self._on_colormap_range) def connect(self): """Connect to all the events we need.""" @@ -1640,7 +1645,7 @@ def key_press(self, event): self.cbar.mappable.set_cmap(cmap) _draw_without_rendering(self.cbar) self.mappable.set_cmap(cmap) - self._update() + self._publish() def on_motion(self, event): """Handle mouse movements.""" @@ -1659,7 +1664,7 @@ def on_motion(self, event): elif event.button == 3: self.cbar.norm.vmin -= (perc * scale) * np.sign(dy) self.cbar.norm.vmax += (perc * scale) * np.sign(dy) - self._update() + self._publish() def on_release(self, event): """Handle release.""" @@ -1671,8 +1676,32 @@ def on_scroll(self, event): scale = 1.1 if event.step < 0 else 1.0 / 1.1 self.cbar.norm.vmin *= scale self.cbar.norm.vmax *= scale + self._publish() + + def _on_colormap_range(self, event): + if event.kind != self.kind or event.ch_type != self.ch_type: + return + if event.fmin is not None: + self.cbar.norm.vmin = event.fmin + if event.fmax is not None: + self.cbar.norm.vmax = event.fmax + if event.cmap is not None: + self.cbar.mappable.set_cmap(event.cmap) + self.mappable.set_cmap(event.cmap) self._update() + def _publish(self): + publish( + self.fig, + ColormapRange( + kind=self.kind, + ch_type=self.ch_type, + fmin=self.cbar.norm.vmin, + fmax=self.cbar.norm.vmax, + cmap=self.mappable.get_cmap(), + ), + ) + def _update(self): from matplotlib.ticker import AutoLocator