Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF: Adds overlay option in OrthoSlicer3D #850

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions nibabel/spatialimages.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,14 @@ def __getitem__(self, idx):
"slicing image array data with `img.dataobj[slice]` or "
"`img.get_fdata()[slice]`")

def orthoview(self):
def orthoview(self, overlay=None, **kwargs):
"""Plot the image using OrthoSlicer3D

Parameters
----------
overlay : ``spatialimage`` instance
Image to be plotted as overlay. Default: None

Returns
-------
viewer : instance of OrthoSlicer3D
Expand All @@ -603,8 +608,12 @@ def orthoview(self):
consider using viewer.show() (equivalently plt.show()) to show
the figure.
"""
return OrthoSlicer3D(self.dataobj, self.affine,
title=self.get_filename())
ortho = OrthoSlicer3D(self.dataobj, self.affine,
title=self.get_filename())
if overlay is not None:
ortho.set_overlay(overlay, **kwargs)

return ortho

def as_reoriented(self, ornt):
"""Apply an orientation change and return a new image
Expand Down
149 changes: 148 additions & 1 deletion nibabel/viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __init__(self, data, affine=None, axes=None, title=None):
self._title = title
self._closed = False
self._cross = True
self._overlay = None
self._threshold = None
self._alpha = 1

data = np.asanyarray(data)
if data.ndim < 3:
Expand Down Expand Up @@ -285,6 +288,150 @@ def clim(self, clim):
self._clim = tuple(clim)
self.draw()

@property
def overlay(self):
"""The current overlay """
return self._overlay

@overlay.setter
def overlay(self, img):
if img is None:
self._remove_overlay()
else:
self.set_overlay(img)

@property
def threshold(self):
"""The current data display threshold """
return self._threshold

@threshold.setter
def threshold(self, threshold):
# mask data array
if threshold is not None:
self._data = np.ma.masked_array(np.asarray(self._data),
np.asarray(self._data) <= threshold)
self._threshold = float(threshold)
else:
self._data = np.asarray(self._data)
self._threshold = threshold

# update current volume data w/masked array and re-draw everything
if self._data.ndim > 3:
self._current_vol_data = self._data[..., self._data_idx[3]]
else:
self._current_vol_data = self._data
self._set_position(None, None, None, notify=False)

@property
def alpha(self):
""" The current alpha (transparency) value """
return self._alpha

@alpha.setter
def alpha(self, alpha):
alpha = float(alpha)
if alpha > 1 or alpha < 0:
raise ValueError('alpha must be between 0 and 1')
for im in self._ims:
im.set_alpha(alpha)
self._alpha = alpha
self.draw()

def set_overlay(self, data, affine=None, threshold=None, cmap='viridis',
alpha=0.7):
""" Sets `data` as overlay for currently plotted image

Parameters
----------
data : array-like
The data that will be overlayed on the slicer. Should have 3+
dimensions.
affine : array-like or None, optional
Affine transform for the provided data. This is used to determine
how the data should be sliced for plotting into the sagittal,
coronal, and axial view axes. If this does not match the currently
plotted slicer the provided data will be resampled.
threshold : float or None, optional
Threshold for overlay data; values below this threshold will not
be displayed. Default: None
cmap : str, optional
The Colormap instance or registered colormap name used to map
scalar data to colors. Default: 'viridis'
alpha : [0, 1] float, optional
Set the alpha value used for blending. Default: 0.7
"""
if affine is None:
try: # did we get an image?
affine = data.affine
data = data.dataobj
except AttributeError:
pass

# check that we have sufficient information to match the overlays
if affine is None and data.shape[:3] != self._data.shape[:3]:
raise ValueError('Provided `data` do not match shape of '
'underlay and no `affine` matrix was '
'provided. Please provide an `affine` matrix '
'or resample first three dims of `data` to {}'
.format(self._data.shape[:3]))

# we need to resample the provided data to the already-plotted data
if not np.allclose(affine, self._affine):
from .processing import resample_from_to
from .nifti1 import Nifti1Image
target_shape = self._data.shape[:3] + data.shape[3:]
# we can't just use SpatialImage because we need an image type
# where the spatial axes are _always_ first
data = resample_from_to(Nifti1Image(data, affine),
(target_shape, self._affine)).dataobj
affine = self._affine

# we already have a plotted overlay
if self._overlay is not None:
self._remove_overlay()

axes = self._axes
o_n_volumes = int(np.prod(data.shape[3:]))
# 3D underlay, 4D overlay
if o_n_volumes > self.n_volumes and self.n_volumes == 1:
axes += [axes[0].figure.add_subplot(224)]
# 4D underlay, 3D overlay
elif o_n_volumes < self.n_volumes and o_n_volumes == 1:
axes = axes[:-1]
# 4D underlay, 4D overlay
elif o_n_volumes > 1 and self.n_volumes > 1:
raise TypeError('Cannot set 4D overlay on top of 4D underlay')

# mask array for provided threshold
self._overlay = self.__class__(data, affine=affine, axes=axes)
self._overlay.threshold = threshold

# set transparency and new cmap
self._overlay.cmap = cmap
self._overlay.alpha = alpha

# no double cross-hairs (they get confused when we have linked orthos)
for cross in self._overlay._crosshairs:
cross['horiz'].set_visible(False)
cross['vert'].set_visible(False)
self._overlay._draw()

def _remove_overlay(self):
""" Removes current overlay image + associated axes """
# remove all images + cross hair lines
for nn, im in enumerate(self._overlay._ims):
im.remove()
for line in self._overlay._crosshairs[nn].values():
line.remove()
# remove the fourth axis, if it was created for the overlay
if (self._overlay.n_volumes > 1 and len(self._overlay._axes) > 3
and self.n_volumes == 1):
a = self._axes.pop(-1)
a.remove()

self._overlay = None

def link_to(self, other):
"""Link positional changes between two canvases

Expand Down Expand Up @@ -412,7 +559,7 @@ def _set_position(self, x, y, z, notify=True):
idx = [slice(None)] * len(self._axes)
for ii in range(3):
idx[self._order[ii]] = self._data_idx[ii]
vdata = self._data[tuple(idx)].ravel()
vdata = np.asarray(self._data[tuple(idx)].ravel())
vdata = np.concatenate((vdata, [vdata[-1]]))
self._volume_ax_objs['patch'].set_x(self._data_idx[3] - 0.5)
self._volume_ax_objs['step'].set_ydata(vdata)
Expand Down