From 84b307def75c65072bf4a52892d6fdf1998e6d23 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 21 Aug 2024 09:29:28 -0700 Subject: [PATCH] Add Scheduler.compute_analyses method (#2660) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2660 Adds a method for bulk computing Analyses from a Scheduler instance and save the cards to the DB. This can be used to generate a "report" as needed. Notes: * If any Analysis fails to compute and Error is logged, but compute_analyses otherwise succeeds. Down the line we may wish to accumulate the errors into their own ErrorCard * If not Analyses are provided to compute_analyses, _choose_analyses will be called to infer which analyses should be computed for the given Scheduler based on the Experiment, GenerationStrategy, etc. Since only ParallelCoordinates is implemented right now this always returns [ParallelCoordinates()] Reviewed By: Cesar-Cardoso Differential Revision: D61338432 fbshipit-source-id: 279c499be092f28fd5542cc14f70f484a1f3b57c --- ax/service/scheduler.py | 34 +++++++++++++++++ ax/service/tests/scheduler_test_utils.py | 46 +++++++++++++++++++++++ ax/service/utils/with_db_settings_base.py | 38 ++++++++++++++++++- 3 files changed, 117 insertions(+), 1 deletion(-) diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 7bf156b0044..4c3da077f8b 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -18,6 +18,10 @@ from typing import Any, Callable, cast, NamedTuple, Optional import ax.service.utils.early_stopping as early_stopping_utils +from ax.analysis.analysis import Analysis, AnalysisCard +from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import ( + ParallelCoordinatesPlot, +) from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface @@ -1768,6 +1772,36 @@ def generate_candidates( ) return new_trials + def compute_analyses( + self, analyses: Optional[Iterable[Analysis]] = None + ) -> list[AnalysisCard]: + analyses = analyses if analyses is not None else self._choose_analyses() + + results = [ + analysis.compute_result( + experiment=self.experiment, generation_strategy=self.generation_strategy + ) + for analysis in analyses + ] + + # TODO Accumulate Es into their own card, perhaps via unwrap_or_else + cards = [result.unwrap() for result in results if result.is_ok()] + + self._save_analysis_cards_to_db_if_possible( + analysis_cards=cards, + experiment=self.experiment, + ) + + return cards + + def _choose_analyses(self) -> list[Analysis]: + """ + Choose Analyses to compute based on the Experiment, GenerationStrategy, etc. + """ + + # TODO Create a useful heuristic for choosing analyses + return [ParallelCoordinatesPlot()] + def _gen_new_trials_from_generation_strategy( self, num_trials: int, diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 4b405ba0f59..ae6187fbca7 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -17,6 +17,9 @@ from unittest.mock import call, Mock, patch, PropertyMock import pandas as pd +from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import ( + ParallelCoordinatesPlot, +) from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus @@ -84,6 +87,7 @@ SpecialGenerationStrategy, ) from ax.utils.testing.mock import fast_botorch_optimize +from ax.utils.testing.modeling_stubs import get_generation_strategy from pyre_extensions import none_throws from sqlalchemy.orm.exc import StaleDataError @@ -2502,3 +2506,45 @@ def test_generate_candidates_does_not_generate_if_overconstrained(self) -> None: 1, str(scheduler.experiment.trials), ) + + def test_compute_analyses(self) -> None: + scheduler = Scheduler( + experiment=get_branin_experiment(with_completed_trial=True), + generation_strategy=get_generation_strategy(), + options=SchedulerOptions( + total_trials=0, + tolerated_trial_failure_rate=0.2, + init_seconds_between_polls=10, + ), + ) + + cards = scheduler.compute_analyses(analyses=[ParallelCoordinatesPlot()]) + + self.assertEqual(len(cards), 1) + self.assertEqual(cards[0].name, "ParallelCoordinatesPlot(metric_name=None)") + + scheduler = Scheduler( + experiment=get_branin_experiment(with_completed_trial=False), + generation_strategy=get_generation_strategy(), + options=SchedulerOptions( + total_trials=0, + tolerated_trial_failure_rate=0.2, + init_seconds_between_polls=10, + ), + ) + + with self.assertLogs(logger="ax.analysis", level="ERROR") as lg: + + cards = scheduler.compute_analyses(analyses=[ParallelCoordinatesPlot()]) + + self.assertEqual(len(cards), 0) + self.assertTrue( + any( + ( + "Failed to compute ParallelCoordinatesPlot(metric_name=None): " + "No data found for metric branin" + ) + in msg + for msg in lg.output + ) + ) diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index 0d0ee3b5631..65d4d1176df 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -10,7 +10,9 @@ import time from logging import INFO, Logger -from typing import Any, Optional +from typing import Any, Iterable, Optional + +from ax.analysis.analysis import AnalysisCard from ax.core.base_trial import BaseTrial from ax.core.experiment import Experiment @@ -22,6 +24,7 @@ UnsupportedError, ) from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.storage.sqa_store.save import save_analysis_cards from ax.utils.common.executils import retry_on_exception from ax.utils.common.logger import _round_floats_for_logging, get_logger from ax.utils.common.typeutils import not_none @@ -468,6 +471,21 @@ def _update_experiment_properties_in_db( return True return False + def _save_analysis_cards_to_db_if_possible( + self, + experiment: Experiment, + analysis_cards: Iterable[AnalysisCard], + ) -> bool: + if self.db_settings_set: + _save_analysis_cards_to_db_if_possible( + experiment=experiment, + analysis_cards=analysis_cards, + config=self.db_settings.encoder.config, + ) + return True + + return False + # ------------- Utils for storage that assume `DBSettings` are provided -------- @@ -590,3 +608,21 @@ def _update_experiment_properties_in_db( experiment_with_updated_properties=experiment_with_updated_properties, config=sqa_config, ) + + +@retry_on_exception( + retries=3, + default_return_on_suppression=False, + exception_types=RETRY_EXCEPTION_TYPES, +) +def _save_analysis_cards_to_db_if_possible( + experiment: Experiment, + analysis_cards: Iterable[AnalysisCard], + sqa_config: SQAConfig, + suppress_all_errors: bool, # Used by the decorator. +) -> None: + save_analysis_cards( + experiment=experiment, + analysis_cards=[*analysis_cards], + config=sqa_config, + )