Skip to content

Commit

Permalink
Add Grid.figure to (eventually) replace Grid.fig (#2639)
Browse files Browse the repository at this point in the history
* Add Grid.figure to (eventually) replace Grid.fig

The advantage of this change is that you will be able to do

    obj = sns.<func>()
    obj.figure.<method>

regardless of whether  is figure-level or axes-level.

The `Grid.fig` methods are being soft-deprecated: discouraged in the
attribute documentation but not (currently) issuing a warning when used.

* Update docs, tests, and release notes
  • Loading branch information
mwaskom authored Aug 14, 2021
1 parent 473ebef commit 091f4c0
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 83 deletions.
2 changes: 1 addition & 1 deletion doc/docstrings/FacetGrid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@
"source": [
"g = sns.FacetGrid(tips, col=\"sex\", row=\"time\", margin_titles=True, despine=False)\n",
"g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\")\n",
"g.fig.subplots_adjust(wspace=0, hspace=0)\n",
"g.figure.subplots_adjust(wspace=0, hspace=0)\n",
"for (row_val, col_val), ax in g.axes_dict.items():\n",
" if row_val == \"Lunch\" and col_val == \"Female\":\n",
" ax.set_facecolor(\".95\")\n",
Expand Down
2 changes: 1 addition & 1 deletion doc/introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@
")\n",
"g.set_axis_labels(\"Bill length (mm)\", \"Bill depth (mm)\", labelpad=10)\n",
"g.legend.set_title(\"Body mass (g)\")\n",
"g.fig.set_size_inches(6.5, 4.5)\n",
"g.figure.set_size_inches(6.5, 4.5)\n",
"g.ax.margins(.15)\n",
"g.despine(trim=True)"
]
Expand Down
8 changes: 5 additions & 3 deletions doc/releases/v0.11.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ This is a minor release that addresses issues in the v0.11 series and adds a sma

- |Feature| In :func:`kdeplot`, added the `warn_singular` parameter to silence the warning about data with zero variance (:pr:`2566`).

- |Enhancement| In :class:`FacetGrid` and functions that use it, visibility of the interior axis labels is now disabled, and exterior axis labels are no longer erased when adding additional layers. This produces the same results for plots made by seaborn functions, but it may produce different (better, in most cases) results for customized facet plots (:pr:`2583`).

- |Enhancement| In :func:`histplot`, improved performance with large datasets and many groupings/facets (:pr:`2559`, :pr:`2570`).

- |Enhancement| The :class:`FacetGrid`, :class:`PairGrid`, and :class:`JointGrid` objects now reference the underlying matplotlib figure with a `.figure` attribute. The existing `.fig` attribute still exists but is discouraged and may eventually be deprecated. The effect is that you can now call `obj.figure` on the return value from any seaborn function to access the matplotlib object (:pr:`2639`).

- |Enhancement| In :class:`FacetGrid` and functions that use it, visibility of the interior axis labels is now disabled, and exterior axis labels are no longer erased when adding additional layers. This produces the same results for plots made by seaborn functions, but it may produce different (better, in most cases) results for customized facet plots (:pr:`2583`).

- |Enhancement| In :class:`FacetGrid`, :class:`PairGrid`, and functions that use them, the matplotlib `figure.autolayout` parameter is disabled to avoid having the legend overlap the plot (:pr:`2571`).

- |Enhancement| The :func:`load_dataset` helper now produces a more informative error when fed a dataframe, easing a common beginner mistake (:pr:`2604`).
Expand Down Expand Up @@ -54,4 +56,4 @@ This is a minor release that addresses issues in the v0.11 series and adds a sma

- |Fix| Fixed an issue that prevented Python from running in `-OO` mode while using seaborn (:pr:`2473`).

- |Docs| Improved the API documentation for theme-related functions (:pr:`2573`).
- |Docs| Improved the API documentation for theme-related functions (:pr:`2573`).
6 changes: 3 additions & 3 deletions doc/tutorial/axis_grids.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,14 @@
"g.map(sns.scatterplot, \"total_bill\", \"tip\", color=\"#334488\")\n",
"g.set_axis_labels(\"Total bill (US Dollars)\", \"Tip\")\n",
"g.set(xticks=[10, 30, 50], yticks=[2, 6, 10])\n",
"g.fig.subplots_adjust(wspace=.02, hspace=.02)"
"g.figure.subplots_adjust(wspace=.02, hspace=.02)"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"For even more customization, you can work directly with the underling matplotlib ``Figure`` and ``Axes`` objects, which are stored as member attributes at ``fig`` and ``axes`` (a two-dimensional array), respectively. When making a figure without row or column faceting, you can also use the ``ax`` attribute to directly access the single axes."
"For even more customization, you can work directly with the underling matplotlib ``Figure`` and ``Axes`` objects, which are stored as member attributes at ``figure`` and ``axes_dict``, respectively. When making a figure without row or column faceting, you can also use the ``ax`` attribute to directly access the single axes."
]
},
{
Expand All @@ -263,7 +263,7 @@
"source": [
"g = sns.FacetGrid(tips, col=\"smoker\", margin_titles=True, height=4)\n",
"g.map(plt.scatter, \"total_bill\", \"tip\", color=\"#338844\", edgecolor=\"white\", s=50, lw=1)\n",
"for ax in g.axes.flat:\n",
"for ax in g.axes_dict.values():\n",
" ax.axline((0, 0), slope=.2, c=\".2\", ls=\"--\", zorder=0)\n",
"g.set(xlim=(0, 60), ylim=(0, 14))"
]
Expand Down
4 changes: 2 additions & 2 deletions doc/tutorial/relational.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@
"df = pd.DataFrame(dict(time=np.arange(500),\n",
" value=np.random.randn(500).cumsum()))\n",
"g = sns.relplot(x=\"time\", y=\"value\", kind=\"line\", data=df)\n",
"g.fig.autofmt_xdate()"
"g.figure.autofmt_xdate()"
]
},
{
Expand Down Expand Up @@ -519,7 +519,7 @@
"df = pd.DataFrame(dict(time=pd.date_range(\"2017-1-1\", periods=500),\n",
" value=np.random.randn(500).cumsum()))\n",
"g = sns.relplot(x=\"time\", y=\"value\", kind=\"line\", data=df)\n",
"g.fig.autofmt_xdate()"
"g.figure.autofmt_xdate()"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/joint_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
g.ax_joint.set(yscale="log")

# Create an inset legend for the histogram colorbar
cax = g.fig.add_axes([.15, .55, .02, .2])
cax = g.figure.add_axes([.15, .55, .02, .2])

# Add the joint and marginal histogram plots
g.plot_joint(
Expand Down
2 changes: 1 addition & 1 deletion examples/kde_ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def label(x, color, label):
g.map(label, "x")

# Set the subplots to overlap
g.fig.subplots_adjust(hspace=-.25)
g.figure.subplots_adjust(hspace=-.25)

# Remove axes details that don't play well with overlap
g.set_titles("")
Expand Down
94 changes: 52 additions & 42 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,8 @@
)


class Grid:
class _BaseGrid:
"""Base class for grids of subplots."""
_margin_titles = False
_legend_out = True

def __init__(self):

self._tight_layout_rect = [0, 0, 1, 1]
self._tight_layout_pad = None

# This attribute is set externally and is a hack to handle newer functions that
# don't add proxy artists onto the Axes. We need an overall cleaner approach.
self._extract_legend_handles = False

def set(self, **kwargs):
"""Set attributes on each subplot Axes."""
Expand All @@ -48,19 +37,54 @@ def set(self, **kwargs):
ax.set(**kwargs)
return self

@property
def fig(self):
"""DEPRECATED: prefer the `figure` property."""
# Grid.figure is preferred because it matches the Axes attribute name.
# But as the maintanace burden on having this property is minimal,
# let's be slow about formally deprecating it. For now just note its deprecation
# in the docstring; add a warning in version 0.13, and eventually remove it.
return self._figure

@property
def figure(self):
"""Access the :class:`matplotlib.figure.Figure` object underlying the grid."""
return self._figure

def savefig(self, *args, **kwargs):
"""Save the figure."""
"""
Save an image of the plot.
This wraps :meth:`matplotlib.figure.Figure.savefig`, using bbox_inches="tight"
by default. Parameters are passed through to the matplotlib function.
"""
kwargs = kwargs.copy()
kwargs.setdefault("bbox_inches", "tight")
self.fig.savefig(*args, **kwargs)
self.figure.savefig(*args, **kwargs)


class Grid(_BaseGrid):
"""A grid that can have multiple subplots and an external legend."""
_margin_titles = False
_legend_out = True

def __init__(self):

self._tight_layout_rect = [0, 0, 1, 1]
self._tight_layout_pad = None

# This attribute is set externally and is a hack to handle newer functions that
# don't add proxy artists onto the Axes. We need an overall cleaner approach.
self._extract_legend_handles = False

def tight_layout(self, *args, **kwargs):
"""Call fig.tight_layout within rect that exclude the legend."""
kwargs = kwargs.copy()
kwargs.setdefault("rect", self._tight_layout_rect)
if self._tight_layout_pad is not None:
kwargs.setdefault("pad", self._tight_layout_pad)
self.fig.tight_layout(*args, **kwargs)
self._figure.tight_layout(*args, **kwargs)

def add_legend(self, legend_data=None, title=None, label_order=None,
adjust_subtitles=False, **kwargs):
Expand Down Expand Up @@ -122,7 +146,7 @@ def add_legend(self, legend_data=None, title=None, label_order=None,
kwargs.setdefault("loc", "center right")

# Draw a full-figure legend outside the grid
figlegend = self.fig.legend(handles, labels, **kwargs)
figlegend = self._figure.legend(handles, labels, **kwargs)

self._legend = figlegend
figlegend.set_title(title, prop={"size": title_size})
Expand All @@ -131,25 +155,25 @@ def add_legend(self, legend_data=None, title=None, label_order=None,
adjust_legend_subtitles(figlegend)

# Draw the plot to set the bounding boxes correctly
_draw_figure(self.fig)
_draw_figure(self._figure)

# Calculate and set the new width of the figure so the legend fits
legend_width = figlegend.get_window_extent().width / self.fig.dpi
fig_width, fig_height = self.fig.get_size_inches()
self.fig.set_size_inches(fig_width + legend_width, fig_height)
legend_width = figlegend.get_window_extent().width / self._figure.dpi
fig_width, fig_height = self._figure.get_size_inches()
self._figure.set_size_inches(fig_width + legend_width, fig_height)

# Draw the plot again to get the new transformations
_draw_figure(self.fig)
_draw_figure(self._figure)

# Now calculate how much space we need on the right side
legend_width = figlegend.get_window_extent().width / self.fig.dpi
legend_width = figlegend.get_window_extent().width / self._figure.dpi
space_needed = legend_width / (fig_width + legend_width)
margin = .04 if self._margin_titles else .01
self._space_needed = margin + space_needed
right = 1 - self._space_needed

# Place the subplot axes to give space for the legend
self.fig.subplots_adjust(right=right)
self._figure.subplots_adjust(right=right)
self._tight_layout_rect[2] = right

else:
Expand Down Expand Up @@ -418,7 +442,7 @@ def __init__(

# Attributes that are part of the public API but accessed through
# a property so that Sphinx adds them to the auto class doc
self._fig = fig
self._figure = fig
self._axes = axes
self._axes_dict = axes_dict
self._legend = None
Expand Down Expand Up @@ -801,7 +825,7 @@ def facet_axis(self, row_i, col_j, modify_state=True):

def despine(self, **kwargs):
"""Remove axis spines from the facets."""
utils.despine(self.fig, **kwargs)
utils.despine(self._figure, **kwargs)
return self

def set_axis_labels(self, x_var=None, y_var=None, clear_inner=True, **kwargs):
Expand Down Expand Up @@ -991,11 +1015,6 @@ def refline(self, *, x=None, y=None, color='.5', linestyle='--', **line_kws):

# ------ Properties that are part of the public API and documented by Sphinx

@property
def fig(self):
"""The :class:`matplotlib.figure.Figure` with the plot."""
return self._fig

@property
def axes(self):
"""An array of the :class:`matplotlib.axes.Axes` objects in the grid."""
Expand Down Expand Up @@ -1233,7 +1252,7 @@ def __init__(
axes[i, j].remove()
axes[i, j] = None

self.fig = fig
self._figure = fig
self.axes = axes
self.data = data

Expand Down Expand Up @@ -1618,7 +1637,7 @@ def _find_numeric_cols(self, data):
return numeric_cols


class JointGrid(object):
class JointGrid(_BaseGrid):
"""Grid for drawing a bivariate plot with marginal univariate plots.
Many plots can be drawn by using the figure-level interface :func:`jointplot`.
Expand Down Expand Up @@ -1650,7 +1669,7 @@ def __init__(
ax_marg_x = f.add_subplot(gs[0, :-1], sharex=ax_joint)
ax_marg_y = f.add_subplot(gs[1:, -1], sharey=ax_joint)

self.fig = f
self._figure = f
self.ax_joint = ax_joint
self.ax_marg_x = ax_marg_x
self.ax_marg_y = ax_marg_y
Expand Down Expand Up @@ -1913,15 +1932,6 @@ def set_axis_labels(self, xlabel="", ylabel="", **kwargs):
self.ax_joint.set_ylabel(ylabel, **kwargs)
return self

def savefig(self, *args, **kwargs):
"""Save the figure using a "tight" bounding box by default.
Wraps :meth:`matplotlib.figure.Figure.savefig`.
"""
kwargs.setdefault("bbox_inches", "tight")
self.fig.savefig(*args, **kwargs)


JointGrid.__init__.__doc__ = """\
Set up the grid of subplots and store data internally for easy plotting.
Expand Down
23 changes: 9 additions & 14 deletions seaborn/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def __init__(self, data, pivot_kws=None, z_score=None, standard_scale=None,

self.mask = _matrix_mask(self.data2d, mask)

self.fig = plt.figure(figsize=figsize)
self._figure = plt.figure(figsize=figsize)

self.row_colors, self.row_color_labels = \
self._preprocess_colors(data, row_colors, axis=0)
Expand Down Expand Up @@ -842,28 +842,28 @@ def __init__(self, data, pivot_kws=None, z_score=None, standard_scale=None,
width_ratios=width_ratios,
height_ratios=height_ratios)

self.ax_row_dendrogram = self.fig.add_subplot(self.gs[-1, 0])
self.ax_col_dendrogram = self.fig.add_subplot(self.gs[0, -1])
self.ax_row_dendrogram = self._figure.add_subplot(self.gs[-1, 0])
self.ax_col_dendrogram = self._figure.add_subplot(self.gs[0, -1])
self.ax_row_dendrogram.set_axis_off()
self.ax_col_dendrogram.set_axis_off()

self.ax_row_colors = None
self.ax_col_colors = None

if self.row_colors is not None:
self.ax_row_colors = self.fig.add_subplot(
self.ax_row_colors = self._figure.add_subplot(
self.gs[-1, 1])
if self.col_colors is not None:
self.ax_col_colors = self.fig.add_subplot(
self.ax_col_colors = self._figure.add_subplot(
self.gs[1, -1])

self.ax_heatmap = self.fig.add_subplot(self.gs[-1, -1])
self.ax_heatmap = self._figure.add_subplot(self.gs[-1, -1])
if cbar_pos is None:
self.ax_cbar = self.cax = None
else:
# Initialize the colorbar axes in the gridspec so that tight_layout
# works. We will move it where it belongs later. This is a hack.
self.ax_cbar = self.fig.add_subplot(self.gs[0, 0])
self.ax_cbar = self._figure.add_subplot(self.gs[0, 0])
self.cax = self.ax_cbar # Backwards compatibility
self.cbar_pos = cbar_pos

Expand Down Expand Up @@ -1066,11 +1066,6 @@ def color_list_to_matrix_and_cmap(colors, ind, axis=0):
cmap = mpl.colors.ListedColormap(list(unique_colors))
return matrix, cmap

def savefig(self, *args, **kwargs):
if 'bbox_inches' not in kwargs:
kwargs['bbox_inches'] = 'tight'
self.fig.savefig(*args, **kwargs)

def plot_dendrograms(self, row_cluster, col_cluster, metric, method,
row_linkage, col_linkage, tree_kws):
# Plot the row dendrogram
Expand Down Expand Up @@ -1208,13 +1203,13 @@ def plot_matrix(self, colorbar_kws, xind, yind, **kws):

tight_params = dict(h_pad=.02, w_pad=.02)
if self.ax_cbar is None:
self.fig.tight_layout(**tight_params)
self._figure.tight_layout(**tight_params)
else:
# Turn the colorbar axes off for tight layout so that its
# ticks don't interfere with the rest of the plot layout.
# Then move it.
self.ax_cbar.set_axis_off()
self.fig.tight_layout(**tight_params)
self._figure.tight_layout(**tight_params)
self.ax_cbar.set_axis_on()
self.ax_cbar.set_position(self.cbar_pos)

Expand Down
Loading

0 comments on commit 091f4c0

Please sign in to comment.