From c32fdfd32473dd354d292d33a19610a4c0a2eb63 Mon Sep 17 00:00:00 2001 From: Saairam Venkatesh Date: Tue, 2 Jul 2024 07:40:22 -0400 Subject: [PATCH] fix: seaborn multi plots in same session (#58) --- maidr/__init__.py | 5 +++-- maidr/core/figure_manager.py | 11 ++++++++++- maidr/core/maidr.py | 9 ++++++++- maidr/core/plot/boxplot.py | 2 +- maidr/core/plot/grouped_barplot.py | 2 +- maidr/core/plot/histogram.py | 3 +-- maidr/core/plot/scatterplot.py | 2 +- maidr/maidr.py | 24 +++--------------------- maidr/patch/clear.py | 17 +++++++++++++++++ 9 files changed, 45 insertions(+), 30 deletions(-) create mode 100644 maidr/patch/clear.py diff --git a/maidr/__init__.py b/maidr/__init__.py index 275f4c3..6d82400 100644 --- a/maidr/__init__.py +++ b/maidr/__init__.py @@ -2,10 +2,11 @@ from .core import Maidr from .core.enum import PlotType -from .patch import barplot, boxplot, heatmap, histogram, lineplot, scatterplot -from .maidr import save_html, show, stacked +from .patch import barplot, boxplot, clear, heatmap, histogram, lineplot, scatterplot +from .maidr import close, save_html, show, stacked __all__ = [ + "close", "save_html", "show", "stacked", diff --git a/maidr/core/figure_manager.py b/maidr/core/figure_manager.py index de26516..21bb0d2 100644 --- a/maidr/core/figure_manager.py +++ b/maidr/core/figure_manager.py @@ -74,9 +74,18 @@ def _get_maidr(cls, fig: Figure) -> Maidr: def get_maidr(cls, fig: Figure) -> Maidr: """Retrieve the Maidr instance for the given Figure.""" if fig not in cls.figs.keys(): - raise ValueError(f"No MAIDR found for figure: {fig}.") + raise KeyError(f"No MAIDR found for figure: {fig}.") return cls.figs[fig] + @classmethod + def destroy(cls, fig: Figure) -> None: + try: + maidr = cls.figs.pop(fig) + except KeyError: + return + maidr.destroy() + del maidr + @staticmethod def get_axes( artist: Artist | Axes | BarContainer | dict | list | None, diff --git a/maidr/core/maidr.py b/maidr/core/maidr.py index 892fafd..f686837 100644 --- a/maidr/core/maidr.py +++ b/maidr/core/maidr.py @@ -37,7 +37,7 @@ class Maidr: def __init__(self, fig: Figure) -> None: """Create a new Maidr for the given ``matplotlib.figure.Figure``.""" self._fig = fig - self._plots = list() + self._plots = [] @property def fig(self) -> Figure: @@ -95,6 +95,13 @@ def show(self, renderer: Literal["auto", "ipython", "browser"] = "auto") -> obje html = self._create_html_tag() return html.show(renderer) + def clear(self): + self._plots = [] + + def destroy(self) -> None: + del self._plots + del self._fig + def _create_html_tag(self) -> Tag: """Create the MAIDR HTML using HTML tags.""" svg = self._get_svg() diff --git a/maidr/core/plot/boxplot.py b/maidr/core/plot/boxplot.py index 9ef3179..1af30c2 100644 --- a/maidr/core/plot/boxplot.py +++ b/maidr/core/plot/boxplot.py @@ -162,7 +162,7 @@ def _extract_bxp_maidr(self, bxp_stats: dict) -> list[dict] | None: if bxp_stats is None: return None - bxp_maidr = list() + bxp_maidr = [] whiskers = self._bxp_extractor.extract_whiskers(bxp_stats["whiskers"]) caps = self._bxp_extractor.extract_caps(bxp_stats["caps"]) medians = self._bxp_extractor.extract_medians(bxp_stats["medians"]) diff --git a/maidr/core/plot/grouped_barplot.py b/maidr/core/plot/grouped_barplot.py index 09a5e3a..acba556 100644 --- a/maidr/core/plot/grouped_barplot.py +++ b/maidr/core/plot/grouped_barplot.py @@ -48,7 +48,7 @@ def _extract_grouped_bar_data( return None x_level = self.extract_level(self.ax) - data = list() + data = [] for container in plot: if len(x_level) != len(container.patches): diff --git a/maidr/core/plot/histogram.py b/maidr/core/plot/histogram.py index ec496ad..ffb12e5 100644 --- a/maidr/core/plot/histogram.py +++ b/maidr/core/plot/histogram.py @@ -27,8 +27,7 @@ def _extract_bar_container_data(plot: BarContainer | None) -> list[dict] | None: if plot is None or plot.patches is None: return None - data = list() - + data = [] for patch in plot.patches: y = float(patch.get_height()) x = float(patch.get_x()) diff --git a/maidr/core/plot/scatterplot.py b/maidr/core/plot/scatterplot.py index d424120..1b2de0b 100644 --- a/maidr/core/plot/scatterplot.py +++ b/maidr/core/plot/scatterplot.py @@ -27,7 +27,7 @@ def _extract_point_data(plot: PathCollection | None) -> list[dict] | None: if plot is None or plot.get_offsets() is None: return None - data = list() + data = [] for point in plot.get_offsets().data: x, y = point data.append( diff --git a/maidr/maidr.py b/maidr/maidr.py index 14dc5fa..de423c6 100644 --- a/maidr/maidr.py +++ b/maidr/maidr.py @@ -29,24 +29,6 @@ def stacked(plot: Axes | BarContainer) -> Maidr: return FigureManager.create_maidr(ax, PlotType.STACKED) -def close() -> None: - pass - - -def test_enum_comparison(): - print( - f"PlotType.COUNT == PlotType.BAR: {PlotType.COUNT == PlotType.BAR}" - ) # Should be False - print( - f"PlotType.COUNT is PlotType.BAR: {PlotType.COUNT is PlotType.BAR}" - ) # Should be False - print( - f"PlotType.COUNT == PlotType.COUNT: {PlotType.COUNT == PlotType.COUNT}" - ) # Should be True - print( - f"PlotType.BAR == PlotType.BAR: {PlotType.BAR == PlotType.BAR}" - ) # Should be True - print( - f"PlotType.COUNT is PlotType.COUNT: {PlotType.COUNT is PlotType.COUNT}" - ) # Should be True - print(f"PlotType.BAR is PlotType.BAR: {PlotType.BAR is PlotType.BAR}") +def close(plot: Any) -> None: + ax = FigureManager.get_axes(plot) + FigureManager.destroy(ax.get_figure()) diff --git a/maidr/patch/clear.py b/maidr/patch/clear.py new file mode 100644 index 0000000..700f5f3 --- /dev/null +++ b/maidr/patch/clear.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import wrapt + +from matplotlib.figure import Figure + +from maidr.core.figure_manager import FigureManager + + +@wrapt.patch_function_wrapper(Figure, "clear") +def clear(wrapped, instance, args, kwargs) -> None: + wrapped(*args, **kwargs) + try: + maidr = FigureManager.get_maidr(instance.get_figure()) + except KeyError: + return + maidr.clear()