Skip to content

Commit

Permalink
fix: seaborn multi plots in same session (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
SaaiVenkat authored Jul 2, 2024
1 parent e5986aa commit c32fdfd
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 30 deletions.
5 changes: 3 additions & 2 deletions maidr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from .core import Maidr
from .core.enum import PlotType
from .patch import barplot, boxplot, heatmap, histogram, lineplot, scatterplot
from .maidr import save_html, show, stacked
from .patch import barplot, boxplot, clear, heatmap, histogram, lineplot, scatterplot
from .maidr import close, save_html, show, stacked

__all__ = [
"close",
"save_html",
"show",
"stacked",
Expand Down
11 changes: 10 additions & 1 deletion maidr/core/figure_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,18 @@ def _get_maidr(cls, fig: Figure) -> Maidr:
def get_maidr(cls, fig: Figure) -> Maidr:
"""Retrieve the Maidr instance for the given Figure."""
if fig not in cls.figs.keys():
raise ValueError(f"No MAIDR found for figure: {fig}.")
raise KeyError(f"No MAIDR found for figure: {fig}.")
return cls.figs[fig]

@classmethod
def destroy(cls, fig: Figure) -> None:
try:
maidr = cls.figs.pop(fig)
except KeyError:
return
maidr.destroy()
del maidr

@staticmethod
def get_axes(
artist: Artist | Axes | BarContainer | dict | list | None,
Expand Down
9 changes: 8 additions & 1 deletion maidr/core/maidr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Maidr:
def __init__(self, fig: Figure) -> None:
"""Create a new Maidr for the given ``matplotlib.figure.Figure``."""
self._fig = fig
self._plots = list()
self._plots = []

@property
def fig(self) -> Figure:
Expand Down Expand Up @@ -95,6 +95,13 @@ def show(self, renderer: Literal["auto", "ipython", "browser"] = "auto") -> obje
html = self._create_html_tag()
return html.show(renderer)

def clear(self):
self._plots = []

def destroy(self) -> None:
del self._plots
del self._fig

def _create_html_tag(self) -> Tag:
"""Create the MAIDR HTML using HTML tags."""
svg = self._get_svg()
Expand Down
2 changes: 1 addition & 1 deletion maidr/core/plot/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _extract_bxp_maidr(self, bxp_stats: dict) -> list[dict] | None:
if bxp_stats is None:
return None

bxp_maidr = list()
bxp_maidr = []
whiskers = self._bxp_extractor.extract_whiskers(bxp_stats["whiskers"])
caps = self._bxp_extractor.extract_caps(bxp_stats["caps"])
medians = self._bxp_extractor.extract_medians(bxp_stats["medians"])
Expand Down
2 changes: 1 addition & 1 deletion maidr/core/plot/grouped_barplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _extract_grouped_bar_data(
return None

x_level = self.extract_level(self.ax)
data = list()
data = []

for container in plot:
if len(x_level) != len(container.patches):
Expand Down
3 changes: 1 addition & 2 deletions maidr/core/plot/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def _extract_bar_container_data(plot: BarContainer | None) -> list[dict] | None:
if plot is None or plot.patches is None:
return None

data = list()

data = []
for patch in plot.patches:
y = float(patch.get_height())
x = float(patch.get_x())
Expand Down
2 changes: 1 addition & 1 deletion maidr/core/plot/scatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _extract_point_data(plot: PathCollection | None) -> list[dict] | None:
if plot is None or plot.get_offsets() is None:
return None

data = list()
data = []
for point in plot.get_offsets().data:
x, y = point
data.append(
Expand Down
24 changes: 3 additions & 21 deletions maidr/maidr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,6 @@ def stacked(plot: Axes | BarContainer) -> Maidr:
return FigureManager.create_maidr(ax, PlotType.STACKED)


def close() -> None:
pass


def test_enum_comparison():
print(
f"PlotType.COUNT == PlotType.BAR: {PlotType.COUNT == PlotType.BAR}"
) # Should be False
print(
f"PlotType.COUNT is PlotType.BAR: {PlotType.COUNT is PlotType.BAR}"
) # Should be False
print(
f"PlotType.COUNT == PlotType.COUNT: {PlotType.COUNT == PlotType.COUNT}"
) # Should be True
print(
f"PlotType.BAR == PlotType.BAR: {PlotType.BAR == PlotType.BAR}"
) # Should be True
print(
f"PlotType.COUNT is PlotType.COUNT: {PlotType.COUNT is PlotType.COUNT}"
) # Should be True
print(f"PlotType.BAR is PlotType.BAR: {PlotType.BAR is PlotType.BAR}")
def close(plot: Any) -> None:
ax = FigureManager.get_axes(plot)
FigureManager.destroy(ax.get_figure())
17 changes: 17 additions & 0 deletions maidr/patch/clear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

import wrapt

from matplotlib.figure import Figure

from maidr.core.figure_manager import FigureManager


@wrapt.patch_function_wrapper(Figure, "clear")
def clear(wrapped, instance, args, kwargs) -> None:
wrapped(*args, **kwargs)
try:
maidr = FigureManager.get_maidr(instance.get_figure())
except KeyError:
return
maidr.clear()

0 comments on commit c32fdfd

Please sign in to comment.