Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add facet_col and animation_frame argument to imshow #2746

Merged
merged 31 commits into from
Dec 3, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
afb5c4d
use init_figure from main px core
emmanuelle Sep 3, 2020
8be8ca0
WIP: add facet_col arg to imshow
emmanuelle Sep 3, 2020
d236bc2
animations work for grayscale images, with or without binary string
emmanuelle Sep 4, 2020
c8e852e
animations now work + tests
emmanuelle Sep 5, 2020
12cec34
docs on facets and animations + add subplots titles
emmanuelle Sep 6, 2020
ab427ae
Merge branch 'master' into imshow-animation
emmanuelle Sep 7, 2020
7a3a9f4
solved old unnoticed conflict
emmanuelle Sep 7, 2020
b689a2f
attempt to use imshow with binary strings and xarrays
emmanuelle Sep 7, 2020
fbb3f65
added test
emmanuelle Sep 7, 2020
882810f
animation work for xarrays, still need to fix slider label
emmanuelle Sep 7, 2020
ba65990
added test with xarray and animations
emmanuelle Sep 7, 2020
cf644e5
added doc
emmanuelle Sep 7, 2020
72674b7
added pooch to doc requirements
emmanuelle Sep 7, 2020
bd42385
Update packages/python/plotly/plotly/express/_imshow.py
emmanuelle Sep 8, 2020
fc2375b
Update doc/python/imshow.md
emmanuelle Sep 8, 2020
a431fad
remove commented-out code
emmanuelle Sep 9, 2020
b652039
animation + facet kinda working now, but it broke labels
emmanuelle Sep 17, 2020
59c6622
added test
emmanuelle Sep 17, 2020
c7285a3
simplified code
emmanuelle Sep 17, 2020
91c066e
simplified code
emmanuelle Sep 17, 2020
ac5aa1f
polished code and added doc example
emmanuelle Sep 17, 2020
36b9f98
Merge branch 'imshow-animation' of https://github.com/plotly/plotly.p…
emmanuelle Sep 17, 2020
8cdc6af
updated doc
emmanuelle Nov 17, 2020
cf1c2b9
Merge branch 'master' into imshow-animation
emmanuelle Nov 18, 2020
5d1d8d8
add facet_col_spacing and facet_row_spacing
emmanuelle Nov 24, 2020
c27f88a
modify error message + animation_frame label
emmanuelle Nov 24, 2020
502fdfd
improve code readibility
emmanuelle Nov 24, 2020
135b01b
added example with sequence of images
emmanuelle Nov 24, 2020
6ac3e36
typoe
emmanuelle Nov 24, 2020
a5a2252
label names
emmanuelle Nov 27, 2020
77cb5cd
label name
emmanuelle Nov 30, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 70 additions & 6 deletions doc/python/imshow.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupyter:
extension: .md
format_name: markdown
format_version: '1.2'
jupytext_version: 1.4.2
jupytext_version: 1.3.0
kernelspec:
display_name: Python 3
language: python
Expand All @@ -20,7 +20,7 @@ jupyter:
name: python
nbconvert_exporter: python
pygments_lexer: ipython3
version: 3.7.7
version: 3.7.3
plotly:
description: How to display image data in Python with Plotly.
display_as: scientific
Expand Down Expand Up @@ -399,9 +399,73 @@ for compression_level in range(0, 9):
fig.show()
```

### Exploring 3-D images and timeseries with `facet_col`

*Introduced in plotly 4.11*

For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axes the image is sliced through to make the facets. With `facet_col_wrap` , one can set the maximum number of columns. For image datasets passed as xarrays, it is also possible to give an axis name as a string for `facet_col`.
emmanuelle marked this conversation as resolved.
Show resolved Hide resolved

It is recommended to use `binary_string=True` for facetted plots of images in order to keep a small figure size and a short rendering time.

See the [tutorial on facet plots](/python/facet-plots/) for more information on creating and styling facet plots.

```python
import plotly.express as px
from skimage import io
from skimage.data import image_fetcher
path = image_fetcher.fetch('data/cells.tif')
data = io.imread(path)
mkcor marked this conversation as resolved.
Show resolved Hide resolved
img = data[25:40]
fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5, height=700)
fig.show()
```

```python
import plotly.express as px
from skimage import io
from skimage.data import image_fetcher
path = image_fetcher.fetch('data/cells.tif')
data = io.imread(path)
img = data[25:40]
fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5)
# To have square facets one needs to unmatch axes
fig.update_xaxes(matches=None)
fig.update_yaxes(matches=None)
fig.show()
```

### Exploring 3-D images and timeseries with `animation_frame`

*Introduced in plotly 4.11*

For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by sliding through its different planes in an animation. The `animation_frame` argument of `px.imshow` sets the axis along which the 3-D image is sliced in the animation.

```python
import plotly.express as px
from skimage import io
from skimage.data import image_fetcher
path = image_fetcher.fetch('data/cells.tif')
data = io.imread(path)
img = data[25:40]
fig = px.imshow(img, animation_frame=0, binary_string=True)
fig.show()
```

### Animations of xarray datasets

*Introduced in plotly 4.11*
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*Introduced in plotly 4.11*
*Introduced in plotly 4.13*


For xarray datasets, one can pass either an axis number or an axis name to `animation_frame`. Axis names and coordinates are automatically used for the labels, ticks and animation controls of the figure.

```python
import plotly.express as px
import xarray as xr
# Load xarray from dataset included in the xarray tutorial
ds = xr.tutorial.open_dataset('air_temperature').air[:20]
fig = px.imshow(ds, animation_frame='time', zmin=220, zmax=300, color_continuous_scale='RdBu_r')
fig.show()
```

#### Reference
<<<<<<< HEAD
See https://plotly.com/python/reference/#image for more information and chart attribute options!
=======

See https://plotly.com/python/reference/image/ for more information and chart attribute options!
>>>>>>> doc-prod
1 change: 1 addition & 0 deletions doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ pyarrow
cufflinks==0.17.3
kaleido
umap-learn
pooch
165 changes: 136 additions & 29 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import plotly.graph_objs as go
from _plotly_utils.basevalidators import ColorscaleValidator
from ._core import apply_default_cascade
from ._core import apply_default_cascade, init_figure, configure_animation_controls
from io import BytesIO
import base64
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
Expand Down Expand Up @@ -133,6 +133,9 @@ def imshow(
labels={},
x=None,
y=None,
animation_frame=None,
facet_col=None,
facet_col_wrap=None,
nicolaskruchten marked this conversation as resolved.
Show resolved Hide resolved
color_continuous_scale=None,
color_continuous_midpoint=None,
range_color=None,
Expand Down Expand Up @@ -186,6 +189,14 @@ def imshow(
their lengths must match the lengths of the second and first dimensions of the
img argument. They are auto-populated if the input is an xarray.

facet_col: int, optional (default None)
axis number along which the image array is slices to create a facetted plot.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
axis number along which the image array is slices to create a facetted plot.
axis along which the image array is sliced to create a facetted plot.

I'm not entirely positive about my suggestion to remove 'number' in 'axis number'... From https://numpy.org/doc/stable/glossary.html, it looks like conventional terminology would be just 'axis' (as in 'axis 0' and 'axis 1'); I was tempted by 'axis index' but this would be confusing with dataframes (as in, 'index' vs 'columns'). Maybe 'axis number' is conventional terminology after all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis is fine indeed. It could also be "axis position" (as in the docstring of np.moveaxis). We can also ask other opinions, what do you think about the terminology @nicolaskruchten ?


facet_col_wrap: int
Maximum number of facet columns. Wraps the column variable at this width,
so that the column facets span multiple rows.
Ignored if `facet_col` is None.

color_continuous_scale : str or list of str
colormap used to map scalar data to colors (for a 2D image). This parameter is
not used for RGB or RGBA images. If a string is provided, it should be the name
Expand Down Expand Up @@ -277,15 +288,38 @@ def imshow(
args = locals()
apply_default_cascade(args)
labels = labels.copy()
nslices = 1
if facet_col is not None:
if isinstance(facet_col, str):
facet_col = img.dims.index(facet_col)
nslices = img.shape[facet_col]
ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices
nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
else:
nrows = 1
ncols = 1
if animation_frame is not None:
if isinstance(animation_frame, str):
animation_frame = img.dims.index(animation_frame)
nslices = img.shape[animation_frame]
slice_through = (facet_col is not None) or (animation_frame is not None)
slice_label = None
slices = range(nslices)
# ----- Define x and y, set labels if img is an xarray -------------------
if xarray_imported and isinstance(img, xarray.DataArray):
if binary_string:
raise ValueError(
"It is not possible to use binary image strings for xarrays."
"Please pass your data as a numpy array instead using"
"`img.values`"
)
y_label, x_label = img.dims[0], img.dims[1]
# if binary_string:
# raise ValueError(
# "It is not possible to use binary image strings for xarrays."
# "Please pass your data as a numpy array instead using"
# "`img.values`"
# )
dims = list(img.dims)
if slice_through:
slice_index = facet_col if facet_col is not None else animation_frame
slices = img.coords[img.dims[slice_index]].values
_ = dims.pop(slice_index)
slice_label = img.dims[slice_index]
y_label, x_label = dims[0], dims[1]
# np.datetime64 is not handled correctly by go.Heatmap
for ax in [x_label, y_label]:
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
Expand All @@ -300,6 +334,8 @@ def imshow(
labels["x"] = x_label
if labels.get("y", None) is None:
labels["y"] = y_label
if labels.get("slice", None) is None:
labels["slice"] = slice_label
if labels.get("color", None) is None:
labels["color"] = xarray.plot.utils.label_from_attrs(img)
labels["color"] = labels["color"].replace("\n", "<br>")
Expand Down Expand Up @@ -334,10 +370,22 @@ def imshow(

# --------------- Starting from here img is always a numpy array --------
img = np.asanyarray(img)
if facet_col is not None:
img = np.moveaxis(img, facet_col, 0)
facet_col = True
if animation_frame is not None:
img = np.moveaxis(img, animation_frame, 0)
animation_frame = True
args["animation_frame"] = (
"slice" if labels.get("slice") is None else labels["slice"]
)

# Default behaviour of binary_string: True for RGB images, False for 2D
if binary_string is None:
binary_string = img.ndim >= 3 and not is_dataframe
if slice_through:
binary_string = img.ndim >= 4 and not is_dataframe
else:
binary_string = img.ndim >= 3 and not is_dataframe

# Cast bools to uint8 (also one byte)
if img.dtype == np.bool:
Expand All @@ -349,7 +397,11 @@ def imshow(

# -------- Contrast rescaling: either minmax or infer ------------------
if contrast_rescaling is None:
contrast_rescaling = "minmax" if img.ndim == 2 else "infer"
contrast_rescaling = (
"minmax"
if (img.ndim == 2 or (img.ndim == 3 and slice_through))
else "infer"
)

# We try to set zmin and zmax only if necessary, because traces have good defaults
if contrast_rescaling == "minmax":
Expand All @@ -366,18 +418,26 @@ def imshow(
zmin = 0

# For 2d data, use Heatmap trace, unless binary_string is True
if img.ndim == 2 and not binary_string:
if y is not None and img.shape[0] != len(y):
if (img.ndim == 2 or (img.ndim == 3 and slice_through)) and not binary_string:
y_index = 1 if slice_through else 0
if y is not None and img.shape[y_index] != len(y):
raise ValueError(
"The length of the y vector must match the length of the first "
+ "dimension of the img matrix."
)
if x is not None and img.shape[1] != len(x):
x_index = 2 if slice_through else 1
if x is not None and img.shape[x_index] != len(x):
raise ValueError(
"The length of the x vector must match the length of the second "
+ "dimension of the img matrix."
)
trace = go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")
if slice_through:
traces = [
go.Heatmap(x=x, y=y, z=img_slice, coloraxis="coloraxis1", name=str(i))
for i, img_slice in enumerate(img)
]
else:
traces = [go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")]
autorange = True if origin == "lower" else "reversed"
layout = dict(yaxis=dict(autorange=autorange))
if aspect == "equal":
Expand All @@ -396,7 +456,9 @@ def imshow(
layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"])

# For 2D+RGB data, use Image trace
elif img.ndim == 3 and img.shape[-1] in [3, 4] or (img.ndim == 2 and binary_string):
elif (
img.ndim >= 3 and (img.shape[-1] in [3, 4] or slice_through and binary_string)
) or (img.ndim == 2 and binary_string):
rescale_image = True # to check whether image has been modified
if zmin is not None and zmax is not None:
zmin, zmax = (
Expand All @@ -407,40 +469,75 @@ def imshow(
if zmin is None and zmax is None: # no rescaling, faster
img_rescaled = img
rescale_image = False
elif img.ndim == 2:
elif img.ndim == 2 or (img.ndim == 3 and slice_through):
img_rescaled = rescale_intensity(
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
)
else:
img_rescaled = np.dstack(
img_rescaled = np.stack(
[
rescale_intensity(
img[..., ch],
in_range=(zmin[ch], zmax[ch]),
out_range=np.uint8,
)
for ch in range(img.shape[-1])
]
],
axis=-1,
)
img_str = _array_to_b64str(
img_rescaled,
backend=binary_backend,
compression=binary_compression_level,
ext=binary_format,
)
trace = go.Image(source=img_str)
if slice_through:
img_str = [
_array_to_b64str(
img_rescaled_slice,
backend=binary_backend,
compression=binary_compression_level,
ext=binary_format,
)
for img_rescaled_slice in img_rescaled
]

else:
img_str = [
_array_to_b64str(
img_rescaled,
backend=binary_backend,
compression=binary_compression_level,
ext=binary_format,
)
]
traces = [
go.Image(source=img_str_slice, name=str(i))
for i, img_str_slice in enumerate(img_str)
]
else:
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)
if slice_through:
traces = [
go.Image(z=img_slice, zmin=zmin, zmax=zmax, colormodel=colormodel)
for img_slice in img
]
else:
traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)]
layout = {}
if origin == "lower":
layout["yaxis"] = dict(autorange=True)
else:
raise ValueError(
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
"An image of shape %s was provided" % str(img.shape)
"An image of shape %s was provided"
emmanuelle marked this conversation as resolved.
Show resolved Hide resolved
"Alternatively, 3-D single or multichannel datasets can be"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3- or 4-D ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"visualized using the `facet_col` or `animation_frame` arguments."
% str(img.shape)
)

# Now build figure
col_labels = []
if facet_col is not None:
slice_label = "slice" if labels.get("slice") is None else labels["slice"]
if slices is None:
slices = range(nslices)
col_labels = ["%s = %d" % (slice_label, i) for i in slices]
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
layout_patch = dict()
for attr_name in ["height", "width"]:
if args[attr_name]:
Expand All @@ -449,7 +546,16 @@ def imshow(
layout_patch["title_text"] = args["title"]
elif args["template"].layout.margin.t is None:
layout_patch["margin"] = {"t": 60}
fig = go.Figure(data=trace, layout=layout)

frame_list = []
for index, (slice_index, trace) in enumerate(zip(slices, traces)):
if facet_col or index == 0:
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
if animation_frame:
frame_list.append(dict(data=trace, layout=layout, name=str(slice_index)))
if animation_frame:
fig.frames = frame_list
fig.update_layout(layout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit odd to have layout and layout_patch here... it was odd before I guess but might be worth looking at merging them together earlier?

fig.update_layout(layout_patch)
# Hover name, z or color
if binary_string and rescale_image and not np.all(img == img_rescaled):
Expand Down Expand Up @@ -479,5 +585,6 @@ def imshow(
fig.update_xaxes(title_text=labels["x"])
if labels["y"]:
fig.update_yaxes(title_text=labels["y"])
fig.update_layout(template=args["template"], overwrite=True)
configure_animation_controls(args, go.Image, fig)
# fig.update_layout(template=args["template"], overwrite=True)
return fig
Loading