From 44bdf86d3a14bc7a1cabefdd6bb7162718612da7 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Wed, 1 Dec 2021 19:14:42 +0100 Subject: [PATCH 1/3] [feat] Add the option to save a figure in plot setting params Since non-GUI based environments would like to avoid the usage of show method in the matplotlib, I added the option to savefig and thus users can complete the operations inside AutoPytorch. --- autoPyTorch/utils/results_visualizer.py | 48 ++++++++++++++----- .../40_advanced/example_plot_over_time.py | 11 ++--- test/test_utils/test_results_manager.py | 30 ++++-------- test/test_utils/test_results_visualizer.py | 41 ++++++++++++---- 4 files changed, 81 insertions(+), 49 deletions(-) diff --git a/autoPyTorch/utils/results_visualizer.py b/autoPyTorch/utils/results_visualizer.py index 64c87ba94..e1debe29c 100644 --- a/autoPyTorch/utils/results_visualizer.py +++ b/autoPyTorch/utils/results_visualizer.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, NamedTuple, Optional, Tuple import matplotlib.pyplot as plt @@ -71,8 +71,7 @@ def extract_dicts( return colors, labels -@dataclass(frozen=True) -class PlotSettingParams: +class PlotSettingParams(NamedTuple): """ Parameters for the plot environment. @@ -93,12 +92,28 @@ class PlotSettingParams: The range of x axis. ylim (Tuple[float, float]): The range of y axis. + grid (bool): + Whether to have grid lines. + If users would like to define lines in detail, + they need to deactivate it. legend (bool): Whether to have legend in the figure. - legend_loc (str): - The location of the legend. + legend_kwargs (Dict[str, Any]): + The kwargs for ax.legend. + Ref: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html + title (Optional[str]): + The title of the figure. + title_kwargs (Dict[str, Any]): + The kwargs for ax.set_title except title label. + Ref: https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.set_title.html show (bool): Whether to show the plot. + If figname is not None, the save will be prioritized. + figname (Optional[str]): + Name of a figure to save. If None, no figure will be saved. + savefig_kwargs (Dict[str, Any]): + The kwargs for plt.savefig except filename. + Ref: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html args, kwargs (Any): Arguments for the ax.plot. """ @@ -108,12 +123,16 @@ class PlotSettingParams: xlabel: Optional[str] = None ylabel: Optional[str] = None title: Optional[str] = None + title_kwargs: Dict[str, Any] = {} xlim: Optional[Tuple[float, float]] = None ylim: Optional[Tuple[float, float]] = None + grid: bool = True legend: bool = True - legend_loc: str = 'best' + legend_kwargs: Dict[str, Any] = {} show: bool = False + figname: Optional[str] = None figsize: Optional[Tuple[int, int]] = None + savefig_kwargs: Dict[str, Any] = {} class ScaleChoices(Enum): @@ -201,17 +220,22 @@ def _set_plot_args( ax.set_xscale(plot_setting_params.xscale) ax.set_yscale(plot_setting_params.yscale) - if plot_setting_params.xscale == 'log' or plot_setting_params.yscale == 'log': - ax.grid(True, which='minor', color='gray', linestyle=':') - ax.grid(True, which='major', color='black') + if plot_setting_params.grid: + if plot_setting_params.xscale == 'log' or plot_setting_params.yscale == 'log': + ax.grid(True, which='minor', color='gray', linestyle=':') + + ax.grid(True, which='major', color='black') if plot_setting_params.legend: - ax.legend(loc=plot_setting_params.legend_loc) + ax.legend(**plot_setting_params.legend_kwargs) if plot_setting_params.title is not None: - ax.set_title(plot_setting_params.title) - if plot_setting_params.show: + ax.set_title(plot_setting_params.title, **plot_setting_params.title_kwargs) + + if plot_setting_params.figname is not None: + plt.savefig(plot_setting_params.figname, **plot_setting_params.savefig_kwargs) + elif plot_setting_params.show: plt.show() @staticmethod diff --git a/examples/40_advanced/example_plot_over_time.py b/examples/40_advanced/example_plot_over_time.py index 9c103452e..cf672fc46 100644 --- a/examples/40_advanced/example_plot_over_time.py +++ b/examples/40_advanced/example_plot_over_time.py @@ -62,21 +62,20 @@ xlabel='Runtime', ylabel='Accuracy', title='Toy Example', - show=False # If you would like to show, make it True + figname='example_plot_over_time.png', + savefig_kwargs={'bbox_inches': 'tight'}, + show=False # If you would like to show, make it True and set figname=None ) ############################################################################ # Plot with the Specified Setting Parameters # ========================================== -_, ax = plt.subplots() +# _, ax = plt.subplots() <=== You can feed it to post-process the figure. +# You might need to run `export DISPLAY=:0.0` if you are using non-GUI based environment. api.plot_perf_over_time( - ax=ax, # You do not have to provide. metric_name=metric_name, plot_setting_params=params, marker='*', markersize=10 ) - -# plt.show() might cause issue depending on environments -plt.savefig('example_plot_over_time.png') diff --git a/test/test_utils/test_results_manager.py b/test/test_utils/test_results_manager.py index 60ee11f42..8998009a4 100644 --- a/test/test_utils/test_results_manager.py +++ b/test/test_utils/test_results_manager.py @@ -165,11 +165,9 @@ def test_extract_results_from_run_history(): time=1.0, status=StatusType.CAPPED, ) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): SearchResults(metric=accuracy, scoring_functions=[], run_history=run_history) - assert excinfo._excinfo[0] == ValueError - def test_raise_error_in_update_and_sort_by_time(): cs = ConfigurationSpace() @@ -179,7 +177,7 @@ def test_raise_error_in_update_and_sort_by_time(): sr = SearchResults(metric=accuracy, scoring_functions=[], run_history=RunHistory()) er = EnsembleResults(metric=accuracy, ensemble_performance_history=[]) - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): sr._update( config=config, run_key=RunKey(config_id=0, instance_id=0, seed=0), @@ -189,19 +187,13 @@ def test_raise_error_in_update_and_sort_by_time(): ) ) - assert excinfo._excinfo[0] == RuntimeError - - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): sr._sort_by_endtime() - assert excinfo._excinfo[0] == RuntimeError - - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): er._update(data={}) - assert excinfo._excinfo[0] == RuntimeError - - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): er._sort_by_endtime() @@ -244,11 +236,9 @@ def test_raise_error_in_get_start_time(): status=StatusType.CAPPED, ) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): get_start_time(run_history) - assert excinfo._excinfo[0] == ValueError - def test_search_results_sort_by_endtime(): run_history = RunHistory() @@ -364,11 +354,9 @@ def test_metric_results(metric, scores, ensemble_ends_later): def test_search_results_sprint_statistics(): api = BaseTask() for method in ['get_search_results', 'sprint_statistics', 'get_incumbent_results']: - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): getattr(api, method)() - assert excinfo._excinfo[0] == RuntimeError - run_history_data = json.load(open(os.path.join(os.path.dirname(__file__), 'runhistory.json'), mode='r'))['data'] @@ -420,11 +408,9 @@ def test_check_run_history(run_history): manager = ResultsManager() manager.run_history = run_history - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): manager._check_run_history() - assert excinfo._excinfo[0] == RuntimeError - @pytest.mark.parametrize('include_traditional', (True, False)) @pytest.mark.parametrize('metric', (accuracy, log_loss)) diff --git a/test/test_utils/test_results_visualizer.py b/test/test_utils/test_results_visualizer.py index 926d21e6f..545605f73 100644 --- a/test/test_utils/test_results_visualizer.py +++ b/test/test_utils/test_results_visualizer.py @@ -67,6 +67,34 @@ def test_plt_show_in_set_plot_args(params): # TODO plt.close() +@pytest.mark.parametrize('params', ( + PlotSettingParams(), + PlotSettingParams(figname='fig') +)) +def test_plt_savefig_in_set_plot_args(params): # TODO + plt.savefig = MagicMock() + _, ax = plt.subplots(nrows=1, ncols=1) + viz = ResultsVisualizer() + + viz._set_plot_args(ax, params) + assert plt.savefig._mock_called == (params.figname is not None) + plt.close() + + +@pytest.mark.parametrize('params', ( + PlotSettingParams(grid=True), + PlotSettingParams(grid=False) +)) +def test_ax_grid_in_set_plot_args(params): # TODO + _, ax = plt.subplots(nrows=1, ncols=1) + ax.grid = MagicMock() + viz = ResultsVisualizer() + + viz._set_plot_args(ax, params) + assert ax.grid._mock_called == params.grid + plt.close() + + @pytest.mark.parametrize('params', ( PlotSettingParams(xscale='none', yscale='none'), PlotSettingParams(xscale='none', yscale='log'), @@ -77,10 +105,9 @@ def test_raise_value_error_in_set_plot_args(params): # TODO _, ax = plt.subplots(nrows=1, ncols=1) viz = ResultsVisualizer() - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): viz._set_plot_args(ax, params) - assert excinfo._excinfo[0] == ValueError plt.close() @@ -119,13 +146,11 @@ def test_raise_error_in_plot_perf_over_time_in_base_task(metric_name): api = BaseTask() if metric_name == 'unknown': - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): api.plot_perf_over_time(metric_name) - assert excinfo._excinfo[0] == ValueError else: - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): api.plot_perf_over_time(metric_name) - assert excinfo._excinfo[0] == RuntimeError @pytest.mark.parametrize('metric_name', ('balanced_accuracy', 'accuracy')) @@ -175,7 +200,7 @@ def test_raise_error_get_perf_and_time(params): results = np.linspace(-1, 1, 10) cum_times = np.linspace(0, 1, 10) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): _get_perf_and_time( cum_results=results, cum_times=cum_times, @@ -183,8 +208,6 @@ def test_raise_error_get_perf_and_time(params): worst_val=np.inf ) - assert excinfo._excinfo[0] == ValueError - @pytest.mark.parametrize('params', ( PlotSettingParams(n_points=20, xscale='linear', yscale='linear'), From 26c8c1a2d951c04994937c60a52f0958b4b833f8 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Wed, 1 Dec 2021 19:24:24 +0100 Subject: [PATCH 2/3] [doc] Add a comment for non-GUI based computer in plot_perf_over_time method --- autoPyTorch/api/base_task.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index edd505d86..b4d20165e 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -1513,6 +1513,9 @@ def plot_perf_over_time( The settings of a pair of color and label for each plot. args, kwargs (Any): Arguments for the ax.plot. + + Note: + You might need to run `export DISPLAY=:0.0` if you are using non-GUI based environment. """ if not hasattr(metrics, metric_name): From 08368f7190cee84490b9b51e6d81e11fab9cec09 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Wed, 1 Dec 2021 19:31:26 +0100 Subject: [PATCH 3/3] [test] Add a test to check the priority of show and savefig Since plt.savefig and plt.show do not work at the same time due to the matplotlib design, we need to check whether show will not be called when a figname is specified. We can actually raise an error, but plot will be basically called in the end of an optimization, so I wanted to avoid raising an error and just sticked to a check by tests. --- test/test_utils/test_results_visualizer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_utils/test_results_visualizer.py b/test/test_utils/test_results_visualizer.py index 545605f73..c463fa063 100644 --- a/test/test_utils/test_results_visualizer.py +++ b/test/test_utils/test_results_visualizer.py @@ -55,15 +55,18 @@ def test_extract_dicts(cl_settings, with_ensemble): @pytest.mark.parametrize('params', ( PlotSettingParams(show=True), - PlotSettingParams(show=False) + PlotSettingParams(show=False), + PlotSettingParams(show=True, figname='dummy') )) def test_plt_show_in_set_plot_args(params): # TODO plt.show = MagicMock() + plt.savefig = MagicMock() _, ax = plt.subplots(nrows=1, ncols=1) viz = ResultsVisualizer() viz._set_plot_args(ax, params) - assert plt.show._mock_called == params.show + # if figname is not None, show will not be called. (due to the matplotlib design) + assert plt.show._mock_called == (params.figname is None and params.show) plt.close()