diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 7bf156b0044..4c3da077f8b 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -18,6 +18,10 @@ from typing import Any, Callable, cast, NamedTuple, Optional import ax.service.utils.early_stopping as early_stopping_utils +from ax.analysis.analysis import Analysis, AnalysisCard +from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import ( + ParallelCoordinatesPlot, +) from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface @@ -1768,6 +1772,36 @@ def generate_candidates( ) return new_trials + def compute_analyses( + self, analyses: Optional[Iterable[Analysis]] = None + ) -> list[AnalysisCard]: + analyses = analyses if analyses is not None else self._choose_analyses() + + results = [ + analysis.compute_result( + experiment=self.experiment, generation_strategy=self.generation_strategy + ) + for analysis in analyses + ] + + # TODO Accumulate Es into their own card, perhaps via unwrap_or_else + cards = [result.unwrap() for result in results if result.is_ok()] + + self._save_analysis_cards_to_db_if_possible( + analysis_cards=cards, + experiment=self.experiment, + ) + + return cards + + def _choose_analyses(self) -> list[Analysis]: + """ + Choose Analyses to compute based on the Experiment, GenerationStrategy, etc. + """ + + # TODO Create a useful heuristic for choosing analyses + return [ParallelCoordinatesPlot()] + def _gen_new_trials_from_generation_strategy( self, num_trials: int, diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 4b405ba0f59..ae6187fbca7 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -17,6 +17,9 @@ from unittest.mock import call, Mock, patch, PropertyMock import pandas as pd +from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import ( + ParallelCoordinatesPlot, +) from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus @@ -84,6 +87,7 @@ SpecialGenerationStrategy, ) from ax.utils.testing.mock import fast_botorch_optimize +from ax.utils.testing.modeling_stubs import get_generation_strategy from pyre_extensions import none_throws from sqlalchemy.orm.exc import StaleDataError @@ -2502,3 +2506,45 @@ def test_generate_candidates_does_not_generate_if_overconstrained(self) -> None: 1, str(scheduler.experiment.trials), ) + + def test_compute_analyses(self) -> None: + scheduler = Scheduler( + experiment=get_branin_experiment(with_completed_trial=True), + generation_strategy=get_generation_strategy(), + options=SchedulerOptions( + total_trials=0, + tolerated_trial_failure_rate=0.2, + init_seconds_between_polls=10, + ), + ) + + cards = scheduler.compute_analyses(analyses=[ParallelCoordinatesPlot()]) + + self.assertEqual(len(cards), 1) + self.assertEqual(cards[0].name, "ParallelCoordinatesPlot(metric_name=None)") + + scheduler = Scheduler( + experiment=get_branin_experiment(with_completed_trial=False), + generation_strategy=get_generation_strategy(), + options=SchedulerOptions( + total_trials=0, + tolerated_trial_failure_rate=0.2, + init_seconds_between_polls=10, + ), + ) + + with self.assertLogs(logger="ax.analysis", level="ERROR") as lg: + + cards = scheduler.compute_analyses(analyses=[ParallelCoordinatesPlot()]) + + self.assertEqual(len(cards), 0) + self.assertTrue( + any( + ( + "Failed to compute ParallelCoordinatesPlot(metric_name=None): " + "No data found for metric branin" + ) + in msg + for msg in lg.output + ) + ) diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index 0d0ee3b5631..65d4d1176df 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -10,7 +10,9 @@ import time from logging import INFO, Logger -from typing import Any, Optional +from typing import Any, Iterable, Optional + +from ax.analysis.analysis import AnalysisCard from ax.core.base_trial import BaseTrial from ax.core.experiment import Experiment @@ -22,6 +24,7 @@ UnsupportedError, ) from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.storage.sqa_store.save import save_analysis_cards from ax.utils.common.executils import retry_on_exception from ax.utils.common.logger import _round_floats_for_logging, get_logger from ax.utils.common.typeutils import not_none @@ -468,6 +471,21 @@ def _update_experiment_properties_in_db( return True return False + def _save_analysis_cards_to_db_if_possible( + self, + experiment: Experiment, + analysis_cards: Iterable[AnalysisCard], + ) -> bool: + if self.db_settings_set: + _save_analysis_cards_to_db_if_possible( + experiment=experiment, + analysis_cards=analysis_cards, + config=self.db_settings.encoder.config, + ) + return True + + return False + # ------------- Utils for storage that assume `DBSettings` are provided -------- @@ -590,3 +608,21 @@ def _update_experiment_properties_in_db( experiment_with_updated_properties=experiment_with_updated_properties, config=sqa_config, ) + + +@retry_on_exception( + retries=3, + default_return_on_suppression=False, + exception_types=RETRY_EXCEPTION_TYPES, +) +def _save_analysis_cards_to_db_if_possible( + experiment: Experiment, + analysis_cards: Iterable[AnalysisCard], + sqa_config: SQAConfig, + suppress_all_errors: bool, # Used by the decorator. +) -> None: + save_analysis_cards( + experiment=experiment, + analysis_cards=[*analysis_cards], + config=sqa_config, + )