Skip to content

Commit

Permalink
Merge pull request #826 from ioam/fix_aspect
Browse files Browse the repository at this point in the history
Fix for aspects of matplotlib LayoutPlot
  • Loading branch information
jlstevens authored Aug 22, 2016
2 parents c91b70b + e34c20f commit 20877f5
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 55 deletions.
5 changes: 4 additions & 1 deletion holoviews/plotting/mpl/chart3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def _finalize_axis(self, key, **kwargs):
return super(Plot3D, self)._finalize_axis(key, **kwargs)


def _draw_colorbar(self, artist, element, dim=None):
def _draw_colorbar(self, dim=None):
element = self.hmap.last
artist = self.handles.get('artist', None)

fig = self.handles['fig']
ax = self.handles['axis']
# Get colorbar label
Expand Down
14 changes: 8 additions & 6 deletions holoviews/plotting/mpl/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,21 +553,23 @@ def _adjust_cbar(self, cbar, label, dim):


def _finalize_artist(self, key):
element = self.hmap.last
artist = self.handles.get('artist', None)
if artist and self.colorbar:
self._draw_colorbar(artist, element)
self._draw_colorbar()


def _draw_colorbar(self, artist, element, dim=None):
def _draw_colorbar(self, dim=None, redraw=True):
element = self.hmap.last
artist = self.handles.get('artist', None)
fig = self.handles['fig']
axis = self.handles['axis']
ax_colorbars, position = ColorbarPlot._colorbars.get(id(axis), ([], None))
specs = [spec[:2] for _, _, spec, _ in ax_colorbars]
spec = util.get_spec(element)

if position is None:
fig.canvas.draw()
if position is None or not redraw:
if redraw:
fig.canvas.draw()
bbox = axis.get_position()
l, b, w, h = bbox.x0, bbox.y0, bbox.width, bbox.height
else:
Expand All @@ -594,7 +596,7 @@ def _draw_colorbar(self, artist, element, dim=None):
self.handles['bbox_extra_artists'] += [cax, ylabel]
ax_colorbars.append((artist, cax, spec, label))

for i, (artist, cax, spec, label) in enumerate(ax_colorbars[:-1]):
for i, (artist, cax, spec, label) in enumerate(ax_colorbars):
scaled_w = w*width
cax.set_position([l+w+padding+(scaled_w+padding+w*0.15)*i,
b, scaled_w, h])
Expand Down
2 changes: 1 addition & 1 deletion holoviews/plotting/mpl/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def init_artists(self, ax, plot_args, plot_kwargs):
collection = PatchCollection(*plot_args, **plot_kwargs)
ax.add_collection(collection)
if self.colorbar:
self._draw_colorbar(collection, self.current_frame)
self._draw_colorbar()
return {'artist': collection, 'polys': plot_args[0]}


Expand Down
50 changes: 38 additions & 12 deletions holoviews/plotting/mpl/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import division

from itertools import chain

import numpy as np
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D # noqa (For 3D plots)
Expand All @@ -15,7 +17,7 @@
from ..plot import DimensionedPlot, GenericLayoutPlot, GenericCompositePlot
from ..util import get_dynamic_mode, initialize_sampled
from .renderer import MPLRenderer
from .util import compute_ratios
from .util import compute_ratios, fix_aspect


class MPLPlot(DimensionedPlot):
Expand Down Expand Up @@ -54,7 +56,7 @@ class MPLPlot(DimensionedPlot):
fig_rcparams = param.Dict(default={}, doc="""
matplotlib rc parameters to apply to the overall figure.""")

fig_size = param.Integer(default=100, bounds=(1, None), doc="""
fig_size = param.Number(default=100., bounds=(1, None), doc="""
Size relative to the supplied overall fig_inches in percent.""")

initial_hooks = param.HookList(default=[], doc="""
Expand Down Expand Up @@ -97,12 +99,12 @@ def __init__(self, fig=None, axis=None, **params):
self._create_fig = True
super(MPLPlot, self).__init__(**params)
# List of handles to matplotlib objects for animation update
scale = self.fig_size/100.
self.fig_scale = self.fig_size/100.
if isinstance(self.fig_inches, (tuple, list)):
self.fig_inches = [None if i is None else i*scale
self.fig_inches = [None if i is None else i*self.fig_scale
for i in self.fig_inches]
else:
self.fig_inches *= scale
self.fig_inches *= self.fig_scale
fig, axis = self._init_axis(fig, axis)
self.handles['fig'] = fig
self.handles['axis'] = axis
Expand Down Expand Up @@ -617,7 +619,7 @@ def initialize_plot(self, ranges=None):
self.drawn = True


def adjust_positions(self):
def adjust_positions(self, redraw=True):
"""
Make adjustments to the positions of subplots (if available)
relative to the main plot axes as required.
Expand All @@ -631,7 +633,8 @@ def adjust_positions(self):
top = all('top' in check for check in checks)
if not 'main' in self.subplots or not (top or right):
return
self.handles['fig'].canvas.draw()
if redraw:
self.handles['fig'].canvas.draw()
main_ax = self.subplots['main'].handles['axis']
bbox = main_ax.get_position()
if right:
Expand Down Expand Up @@ -712,6 +715,11 @@ class LayoutPlot(GenericLayoutPlot, CompositePlot):

fontsize = param.Parameter(default={'title':16}, allow_None=True)

# Whether to enable fix for non-square figures
# Will be enabled by default in v1.7
# If enabled default vspace should be increased to 0.3
v17_layout_format = False

def __init__(self, layout, **params):
super(LayoutPlot, self).__init__(layout=layout, **params)
self.subplots, self.subaxes, self.layout = self._compute_gridspec(layout)
Expand Down Expand Up @@ -1025,12 +1033,30 @@ def initialize_plot(self):
subplot.initialize_plot(ranges=ranges)

# Create title handle
if self.show_title and len(self.coords) > 1:
title = self._format_title(key)
title = self.handles['fig'].suptitle(title, **self._fontsize('title'))
self.handles['title'] = title
self.handles['bbox_extra_artists'] += [title]
title_obj = None
title = self._format_title(key)
if self.show_title and len(self.coords) > 1 and title:
title_obj = self.handles['fig'].suptitle(title, **self._fontsize('title'))
self.handles['title'] = title_obj
self.handles['bbox_extra_artists'] += [title_obj]

fig = self.handles['fig']
if (not self.traverse(specs=[GridPlot]) and not isinstance(self.fig_inches, tuple)
and self.v17_layout_format):
traverse_fn = lambda x: x.handles.get('bbox_extra_artists', None)
extra_artists = list(chain(*[artists for artists in self.traverse(traverse_fn)
if artists is not None]))
aspect = fix_aspect(fig, self.rows, self.cols,
title_obj, extra_artists,
vspace=self.vspace*self.fig_scale,
hspace=self.hspace*self.fig_scale)
colorbars = self.traverse(specs=[lambda x: hasattr(x, 'colorbar')])
for cbar_plot in colorbars:
if cbar_plot.colorbar:
cbar_plot._draw_colorbar(redraw=False)
adjoined = self.traverse(specs=[AdjointLayoutPlot])
for adjoined in adjoined:
adjoined.adjust_positions(redraw=False)
return self._finalize_axis(None)


Expand Down
40 changes: 5 additions & 35 deletions holoviews/plotting/mpl/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ..renderer import Renderer, MIME_TYPES
from .widgets import MPLSelectionWidget, MPLScrubberWidget
from .util import get_tight_bbox

class OutputWarning(param.Parameterized):pass
outputwarning = OutputWarning(name='Warning')
Expand Down Expand Up @@ -121,15 +122,9 @@ def plot_options(cls, obj, percent_size):
factor = percent_size / 100.0
obj = obj.last if isinstance(obj, HoloMap) else obj
options = Store.lookup_options(cls.backend, obj, 'plot').options
fig_inches = options.get('fig_inches', MPLPlot.fig_inches)
fig_size = options.get('fig_size', MPLPlot.fig_size)*factor

if isinstance(fig_inches, (list, tuple)):
fig_inches = (None if fig_inches[0] is None else fig_inches[0] * factor,
None if fig_inches[1] is None else fig_inches[1] * factor)
else:
fig_inches = MPLPlot.fig_inches * factor

return dict({'fig_inches':fig_inches},
return dict({'fig_size':fig_size},
**Store.lookup_options(cls.backend, obj, 'plot').options)


Expand Down Expand Up @@ -233,34 +228,9 @@ def _compute_bbox(self, fig, kw):
if not fig_id in MPLRenderer.drawn:
fig.set_dpi(self.dpi)
fig.canvas.draw()
renderer = fig._cachedRenderer
bbox_inches = fig.get_tightbbox(renderer)
bbox_artists = kw.pop("bbox_extra_artists", [])
bbox_artists += fig.get_default_bbox_extra_artists()
bbox_filtered = []
for a in bbox_artists:
bbox = a.get_window_extent(renderer)
if isinstance(bbox, tuple):
continue
if a.get_clip_on():
clip_box = a.get_clip_box()
if clip_box is not None:
bbox = Bbox.intersection(bbox, clip_box)
clip_path = a.get_clip_path()
if clip_path is not None and bbox is not None:
clip_path = clip_path.get_fully_transformed_path()
bbox = Bbox.intersection(bbox,
clip_path.get_extents())
if bbox is not None and (bbox.width != 0 or
bbox.height != 0):
bbox_filtered.append(bbox)
if bbox_filtered:
_bbox = Bbox.union(bbox_filtered)
trans = Affine2D().scale(1.0 / self.dpi)
bbox_extra = TransformedBbox(_bbox, trans)
bbox_inches = Bbox.union([bbox_inches, bbox_extra])
extra_artists = kw.pop("bbox_extra_artists", [])
pad = plt.rcParams['savefig.pad_inches']
bbox_inches = bbox_inches.padded(pad)
bbox_inches = get_tight_bbox(fig, extra_artists, pad=pad)
MPLRenderer.drawn[fig_id] = bbox_inches
kw['bbox_inches'] = bbox_inches
else:
Expand Down
106 changes: 106 additions & 0 deletions holoviews/plotting/mpl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from matplotlib import ticker
from matplotlib.transforms import Bbox, TransformedBbox, Affine2D

from ...core.util import basestring

Expand Down Expand Up @@ -58,3 +59,108 @@ def compute_ratios(ratios, normalized=True):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', r'All-NaN (slice|axis) encountered')
return np.nanmax(np.vstack([v for _, v in sorted_ratios]), axis=0)


def axis_overlap(ax1, ax2):
"""
Tests whether two axes overlap vertically
"""
b1, t1 = ax1.get_position().intervaly
b2, t2 = ax2.get_position().intervaly
return t1 > b2 and b1 < t2


def resolve_rows(rows):
"""
Recursively iterate over lists of axes merging
them by their vertical overlap leaving a list
of rows.
"""
merged_rows = []
for row in rows:
overlap = False
for mrow in merged_rows:
if any(axis_overlap(ax1, ax2) for ax1 in row
for ax2 in mrow):
mrow += row
overlap = True
break
if not overlap:
merged_rows.append(row)
if rows == merged_rows:
return rows
else:
return resolve_rows(merged_rows)


def fix_aspect(fig, nrows, ncols, title=None, extra_artists=[],
vspace=0.2, hspace=0.2):
"""
Calculate heights and widths of axes and adjust
the size of the figure to match the aspect.
"""
fig.canvas.draw()
w, h = fig.get_size_inches()

# Compute maximum height and width of each row and columns
rows = resolve_rows([[ax] for ax in fig.axes])
rs, cs = len(rows), max([len(r) for r in rows])
heights = [[] for i in range(cs)]
widths = [[] for i in range(rs)]
for r, row in enumerate(rows):
for c, ax in enumerate(row):
bbox = ax.get_tightbbox(fig.canvas.renderer)
heights[c].append(bbox.height)
widths[r].append(bbox.width)
height = (max([sum(c) for c in heights])) + nrows*vspace*fig.dpi
width = (max([sum(r) for r in widths])) + ncols*hspace*fig.dpi

# Compute aspect and set new size (in inches)
aspect = height/width
offset = 0
if title and title.get_text():
offset = title.get_window_extent().height/fig.dpi
fig.set_size_inches(w, (w*aspect)+offset)

# Redraw and adjust title position if defined
fig.canvas.draw()
if title and title.get_text():
extra_artists = [a for a in extra_artists
if a is not title]
bbox = get_tight_bbox(fig, extra_artists)
top = bbox.intervaly[1]
if title and title.get_text():
title.set_y((top/(w*aspect)))


def get_tight_bbox(fig, bbox_extra_artists=[], pad=None):
"""
Compute a tight bounding box around all the artists in the figure.
"""
renderer = fig._cachedRenderer
bbox_inches = fig.get_tightbbox(renderer)
bbox_artists = bbox_extra_artists[:]
bbox_artists += fig.get_default_bbox_extra_artists()
bbox_filtered = []
for a in bbox_artists:
bbox = a.get_window_extent(renderer)
if isinstance(bbox, tuple):
continue
if a.get_clip_on():
clip_box = a.get_clip_box()
if clip_box is not None:
bbox = Bbox.intersection(bbox, clip_box)
clip_path = a.get_clip_path()
if clip_path is not None and bbox is not None:
clip_path = clip_path.get_fully_transformed_path()
bbox = Bbox.intersection(bbox,
clip_path.get_extents())
if bbox is not None and (bbox.width != 0 or
bbox.height != 0):
bbox_filtered.append(bbox)
if bbox_filtered:
_bbox = Bbox.union(bbox_filtered)
trans = Affine2D().scale(1.0 / fig.dpi)
bbox_extra = TransformedBbox(_bbox, trans)
bbox_inches = Bbox.union([bbox_inches, bbox_extra])
return bbox_inches.padded(pad) if pad else bbox_inches

0 comments on commit 20877f5

Please sign in to comment.