diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index edd505d86..b4d20165e 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -1513,6 +1513,9 @@ def plot_perf_over_time( The settings of a pair of color and label for each plot. args, kwargs (Any): Arguments for the ax.plot. + + Note: + You might need to run `export DISPLAY=:0.0` if you are using non-GUI based environment. """ if not hasattr(metrics, metric_name): diff --git a/autoPyTorch/utils/results_visualizer.py b/autoPyTorch/utils/results_visualizer.py index 64c87ba94..e1debe29c 100644 --- a/autoPyTorch/utils/results_visualizer.py +++ b/autoPyTorch/utils/results_visualizer.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, NamedTuple, Optional, Tuple import matplotlib.pyplot as plt @@ -71,8 +71,7 @@ def extract_dicts( return colors, labels -@dataclass(frozen=True) -class PlotSettingParams: +class PlotSettingParams(NamedTuple): """ Parameters for the plot environment. @@ -93,12 +92,28 @@ class PlotSettingParams: The range of x axis. ylim (Tuple[float, float]): The range of y axis. + grid (bool): + Whether to have grid lines. + If users would like to define lines in detail, + they need to deactivate it. legend (bool): Whether to have legend in the figure. - legend_loc (str): - The location of the legend. + legend_kwargs (Dict[str, Any]): + The kwargs for ax.legend. + Ref: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html + title (Optional[str]): + The title of the figure. + title_kwargs (Dict[str, Any]): + The kwargs for ax.set_title except title label. + Ref: https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.set_title.html show (bool): Whether to show the plot. + If figname is not None, the save will be prioritized. + figname (Optional[str]): + Name of a figure to save. If None, no figure will be saved. + savefig_kwargs (Dict[str, Any]): + The kwargs for plt.savefig except filename. + Ref: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html args, kwargs (Any): Arguments for the ax.plot. """ @@ -108,12 +123,16 @@ class PlotSettingParams: xlabel: Optional[str] = None ylabel: Optional[str] = None title: Optional[str] = None + title_kwargs: Dict[str, Any] = {} xlim: Optional[Tuple[float, float]] = None ylim: Optional[Tuple[float, float]] = None + grid: bool = True legend: bool = True - legend_loc: str = 'best' + legend_kwargs: Dict[str, Any] = {} show: bool = False + figname: Optional[str] = None figsize: Optional[Tuple[int, int]] = None + savefig_kwargs: Dict[str, Any] = {} class ScaleChoices(Enum): @@ -201,17 +220,22 @@ def _set_plot_args( ax.set_xscale(plot_setting_params.xscale) ax.set_yscale(plot_setting_params.yscale) - if plot_setting_params.xscale == 'log' or plot_setting_params.yscale == 'log': - ax.grid(True, which='minor', color='gray', linestyle=':') - ax.grid(True, which='major', color='black') + if plot_setting_params.grid: + if plot_setting_params.xscale == 'log' or plot_setting_params.yscale == 'log': + ax.grid(True, which='minor', color='gray', linestyle=':') + + ax.grid(True, which='major', color='black') if plot_setting_params.legend: - ax.legend(loc=plot_setting_params.legend_loc) + ax.legend(**plot_setting_params.legend_kwargs) if plot_setting_params.title is not None: - ax.set_title(plot_setting_params.title) - if plot_setting_params.show: + ax.set_title(plot_setting_params.title, **plot_setting_params.title_kwargs) + + if plot_setting_params.figname is not None: + plt.savefig(plot_setting_params.figname, **plot_setting_params.savefig_kwargs) + elif plot_setting_params.show: plt.show() @staticmethod diff --git a/examples/40_advanced/example_plot_over_time.py b/examples/40_advanced/example_plot_over_time.py index 9c103452e..cf672fc46 100644 --- a/examples/40_advanced/example_plot_over_time.py +++ b/examples/40_advanced/example_plot_over_time.py @@ -62,21 +62,20 @@ xlabel='Runtime', ylabel='Accuracy', title='Toy Example', - show=False # If you would like to show, make it True + figname='example_plot_over_time.png', + savefig_kwargs={'bbox_inches': 'tight'}, + show=False # If you would like to show, make it True and set figname=None ) ############################################################################ # Plot with the Specified Setting Parameters # ========================================== -_, ax = plt.subplots() +# _, ax = plt.subplots() <=== You can feed it to post-process the figure. +# You might need to run `export DISPLAY=:0.0` if you are using non-GUI based environment. api.plot_perf_over_time( - ax=ax, # You do not have to provide. metric_name=metric_name, plot_setting_params=params, marker='*', markersize=10 ) - -# plt.show() might cause issue depending on environments -plt.savefig('example_plot_over_time.png') diff --git a/test/test_utils/test_results_manager.py b/test/test_utils/test_results_manager.py index 60ee11f42..8998009a4 100644 --- a/test/test_utils/test_results_manager.py +++ b/test/test_utils/test_results_manager.py @@ -165,11 +165,9 @@ def test_extract_results_from_run_history(): time=1.0, status=StatusType.CAPPED, ) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): SearchResults(metric=accuracy, scoring_functions=[], run_history=run_history) - assert excinfo._excinfo[0] == ValueError - def test_raise_error_in_update_and_sort_by_time(): cs = ConfigurationSpace() @@ -179,7 +177,7 @@ def test_raise_error_in_update_and_sort_by_time(): sr = SearchResults(metric=accuracy, scoring_functions=[], run_history=RunHistory()) er = EnsembleResults(metric=accuracy, ensemble_performance_history=[]) - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): sr._update( config=config, run_key=RunKey(config_id=0, instance_id=0, seed=0), @@ -189,19 +187,13 @@ def test_raise_error_in_update_and_sort_by_time(): ) ) - assert excinfo._excinfo[0] == RuntimeError - - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): sr._sort_by_endtime() - assert excinfo._excinfo[0] == RuntimeError - - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): er._update(data={}) - assert excinfo._excinfo[0] == RuntimeError - - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): er._sort_by_endtime() @@ -244,11 +236,9 @@ def test_raise_error_in_get_start_time(): status=StatusType.CAPPED, ) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): get_start_time(run_history) - assert excinfo._excinfo[0] == ValueError - def test_search_results_sort_by_endtime(): run_history = RunHistory() @@ -364,11 +354,9 @@ def test_metric_results(metric, scores, ensemble_ends_later): def test_search_results_sprint_statistics(): api = BaseTask() for method in ['get_search_results', 'sprint_statistics', 'get_incumbent_results']: - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): getattr(api, method)() - assert excinfo._excinfo[0] == RuntimeError - run_history_data = json.load(open(os.path.join(os.path.dirname(__file__), 'runhistory.json'), mode='r'))['data'] @@ -420,11 +408,9 @@ def test_check_run_history(run_history): manager = ResultsManager() manager.run_history = run_history - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): manager._check_run_history() - assert excinfo._excinfo[0] == RuntimeError - @pytest.mark.parametrize('include_traditional', (True, False)) @pytest.mark.parametrize('metric', (accuracy, log_loss)) diff --git a/test/test_utils/test_results_visualizer.py b/test/test_utils/test_results_visualizer.py index 926d21e6f..c463fa063 100644 --- a/test/test_utils/test_results_visualizer.py +++ b/test/test_utils/test_results_visualizer.py @@ -55,15 +55,46 @@ def test_extract_dicts(cl_settings, with_ensemble): @pytest.mark.parametrize('params', ( PlotSettingParams(show=True), - PlotSettingParams(show=False) + PlotSettingParams(show=False), + PlotSettingParams(show=True, figname='dummy') )) def test_plt_show_in_set_plot_args(params): # TODO plt.show = MagicMock() + plt.savefig = MagicMock() _, ax = plt.subplots(nrows=1, ncols=1) viz = ResultsVisualizer() viz._set_plot_args(ax, params) - assert plt.show._mock_called == params.show + # if figname is not None, show will not be called. (due to the matplotlib design) + assert plt.show._mock_called == (params.figname is None and params.show) + plt.close() + + +@pytest.mark.parametrize('params', ( + PlotSettingParams(), + PlotSettingParams(figname='fig') +)) +def test_plt_savefig_in_set_plot_args(params): # TODO + plt.savefig = MagicMock() + _, ax = plt.subplots(nrows=1, ncols=1) + viz = ResultsVisualizer() + + viz._set_plot_args(ax, params) + assert plt.savefig._mock_called == (params.figname is not None) + plt.close() + + +@pytest.mark.parametrize('params', ( + PlotSettingParams(grid=True), + PlotSettingParams(grid=False) +)) +def test_ax_grid_in_set_plot_args(params): # TODO + _, ax = plt.subplots(nrows=1, ncols=1) + ax.grid = MagicMock() + viz = ResultsVisualizer() + + viz._set_plot_args(ax, params) + assert ax.grid._mock_called == params.grid plt.close() @@ -77,10 +108,9 @@ def test_raise_value_error_in_set_plot_args(params): # TODO _, ax = plt.subplots(nrows=1, ncols=1) viz = ResultsVisualizer() - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): viz._set_plot_args(ax, params) - assert excinfo._excinfo[0] == ValueError plt.close() @@ -119,13 +149,11 @@ def test_raise_error_in_plot_perf_over_time_in_base_task(metric_name): api = BaseTask() if metric_name == 'unknown': - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): api.plot_perf_over_time(metric_name) - assert excinfo._excinfo[0] == ValueError else: - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(RuntimeError): api.plot_perf_over_time(metric_name) - assert excinfo._excinfo[0] == RuntimeError @pytest.mark.parametrize('metric_name', ('balanced_accuracy', 'accuracy')) @@ -175,7 +203,7 @@ def test_raise_error_get_perf_and_time(params): results = np.linspace(-1, 1, 10) cum_times = np.linspace(0, 1, 10) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): _get_perf_and_time( cum_results=results, cum_times=cum_times, @@ -183,8 +211,6 @@ def test_raise_error_get_perf_and_time(params): worst_val=np.inf ) - assert excinfo._excinfo[0] == ValueError - @pytest.mark.parametrize('params', ( PlotSettingParams(n_points=20, xscale='linear', yscale='linear'),