Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UI Event linking to DraggableColorbar #12057

Merged
merged 19 commits into from
Oct 5, 2023
Merged
4 changes: 3 additions & 1 deletion mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,9 @@ def _plot_epochs_image(
this_colorbar = cbar(im, cax=ax["colorbar"])
this_colorbar.ax.set_ylabel(unit, rotation=270, labelpad=12)
if cmap[1]:
ax_im.CB = DraggableColorbar(this_colorbar, im)
ax_im.CB = DraggableColorbar(
this_colorbar, im, kind="epochs_image", ch_type=unit
)
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore")
tight_layout(fig=fig)
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ def _plot_image(
cbar = plt.colorbar(im, ax=ax)
cbar.ax.set_title(ch_unit)
if cmap[1]:
ax.CB = DraggableColorbar(cbar, im)
ax.CB = DraggableColorbar(cbar, im, "image", this_type)
ruuskas marked this conversation as resolved.
Show resolved Hide resolved

ylabel = "Channels" if show_names else "Channel (index)"
t = titles[this_type] + " (%d channel%s" % (len(data), _pl(data)) + t_end
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def _imshow_tfr(
else:
cbar = plt.colorbar(mappable=img, ax=ax)
if interactive_cmap:
ax.CB = DraggableColorbar(cbar, img)
ax.CB = DraggableColorbar(cbar, img, kind=None, ch_type=None)
ruuskas marked this conversation as resolved.
Show resolved Hide resolved
ax.RS = RectangleSelector(ax, onselect=onselect) # reference must be kept

return t_end
Expand Down
60 changes: 52 additions & 8 deletions mne/viz/topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,16 @@ def _plot_update_evoked_topomap(params, bools):


def _add_colorbar(
ax, im, cmap, side="right", pad=0.05, title=None, format=None, size="5%"
ax,
im,
cmap,
side="right",
pad=0.05,
title=None,
format=None,
size="5%",
kind=None,
ch_type=None,
):
"""Add a colorbar to an axis."""
import matplotlib.pyplot as plt
Expand All @@ -308,7 +317,7 @@ def _add_colorbar(
cax = divider.append_axes(side, size=size, pad=pad)
cbar = plt.colorbar(im, cax=cax, format=format)
if cmap is not None and cmap[1]:
ax.CB = DraggableColorbar(cbar, im)
ax.CB = DraggableColorbar(cbar, im, kind, ch_type)
if title is not None:
cax.set_title(title, y=1.05, fontsize=10)
return cbar, cax
Expand Down Expand Up @@ -587,7 +596,15 @@ def _plot_projs_topomap(
)

if colorbar:
_add_colorbar(ax, im, cmap, title=units, format=cbar_fmt)
_add_colorbar(
ax,
im,
cmap,
title=units,
format=cbar_fmt,
kind="projs_topomap",
ch_type=_ch_type,
)

return ax.get_figure()

Expand Down Expand Up @@ -973,7 +990,7 @@ def plot_topomap(
.. versionadded:: 0.20
%(res_topomap)s
%(size_topomap)s
%(cmap_topomap_simple)s
%(cmap_topomap)s
%(vlim_plot_topomap)s

.. versionadded:: 1.2
Expand Down Expand Up @@ -1454,7 +1471,16 @@ def _plot_ica_topomap(
ch_type=ch_type,
)[0]
if colorbar:
cbar, cax = _add_colorbar(axes, im, cmap, pad=0.05, title="AU", format="%3.2f")
cbar, cax = _add_colorbar(
axes,
im,
cmap,
pad=0.05,
title="AU",
format="%3.2f",
kind="ica_topomap",
ch_type=ch_type,
)
cbar.ax.tick_params(labelsize=12)
cbar.set_ticks(vlim)
_hide_frame(axes)
Expand Down Expand Up @@ -1685,7 +1711,15 @@ def plot_ica_components(
im.axes.set_label(ica._ica_names[ii])
if colorbar:
cbar, cax = _add_colorbar(
ax, im, cmap, title="AU", side="right", pad=0.05, format=cbar_fmt
ax,
im,
cmap,
title="AU",
side="right",
pad=0.05,
format=cbar_fmt,
kind="ica_comp_topomap",
ch_type=ch_type,
)
cbar.ax.tick_params(labelsize=12)
cbar.set_ticks(_vlim)
Expand Down Expand Up @@ -1956,7 +1990,15 @@ def plot_tfr_topomap(
from matplotlib import ticker

units = _handle_default("units", units)["misc"]
cbar, cax = _add_colorbar(axes, im, cmap, title=units, format=cbar_fmt)
cbar, cax = _add_colorbar(
axes,
im,
cmap,
title=units,
format=cbar_fmt,
kind="tfr_topomap",
ch_type=ch_type,
)
if locator is None:
locator = ticker.MaxNLocator(nbins=5)
cbar.locator = locator
Expand Down Expand Up @@ -2383,7 +2425,9 @@ def _slider_changed(val):
cbar.ax.tick_params(labelsize=7)
if cmap[1]:
for im in images:
im.axes.CB = DraggableColorbar(cbar, im)
im.axes.CB = DraggableColorbar(
cbar, im, kind="evoked_topomap", ch_type=ch_type
)

if proj == "interactive":
_check_delayed_ssp(evoked)
Expand Down
9 changes: 8 additions & 1 deletion mne/viz/ui_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
"""
import contextlib
from dataclasses import dataclass
from typing import Optional, List
from typing import Optional, List, Union
import weakref
import re

from matplotlib.colors import Colormap

from ..utils import warn, fill_doc, _validate_type, logger, verbose

# Global dict {fig: channel} containing all currently active event channels.
Expand Down Expand Up @@ -122,18 +124,23 @@ class ColormapRange(UIEvent):
kind : str
Kind of colormap being updated. The Notes section of the drawing
routine publishing this event should mention the possible kinds.
sensor_type : str
ruuskas marked this conversation as resolved.
Show resolved Hide resolved
Type of sensor the data originates from.
unit : str
The unit of the values.
%(ui_event_name_source)s
%(fmin_fmid_fmax)s
%(alpha)s
%(colormap)s
"""

kind: str
ch_type: Optional[str] = None
fmin: Optional[float] = None
fmid: Optional[float] = None
fmax: Optional[float] = None
alpha: Optional[bool] = None
cmap: Optional[Union[Colormap, str]] = None


@dataclass
Expand Down
34 changes: 31 additions & 3 deletions mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
check_version,
_check_decim,
)
from .ui_events import publish, subscribe, ColormapRange
from ..transforms import apply_trans


Expand Down Expand Up @@ -1569,11 +1570,14 @@ class DraggableColorbar:
See http://www.ster.kuleuven.be/~pieterd/python/html/plotting/interactive_colorbar.html
""" # noqa: E501

def __init__(self, cbar, mappable):
def __init__(self, cbar, mappable, kind, ch_type):
import matplotlib.pyplot as plt

self.cbar = cbar
self.mappable = mappable
self.kind = kind
self.ch_type = ch_type
self.fig = self.cbar.ax.figure
self.press = None
self.cycle = sorted(
[i for i in dir(plt.cm) if hasattr(getattr(plt.cm, i), "N")]
Expand All @@ -1582,6 +1586,7 @@ def __init__(self, cbar, mappable):
self.index = self.cycle.index(mappable.get_cmap().name)
self.lims = (self.cbar.norm.vmin, self.cbar.norm.vmax)
self.connect()
subscribe(self.fig, "colormap_range", self._on_colormap_range)

def connect(self):
"""Connect to all the events we need."""
Expand Down Expand Up @@ -1640,7 +1645,7 @@ def key_press(self, event):
self.cbar.mappable.set_cmap(cmap)
_draw_without_rendering(self.cbar)
self.mappable.set_cmap(cmap)
self._update()
self._publish()

def on_motion(self, event):
"""Handle mouse movements."""
Expand All @@ -1659,7 +1664,7 @@ def on_motion(self, event):
elif event.button == 3:
self.cbar.norm.vmin -= (perc * scale) * np.sign(dy)
self.cbar.norm.vmax += (perc * scale) * np.sign(dy)
self._update()
self._publish()

def on_release(self, event):
"""Handle release."""
Expand All @@ -1671,8 +1676,31 @@ def on_scroll(self, event):
scale = 1.1 if event.step < 0 else 1.0 / 1.1
self.cbar.norm.vmin *= scale
self.cbar.norm.vmax *= scale
self._publish()

def _on_colormap_range(self, event):
if event.kind != self.kind or event.ch_type != self.ch_type:
return
if event.fmin is not None:
self.cbar.norm.vmin = event.fmin
if event.fmax is not None:
self.cbar.norm.vmax = event.fmax
if event.cmap is not None:
self.cbar.mappable.set_cmap(event.cmap)
self._update()

def _publish(self):
publish(
self.fig,
ColormapRange(
kind=self.kind,
ch_type=self.ch_type,
fmin=self.cbar.norm.vmin,
fmax=self.cbar.norm.vmax,
cmap=self.mappable.get_cmap(),
),
)

def _update(self):
from matplotlib.ticker import AutoLocator

Expand Down
Loading