Skip to content

Commit

Permalink
[feat] Add the option to save a figure in plot setting params (#351)
Browse files Browse the repository at this point in the history
* [feat] Add the option to save a figure in plot setting params

Since non-GUI based environments would like to avoid the usage of
show method in the matplotlib, I added the option to savefig and
thus users can complete the operations inside AutoPytorch.

* [doc] Add a comment for non-GUI based computer in plot_perf_over_time method

* [test] Add a test to check the priority of show and savefig

Since plt.savefig and plt.show do not work at the same time due to the
matplotlib design, we need to check whether show will not be called
when a figname is specified. We can actually raise an error, but plot
will be basically called in the end of an optimization, so I wanted
to avoid raising an error and just sticked to a check by tests.
  • Loading branch information
nabenabe0928 authored and ravinkohli committed Dec 8, 2021
1 parent 2a7400e commit 7e2ac06
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 51 deletions.
3 changes: 3 additions & 0 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,9 @@ def plot_perf_over_time(
The settings of a pair of color and label for each plot.
args, kwargs (Any):
Arguments for the ax.plot.
Note:
You might need to run `export DISPLAY=:0.0` if you are using non-GUI based environment.
"""

if not hasattr(metrics, metric_name):
Expand Down
48 changes: 36 additions & 12 deletions autoPyTorch/utils/results_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, NamedTuple, Optional, Tuple

import matplotlib.pyplot as plt

Expand Down Expand Up @@ -71,8 +71,7 @@ def extract_dicts(
return colors, labels


@dataclass(frozen=True)
class PlotSettingParams:
class PlotSettingParams(NamedTuple):
"""
Parameters for the plot environment.
Expand All @@ -93,12 +92,28 @@ class PlotSettingParams:
The range of x axis.
ylim (Tuple[float, float]):
The range of y axis.
grid (bool):
Whether to have grid lines.
If users would like to define lines in detail,
they need to deactivate it.
legend (bool):
Whether to have legend in the figure.
legend_loc (str):
The location of the legend.
legend_kwargs (Dict[str, Any]):
The kwargs for ax.legend.
Ref: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html
title (Optional[str]):
The title of the figure.
title_kwargs (Dict[str, Any]):
The kwargs for ax.set_title except title label.
Ref: https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.set_title.html
show (bool):
Whether to show the plot.
If figname is not None, the save will be prioritized.
figname (Optional[str]):
Name of a figure to save. If None, no figure will be saved.
savefig_kwargs (Dict[str, Any]):
The kwargs for plt.savefig except filename.
Ref: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html
args, kwargs (Any):
Arguments for the ax.plot.
"""
Expand All @@ -108,12 +123,16 @@ class PlotSettingParams:
xlabel: Optional[str] = None
ylabel: Optional[str] = None
title: Optional[str] = None
title_kwargs: Dict[str, Any] = {}
xlim: Optional[Tuple[float, float]] = None
ylim: Optional[Tuple[float, float]] = None
grid: bool = True
legend: bool = True
legend_loc: str = 'best'
legend_kwargs: Dict[str, Any] = {}
show: bool = False
figname: Optional[str] = None
figsize: Optional[Tuple[int, int]] = None
savefig_kwargs: Dict[str, Any] = {}


class ScaleChoices(Enum):
Expand Down Expand Up @@ -201,17 +220,22 @@ def _set_plot_args(

ax.set_xscale(plot_setting_params.xscale)
ax.set_yscale(plot_setting_params.yscale)
if plot_setting_params.xscale == 'log' or plot_setting_params.yscale == 'log':
ax.grid(True, which='minor', color='gray', linestyle=':')

ax.grid(True, which='major', color='black')
if plot_setting_params.grid:
if plot_setting_params.xscale == 'log' or plot_setting_params.yscale == 'log':
ax.grid(True, which='minor', color='gray', linestyle=':')

ax.grid(True, which='major', color='black')

if plot_setting_params.legend:
ax.legend(loc=plot_setting_params.legend_loc)
ax.legend(**plot_setting_params.legend_kwargs)

if plot_setting_params.title is not None:
ax.set_title(plot_setting_params.title)
if plot_setting_params.show:
ax.set_title(plot_setting_params.title, **plot_setting_params.title_kwargs)

if plot_setting_params.figname is not None:
plt.savefig(plot_setting_params.figname, **plot_setting_params.savefig_kwargs)
elif plot_setting_params.show:
plt.show()

@staticmethod
Expand Down
11 changes: 5 additions & 6 deletions examples/40_advanced/example_plot_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,20 @@
xlabel='Runtime',
ylabel='Accuracy',
title='Toy Example',
show=False # If you would like to show, make it True
figname='example_plot_over_time.png',
savefig_kwargs={'bbox_inches': 'tight'},
show=False # If you would like to show, make it True and set figname=None
)

############################################################################
# Plot with the Specified Setting Parameters
# ==========================================
_, ax = plt.subplots()
# _, ax = plt.subplots() <=== You can feed it to post-process the figure.

# You might need to run `export DISPLAY=:0.0` if you are using non-GUI based environment.
api.plot_perf_over_time(
ax=ax, # You do not have to provide.
metric_name=metric_name,
plot_setting_params=params,
marker='*',
markersize=10
)

# plt.show() might cause issue depending on environments
plt.savefig('example_plot_over_time.png')
30 changes: 8 additions & 22 deletions test/test_utils/test_results_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,9 @@ def test_extract_results_from_run_history():
time=1.0,
status=StatusType.CAPPED,
)
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError):
SearchResults(metric=accuracy, scoring_functions=[], run_history=run_history)

assert excinfo._excinfo[0] == ValueError


def test_raise_error_in_update_and_sort_by_time():
cs = ConfigurationSpace()
Expand All @@ -179,7 +177,7 @@ def test_raise_error_in_update_and_sort_by_time():
sr = SearchResults(metric=accuracy, scoring_functions=[], run_history=RunHistory())
er = EnsembleResults(metric=accuracy, ensemble_performance_history=[])

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
sr._update(
config=config,
run_key=RunKey(config_id=0, instance_id=0, seed=0),
Expand All @@ -189,19 +187,13 @@ def test_raise_error_in_update_and_sort_by_time():
)
)

assert excinfo._excinfo[0] == RuntimeError

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
sr._sort_by_endtime()

assert excinfo._excinfo[0] == RuntimeError

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
er._update(data={})

assert excinfo._excinfo[0] == RuntimeError

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
er._sort_by_endtime()


Expand Down Expand Up @@ -244,11 +236,9 @@ def test_raise_error_in_get_start_time():
status=StatusType.CAPPED,
)

with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError):
get_start_time(run_history)

assert excinfo._excinfo[0] == ValueError


def test_search_results_sort_by_endtime():
run_history = RunHistory()
Expand Down Expand Up @@ -364,11 +354,9 @@ def test_metric_results(metric, scores, ensemble_ends_later):
def test_search_results_sprint_statistics():
api = BaseTask()
for method in ['get_search_results', 'sprint_statistics', 'get_incumbent_results']:
with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
getattr(api, method)()

assert excinfo._excinfo[0] == RuntimeError

run_history_data = json.load(open(os.path.join(os.path.dirname(__file__),
'runhistory.json'),
mode='r'))['data']
Expand Down Expand Up @@ -420,11 +408,9 @@ def test_check_run_history(run_history):
manager = ResultsManager()
manager.run_history = run_history

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
manager._check_run_history()

assert excinfo._excinfo[0] == RuntimeError


@pytest.mark.parametrize('include_traditional', (True, False))
@pytest.mark.parametrize('metric', (accuracy, log_loss))
Expand Down
48 changes: 37 additions & 11 deletions test/test_utils/test_results_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,46 @@ def test_extract_dicts(cl_settings, with_ensemble):

@pytest.mark.parametrize('params', (
PlotSettingParams(show=True),
PlotSettingParams(show=False)
PlotSettingParams(show=False),
PlotSettingParams(show=True, figname='dummy')
))
def test_plt_show_in_set_plot_args(params): # TODO
plt.show = MagicMock()
plt.savefig = MagicMock()
_, ax = plt.subplots(nrows=1, ncols=1)
viz = ResultsVisualizer()

viz._set_plot_args(ax, params)
assert plt.show._mock_called == params.show
# if figname is not None, show will not be called. (due to the matplotlib design)
assert plt.show._mock_called == (params.figname is None and params.show)
plt.close()


@pytest.mark.parametrize('params', (
PlotSettingParams(),
PlotSettingParams(figname='fig')
))
def test_plt_savefig_in_set_plot_args(params): # TODO
plt.savefig = MagicMock()
_, ax = plt.subplots(nrows=1, ncols=1)
viz = ResultsVisualizer()

viz._set_plot_args(ax, params)
assert plt.savefig._mock_called == (params.figname is not None)
plt.close()


@pytest.mark.parametrize('params', (
PlotSettingParams(grid=True),
PlotSettingParams(grid=False)
))
def test_ax_grid_in_set_plot_args(params): # TODO
_, ax = plt.subplots(nrows=1, ncols=1)
ax.grid = MagicMock()
viz = ResultsVisualizer()

viz._set_plot_args(ax, params)
assert ax.grid._mock_called == params.grid
plt.close()


Expand All @@ -77,10 +108,9 @@ def test_raise_value_error_in_set_plot_args(params): # TODO
_, ax = plt.subplots(nrows=1, ncols=1)
viz = ResultsVisualizer()

with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError):
viz._set_plot_args(ax, params)

assert excinfo._excinfo[0] == ValueError
plt.close()


Expand Down Expand Up @@ -119,13 +149,11 @@ def test_raise_error_in_plot_perf_over_time_in_base_task(metric_name):
api = BaseTask()

if metric_name == 'unknown':
with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError):
api.plot_perf_over_time(metric_name)
assert excinfo._excinfo[0] == ValueError
else:
with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(RuntimeError):
api.plot_perf_over_time(metric_name)
assert excinfo._excinfo[0] == RuntimeError


@pytest.mark.parametrize('metric_name', ('balanced_accuracy', 'accuracy'))
Expand Down Expand Up @@ -175,16 +203,14 @@ def test_raise_error_get_perf_and_time(params):
results = np.linspace(-1, 1, 10)
cum_times = np.linspace(0, 1, 10)

with pytest.raises(ValueError) as excinfo:
with pytest.raises(ValueError):
_get_perf_and_time(
cum_results=results,
cum_times=cum_times,
plot_setting_params=params,
worst_val=np.inf
)

assert excinfo._excinfo[0] == ValueError


@pytest.mark.parametrize('params', (
PlotSettingParams(n_points=20, xscale='linear', yscale='linear'),
Expand Down

0 comments on commit 7e2ac06

Please sign in to comment.