diff --git a/holoviews/plotting/mpl/plot.py b/holoviews/plotting/mpl/plot.py index 2d24f82ea5..cac160754b 100644 --- a/holoviews/plotting/mpl/plot.py +++ b/holoviews/plotting/mpl/plot.py @@ -1,5 +1,7 @@ from __future__ import division +from itertools import chain + import numpy as np import matplotlib as mpl from mpl_toolkits.mplot3d import Axes3D # noqa (For 3D plots) @@ -15,7 +17,7 @@ from ..plot import DimensionedPlot, GenericLayoutPlot, GenericCompositePlot from ..util import get_dynamic_mode, initialize_sampled from .renderer import MPLRenderer -from .util import compute_ratios +from .util import compute_ratios, fix_aspect class MPLPlot(DimensionedPlot): @@ -617,7 +619,7 @@ def initialize_plot(self, ranges=None): self.drawn = True - def adjust_positions(self): + def adjust_positions(self, redraw=True): """ Make adjustments to the positions of subplots (if available) relative to the main plot axes as required. @@ -631,7 +633,8 @@ def adjust_positions(self): top = all('top' in check for check in checks) if not 'main' in self.subplots or not (top or right): return - self.handles['fig'].canvas.draw() + if redraw: + self.handles['fig'].canvas.draw() main_ax = self.subplots['main'].handles['axis'] bbox = main_ax.get_position() if right: @@ -695,6 +698,10 @@ class LayoutPlot(GenericLayoutPlot, CompositePlot): (left, bottom, right, top), defining the size of the border around the subplots.""") + fix_aspect = param.Boolean(default=False, doc="""Apply a fix to the + figure aspect to take into account non-square plots (will be the + default in future versions""") + tight = param.Boolean(default=False, doc=""" Tightly fit the axes in the layout within the fig_bounds and tight_padding.""") @@ -706,7 +713,7 @@ class LayoutPlot(GenericLayoutPlot, CompositePlot): Specifies the space between horizontally adjacent elements in the grid. Default value is set conservatively to avoid overlap of subplots.""") - vspace = param.Number(default=0.1, doc=""" + vspace = param.Number(default=0.3, doc=""" Specifies the space between vertically adjacent elements in the grid. Default value is set conservatively to avoid overlap of subplots.""") @@ -1025,12 +1032,29 @@ def initialize_plot(self): subplot.initialize_plot(ranges=ranges) # Create title handle + title = None if self.show_title and len(self.coords) > 1: title = self._format_title(key) title = self.handles['fig'].suptitle(title, **self._fontsize('title')) self.handles['title'] = title self.handles['bbox_extra_artists'] += [title] + title = self.handles['title'] + fig = self.handles['fig'] + if (not self.traverse(specs=[GridPlot]) and not isinstance(self.fig_inches, tuple) + and self.fix_aspect): + traverse_fn = lambda x: x.handles.get('bbox_extra_artists', None) + extra_artists = list(chain(*[artists for artists in self.traverse(traverse_fn) + if artists is not None])) + aspect = fix_aspect(fig, title, extra_artists, vspace=self.vspace, + hspace=self.hspace) + colorbars = self.traverse(specs=[lambda x: hasattr(x, 'colorbar')]) + for cbar_plot in colorbars: + if cbar_plot.colorbar: + cbar_plot._draw_colorbar(redraw=False) + adjoined = self.traverse(specs=[AdjointLayoutPlot]) + for adjoined in adjoined: + adjoined.adjust_positions(redraw=False) return self._finalize_axis(None) diff --git a/holoviews/plotting/mpl/renderer.py b/holoviews/plotting/mpl/renderer.py index 0eb0f432d6..923ed1277f 100644 --- a/holoviews/plotting/mpl/renderer.py +++ b/holoviews/plotting/mpl/renderer.py @@ -19,6 +19,7 @@ from ..renderer import Renderer, MIME_TYPES from .widgets import MPLSelectionWidget, MPLScrubberWidget +from .util import get_tight_bbox class OutputWarning(param.Parameterized):pass outputwarning = OutputWarning(name='Warning') @@ -233,34 +234,9 @@ def _compute_bbox(self, fig, kw): if not fig_id in MPLRenderer.drawn: fig.set_dpi(self.dpi) fig.canvas.draw() - renderer = fig._cachedRenderer - bbox_inches = fig.get_tightbbox(renderer) - bbox_artists = kw.pop("bbox_extra_artists", []) - bbox_artists += fig.get_default_bbox_extra_artists() - bbox_filtered = [] - for a in bbox_artists: - bbox = a.get_window_extent(renderer) - if isinstance(bbox, tuple): - continue - if a.get_clip_on(): - clip_box = a.get_clip_box() - if clip_box is not None: - bbox = Bbox.intersection(bbox, clip_box) - clip_path = a.get_clip_path() - if clip_path is not None and bbox is not None: - clip_path = clip_path.get_fully_transformed_path() - bbox = Bbox.intersection(bbox, - clip_path.get_extents()) - if bbox is not None and (bbox.width != 0 or - bbox.height != 0): - bbox_filtered.append(bbox) - if bbox_filtered: - _bbox = Bbox.union(bbox_filtered) - trans = Affine2D().scale(1.0 / self.dpi) - bbox_extra = TransformedBbox(_bbox, trans) - bbox_inches = Bbox.union([bbox_inches, bbox_extra]) + extra_artists = kw.pop("bbox_extra_artists", []) pad = plt.rcParams['savefig.pad_inches'] - bbox_inches = bbox_inches.padded(pad) + bbox_inches = get_tight_bbox(fig, extra_artists, pad=pad) MPLRenderer.drawn[fig_id] = bbox_inches kw['bbox_inches'] = bbox_inches else: diff --git a/holoviews/plotting/mpl/util.py b/holoviews/plotting/mpl/util.py index 15dd388912..4343eb276e 100644 --- a/holoviews/plotting/mpl/util.py +++ b/holoviews/plotting/mpl/util.py @@ -4,6 +4,7 @@ import numpy as np from matplotlib import ticker +from matplotlib.transforms import Bbox, TransformedBbox, Affine2D from ...core.util import basestring @@ -58,3 +59,104 @@ def compute_ratios(ratios, normalized=True): with warnings.catch_warnings(): warnings.filterwarnings('ignore', r'All-NaN (slice|axis) encountered') return np.nanmax(np.vstack([v for _, v in sorted_ratios]), axis=0) + + +def axis_overlap(ax1, ax2): + """ + Tests whether two axes overlap vertically + """ + b1, t1 = ax1.get_position().intervaly + b2, t2 = ax2.get_position().intervaly + return t1 >= b2 and b1 <= t2 + + +def resolve_rows(rows): + """ + Recursively iterate over lists of axes merging + them by their vertical overlap leaving a list + of rows. + """ + merged_rows = [] + for row in rows: + overlap = False + for mrow in merged_rows: + if any(axis_overlap(ax1, ax2) for ax1 in row + for ax2 in mrow): + mrow += row + overlap = True + break + if not overlap: + merged_rows.append(row) + if rows == merged_rows: + return rows + else: + return resolve_rows(merged_rows) + + +def fix_aspect(fig, title=None, extra_artists=[], vspace=0.2, hspace=0.2): + """ + Calculate heights and widths of axes and adjust + the size of the figure to match the aspect. + """ + fig.canvas.draw() + w, h = fig.get_size_inches() + + # Compute maximum height and width of each row and columns + rows = resolve_rows([[ax] for ax in fig.axes]) + rs, cs = len(rows), max([len(r) for r in rows]) + heights = [[] for i in range(cs)] + widths = [[] for i in range(rs)] + for r, row in enumerate(rows): + for c, ax in enumerate(row): + bbox = ax.get_tightbbox(fig.canvas.renderer) + heights[c].append(bbox.height) + widths[r].append(bbox.width) + height = (max([sum(c) for c in heights])) + (rs+1)*vspace + width = (max([sum(r) for r in widths])) + (cs+1)*hspace + + # Compute aspect and set new size (in inches) + aspect = height/width + offset = 0.2 if title and title.get_text() else 0 + fig.set_size_inches(w, (w*aspect)+offset) + + # Redraw and adjust title position if defined + fig.canvas.draw() + if title and title.get_text(): + bbox = get_tight_bbox(fig, extra_artists) + top = bbox.intervaly[1] + extra_artists = [a for a in extra_artists if a is not title] + if title and title.get_text(): + title.set_y((top/(w*aspect))) + + +def get_tight_bbox(fig, bbox_extra_artists=[], pad=None): + """ + Compute a tight bounding box around all the artists in the figure. + """ + renderer = fig._cachedRenderer + bbox_inches = fig.get_tightbbox(renderer) + bbox_artists = bbox_extra_artists[:] + bbox_artists += fig.get_default_bbox_extra_artists() + bbox_filtered = [] + for a in bbox_artists: + bbox = a.get_window_extent(renderer) + if isinstance(bbox, tuple): + continue + if a.get_clip_on(): + clip_box = a.get_clip_box() + if clip_box is not None: + bbox = Bbox.intersection(bbox, clip_box) + clip_path = a.get_clip_path() + if clip_path is not None and bbox is not None: + clip_path = clip_path.get_fully_transformed_path() + bbox = Bbox.intersection(bbox, + clip_path.get_extents()) + if bbox is not None and (bbox.width != 0 or + bbox.height != 0): + bbox_filtered.append(bbox) + if bbox_filtered: + _bbox = Bbox.union(bbox_filtered) + trans = Affine2D().scale(1.0 / fig.dpi) + bbox_extra = TransformedBbox(_bbox, trans) + bbox_inches = Bbox.union([bbox_inches, bbox_extra]) + return bbox_inches.padded(pad) if pad else bbox_inches