Skip to content

Commit

Permalink
Add UI Event linking to DraggableColorbar (mne-tools#12057)
Browse files Browse the repository at this point in the history
Co-authored-by: Santeri Ruuskanen <[email protected]>
Co-authored-by: Marijn van Vliet <[email protected]>
  • Loading branch information
3 people authored and snwnde committed Mar 20, 2024
1 parent 3a4d023 commit 242c18e
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 16 deletions.
3 changes: 2 additions & 1 deletion doc/changes/devel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Version 1.6.dev0 (development)

Enhancements
~~~~~~~~~~~~
- Improve tests for saving splits with `Epochs` (:gh:`11884` by `Dmitrii Altukhov`_)
- Improve tests for saving splits with :class:`mne.Epochs` (:gh:`11884` by `Dmitrii Altukhov`_)
- Added functionality for linking interactive figures together, such that changing one figure will affect another, see :ref:`tut-ui-events` and :mod:`mne.viz.ui_events`. Current figures implementing UI events are :func:`mne.viz.plot_topomap` and :func:`mne.viz.plot_source_estimates` (:gh:`11685` :gh:`11891` by `Marijn van Vliet`_)
- HTML anchors for :class:`mne.Report` now reflect the ``section-title`` of the report items rather than using a global incrementor ``global-N`` (:gh:`11890` by `Eric Larson`_)
- Added public :func:`mne.io.write_info` to complement :func:`mne.io.read_info` (:gh:`11918` by `Eric Larson`_)
Expand All @@ -37,6 +37,7 @@ Enhancements
- Add support for writing forward solutions to HDF5 and convenience function :meth:`mne.Forward.save` (:gh:`12036` by `Eric Larson`_)
- Refactored internals of :func:`mne.read_annotations` (:gh:`11964` by `Paul Roujansky`_)
- Enhance :func:`~mne.viz.plot_evoked_field` with a GUI that has controls for time, colormap, and contour lines (:gh:`11942` by `Marijn van Vliet`_)
- Add :class:`mne.viz.ui_events.UIEvent` linking for interactive colorbars, allowing users to link figures and change the colormap and limits interactively. This supports :func:`~mne.viz.plot_evoked_topomap`, :func:`~mne.viz.plot_ica_components`, :func:`~mne.viz.plot_tfr_topomap`, :func:`~mne.viz.plot_projs_topomap`, :meth:`~mne.Evoked.plot_image`, and :meth:`~mne.Epochs.plot_image` (:gh:`12057` by `Santeri Ruuskanen`_)

Bugs
~~~~
Expand Down
4 changes: 3 additions & 1 deletion mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,9 @@ def _plot_epochs_image(
this_colorbar = cbar(im, cax=ax["colorbar"])
this_colorbar.ax.set_ylabel(unit, rotation=270, labelpad=12)
if cmap[1]:
ax_im.CB = DraggableColorbar(this_colorbar, im)
ax_im.CB = DraggableColorbar(
this_colorbar, im, kind="epochs_image", ch_type=unit
)
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore")
tight_layout(fig=fig)
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ def _plot_image(
cbar = plt.colorbar(im, ax=ax)
cbar.ax.set_title(ch_unit)
if cmap[1]:
ax.CB = DraggableColorbar(cbar, im)
ax.CB = DraggableColorbar(cbar, im, "evoked_image", this_type)

ylabel = "Channels" if show_names else "Channel (index)"
t = titles[this_type] + " (%d channel%s" % (len(data), _pl(data)) + t_end
Expand Down
73 changes: 73 additions & 0 deletions mne/viz/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from mne.viz.utils import (
compare_fiff,
_fake_click,
_fake_keypress,
_fake_scroll,
_compute_scalings,
_validate_if_list_of_axes,
_get_color_list,
Expand All @@ -20,15 +22,18 @@
_make_event_color_dict,
concatenate_images,
)
from mne.viz.ui_events import link, subscribe, ColormapRange
from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap
from mne.io import read_raw_fif
from mne.event import read_events
from mne.epochs import Epochs
from mne import read_evokeds

base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data"
raw_fname = base_dir / "test_raw.fif"
cov_fname = base_dir / "test-cov.fif"
ev_fname = base_dir / "test_raw-eve.fif"
ave_fname = base_dir / "test-ave.fif"


def test_setup_vmin_vmax_warns():
Expand Down Expand Up @@ -202,3 +207,71 @@ def test_concatenate_images(a_w, a_h, b_w, b_h, axis):
else:
want_shape = (max(a_h, b_h), a_w + b_w, 3)
assert img.shape == want_shape


def test_draggable_colorbar():
"""Test that DraggableColorbar publishes correct UI Events."""
evokeds = read_evokeds(ave_fname)
left_auditory = evokeds[0]
right_auditory = evokeds[1]
vmin, vmax = -400, 400
fig = left_auditory.plot_topomap("interactive", vlim=(vmin, vmax))
fig2 = right_auditory.plot_topomap("interactive", vlim=(vmin, vmax))
link(fig, fig2)
callback_calls = []

def callback(event):
callback_calls.append(event)

subscribe(fig, "colormap_range", callback)

# Test that correct event is published
_fake_keypress(fig, "down")
_fake_keypress(fig, "up")
assert len(callback_calls) == 2
event = callback_calls.pop()
assert type(event) is ColormapRange
# Test that scrolling changes color limits
_fake_scroll(fig, 10, 10, 1)
event = callback_calls.pop()
assert abs(event.fmin) < abs(vmin)
assert abs(event.fmax) < abs(vmax)
fmin, fmax = event.fmin, event.fmax
_fake_scroll(fig, 10, 10, -1)
event = callback_calls.pop()
assert abs(event.fmin) > abs(fmin)
assert abs(event.fmax) > abs(fmax)
fmin, fmax = event.fmin, event.fmax
# Test that plus and minus change color limits
_fake_keypress(fig, "+")
event = callback_calls.pop()
assert abs(event.fmin) < abs(fmin)
assert abs(event.fmax) < abs(fmax)
fmin, fmax = event.fmin, event.fmax
_fake_keypress(fig, "-")
event = callback_calls.pop()
assert abs(event.fmin) > abs(fmin)
assert abs(event.fmax) > abs(fmax)
fmin, fmax = event.fmin, event.fmax
# Test that page up and page down change color limits
_fake_keypress(fig, "pageup")
event = callback_calls.pop()
assert event.fmin < fmin
assert event.fmax < fmax
fmin, fmax = event.fmin, event.fmax
_fake_keypress(fig, "pagedown")
event = callback_calls.pop()
assert event.fmin > fmin
assert event.fmax > fmax
# Test that space key resets color limits
_fake_keypress(fig, " ")
event = callback_calls.pop()
assert event.fmax == vmax
assert event.fmin == vmin
# Test that colormap change in one figure changes that of another one
cmap_want = fig.axes[0].CB.cycle[fig.axes[0].CB.index + 1]
cmap_old = fig.axes[0].CB.mappable.get_cmap().name
_fake_keypress(fig, "down")
cmap_new1 = fig.axes[0].CB.mappable.get_cmap().name
cmap_new2 = fig2.axes[0].CB.mappable.get_cmap().name
assert cmap_new1 == cmap_new2 == cmap_want != cmap_old
2 changes: 1 addition & 1 deletion mne/viz/topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def _imshow_tfr(
else:
cbar = plt.colorbar(mappable=img, ax=ax)
if interactive_cmap:
ax.CB = DraggableColorbar(cbar, img)
ax.CB = DraggableColorbar(cbar, img, kind="tfr_image", ch_type=None)
ax.RS = RectangleSelector(ax, onselect=onselect) # reference must be kept

return t_end
Expand Down
70 changes: 62 additions & 8 deletions mne/viz/topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,16 @@ def _plot_update_evoked_topomap(params, bools):


def _add_colorbar(
ax, im, cmap, side="right", pad=0.05, title=None, format=None, size="5%"
ax,
im,
cmap,
side="right",
pad=0.05,
title=None,
format=None,
size="5%",
kind=None,
ch_type=None,
):
"""Add a colorbar to an axis."""
import matplotlib.pyplot as plt
Expand All @@ -308,7 +317,7 @@ def _add_colorbar(
cax = divider.append_axes(side, size=size, pad=pad)
cbar = plt.colorbar(im, cax=cax, format=format)
if cmap is not None and cmap[1]:
ax.CB = DraggableColorbar(cbar, im)
ax.CB = DraggableColorbar(cbar, im, kind, ch_type)
if title is not None:
cax.set_title(title, y=1.05, fontsize=10)
return cbar, cax
Expand Down Expand Up @@ -587,7 +596,15 @@ def _plot_projs_topomap(
)

if colorbar:
_add_colorbar(ax, im, cmap, title=units, format=cbar_fmt)
_add_colorbar(
ax,
im,
cmap,
title=units,
format=cbar_fmt,
kind="projs_topomap",
ch_type=_ch_type,
)

return ax.get_figure()

Expand Down Expand Up @@ -973,7 +990,7 @@ def plot_topomap(
.. versionadded:: 0.20
%(res_topomap)s
%(size_topomap)s
%(cmap_topomap_simple)s
%(cmap_topomap)s
%(vlim_plot_topomap)s
.. versionadded:: 1.2
Expand Down Expand Up @@ -1454,7 +1471,16 @@ def _plot_ica_topomap(
ch_type=ch_type,
)[0]
if colorbar:
cbar, cax = _add_colorbar(axes, im, cmap, pad=0.05, title="AU", format="%3.2f")
cbar, cax = _add_colorbar(
axes,
im,
cmap,
pad=0.05,
title="AU",
format="%3.2f",
kind="ica_topomap",
ch_type=ch_type,
)
cbar.ax.tick_params(labelsize=12)
cbar.set_ticks(vlim)
_hide_frame(axes)
Expand Down Expand Up @@ -1685,7 +1711,15 @@ def plot_ica_components(
im.axes.set_label(ica._ica_names[ii])
if colorbar:
cbar, cax = _add_colorbar(
ax, im, cmap, title="AU", side="right", pad=0.05, format=cbar_fmt
ax,
im,
cmap,
title="AU",
side="right",
pad=0.05,
format=cbar_fmt,
kind="ica_comp_topomap",
ch_type=ch_type,
)
cbar.ax.tick_params(labelsize=12)
cbar.set_ticks(_vlim)
Expand Down Expand Up @@ -1956,7 +1990,15 @@ def plot_tfr_topomap(
from matplotlib import ticker

units = _handle_default("units", units)["misc"]
cbar, cax = _add_colorbar(axes, im, cmap, title=units, format=cbar_fmt)
cbar, cax = _add_colorbar(
axes,
im,
cmap,
title=units,
format=cbar_fmt,
kind="tfr_topomap",
ch_type=ch_type,
)
if locator is None:
locator = ticker.MaxNLocator(nbins=5)
cbar.locator = locator
Expand Down Expand Up @@ -2363,6 +2405,11 @@ def _slider_changed(val):
kwargs=kwargs,
),
)
subscribe(
fig,
"colormap_range",
partial(_on_colormap_range, kwargs=kwargs),
)

if colorbar:
if interactive:
Expand All @@ -2383,7 +2430,9 @@ def _slider_changed(val):
cbar.ax.tick_params(labelsize=7)
if cmap[1]:
for im in images:
im.axes.CB = DraggableColorbar(cbar, im)
im.axes.CB = DraggableColorbar(
cbar, im, kind="evoked_topomap", ch_type=ch_type
)

if proj == "interactive":
_check_delayed_ssp(evoked)
Expand Down Expand Up @@ -2460,6 +2509,11 @@ def _on_time_change(
ax.figure.canvas.draw_idle()


def _on_colormap_range(event, kwargs):
"""Handle updating colormap range."""
kwargs.update(vlim=(event.fmin, event.fmax), cmap=event.cmap)


def _plot_topomap_multi_cbar(
data,
pos,
Expand Down
16 changes: 15 additions & 1 deletion mne/viz/ui_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
"""
import contextlib
from dataclasses import dataclass
from typing import Optional, List
from typing import Optional, List, Union
import weakref
import re

from matplotlib.colors import Colormap

from ..utils import warn, fill_doc, _validate_type, logger, verbose

# Global dict {fig: channel} containing all currently active event channels.
Expand Down Expand Up @@ -114,26 +116,38 @@ class ColormapRange(UIEvent):
kind : str
Kind of colormap being updated. The Notes section of the drawing
routine publishing this event should mention the possible kinds.
ch_type : str
Type of sensor the data originates from.
%(fmin_fmid_fmax)s
%(alpha)s
cmap : str
The colormap to use. Either string or matplotlib.colors.Colormap
instance.
Attributes
----------
kind : str
Kind of colormap being updated. The Notes section of the drawing
routine publishing this event should mention the possible kinds.
ch_type : str
Type of sensor the data originates from.
unit : str
The unit of the values.
%(ui_event_name_source)s
%(fmin_fmid_fmax)s
%(alpha)s
cmap : str
The colormap to use. Either string or matplotlib.colors.Colormap
instance.
"""

kind: str
ch_type: Optional[str] = None
fmin: Optional[float] = None
fmid: Optional[float] = None
fmax: Optional[float] = None
alpha: Optional[bool] = None
cmap: Optional[Union[Colormap, str]] = None


@dataclass
Expand Down
Loading

0 comments on commit 242c18e

Please sign in to comment.