Skip to content

Commit

Permalink
Add Scheduler.compute_analyses method (#2660)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2660

Adds a method for bulk computing Analyses from a Scheduler instance and save the cards to the DB. This can be used to generate a "report" as needed.

Notes:
* If any Analysis fails to compute and Error is logged, but compute_analyses otherwise succeeds. Down the line we may wish to accumulate the errors into their own ErrorCard
* If not Analyses are provided to compute_analyses, _choose_analyses will be called to infer which analyses should be computed for the given Scheduler based on the Experiment, GenerationStrategy, etc. Since only ParallelCoordinates is implemented right now this always returns [ParallelCoordinates()]

Reviewed By: Cesar-Cardoso

Differential Revision: D61338432

fbshipit-source-id: 279c499be092f28fd5542cc14f70f484a1f3b57c
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Aug 21, 2024
1 parent 6e7e798 commit 84b307d
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 1 deletion.
34 changes: 34 additions & 0 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from typing import Any, Callable, cast, NamedTuple, Optional

import ax.service.utils.early_stopping as early_stopping_utils
from ax.analysis.analysis import Analysis, AnalysisCard
from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import (
ParallelCoordinatesPlot,
)
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
Expand Down Expand Up @@ -1768,6 +1772,36 @@ def generate_candidates(
)
return new_trials

def compute_analyses(
self, analyses: Optional[Iterable[Analysis]] = None
) -> list[AnalysisCard]:
analyses = analyses if analyses is not None else self._choose_analyses()

results = [
analysis.compute_result(
experiment=self.experiment, generation_strategy=self.generation_strategy
)
for analysis in analyses
]

# TODO Accumulate Es into their own card, perhaps via unwrap_or_else
cards = [result.unwrap() for result in results if result.is_ok()]

self._save_analysis_cards_to_db_if_possible(
analysis_cards=cards,
experiment=self.experiment,
)

return cards

def _choose_analyses(self) -> list[Analysis]:
"""
Choose Analyses to compute based on the Experiment, GenerationStrategy, etc.
"""

# TODO Create a useful heuristic for choosing analyses
return [ParallelCoordinatesPlot()]

def _gen_new_trials_from_generation_strategy(
self,
num_trials: int,
Expand Down
46 changes: 46 additions & 0 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from unittest.mock import call, Mock, patch, PropertyMock

import pandas as pd
from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import (
ParallelCoordinatesPlot,
)

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
Expand Down Expand Up @@ -84,6 +87,7 @@
SpecialGenerationStrategy,
)
from ax.utils.testing.mock import fast_botorch_optimize
from ax.utils.testing.modeling_stubs import get_generation_strategy
from pyre_extensions import none_throws

from sqlalchemy.orm.exc import StaleDataError
Expand Down Expand Up @@ -2502,3 +2506,45 @@ def test_generate_candidates_does_not_generate_if_overconstrained(self) -> None:
1,
str(scheduler.experiment.trials),
)

def test_compute_analyses(self) -> None:
scheduler = Scheduler(
experiment=get_branin_experiment(with_completed_trial=True),
generation_strategy=get_generation_strategy(),
options=SchedulerOptions(
total_trials=0,
tolerated_trial_failure_rate=0.2,
init_seconds_between_polls=10,
),
)

cards = scheduler.compute_analyses(analyses=[ParallelCoordinatesPlot()])

self.assertEqual(len(cards), 1)
self.assertEqual(cards[0].name, "ParallelCoordinatesPlot(metric_name=None)")

scheduler = Scheduler(
experiment=get_branin_experiment(with_completed_trial=False),
generation_strategy=get_generation_strategy(),
options=SchedulerOptions(
total_trials=0,
tolerated_trial_failure_rate=0.2,
init_seconds_between_polls=10,
),
)

with self.assertLogs(logger="ax.analysis", level="ERROR") as lg:

cards = scheduler.compute_analyses(analyses=[ParallelCoordinatesPlot()])

self.assertEqual(len(cards), 0)
self.assertTrue(
any(
(
"Failed to compute ParallelCoordinatesPlot(metric_name=None): "
"No data found for metric branin"
)
in msg
for msg in lg.output
)
)
38 changes: 37 additions & 1 deletion ax/service/utils/with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import time

from logging import INFO, Logger
from typing import Any, Optional
from typing import Any, Iterable, Optional

from ax.analysis.analysis import AnalysisCard

from ax.core.base_trial import BaseTrial
from ax.core.experiment import Experiment
Expand All @@ -22,6 +24,7 @@
UnsupportedError,
)
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.storage.sqa_store.save import save_analysis_cards
from ax.utils.common.executils import retry_on_exception
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import not_none
Expand Down Expand Up @@ -468,6 +471,21 @@ def _update_experiment_properties_in_db(
return True
return False

def _save_analysis_cards_to_db_if_possible(
self,
experiment: Experiment,
analysis_cards: Iterable[AnalysisCard],
) -> bool:
if self.db_settings_set:
_save_analysis_cards_to_db_if_possible(
experiment=experiment,
analysis_cards=analysis_cards,
config=self.db_settings.encoder.config,
)
return True

return False


# ------------- Utils for storage that assume `DBSettings` are provided --------

Expand Down Expand Up @@ -590,3 +608,21 @@ def _update_experiment_properties_in_db(
experiment_with_updated_properties=experiment_with_updated_properties,
config=sqa_config,
)


@retry_on_exception(
retries=3,
default_return_on_suppression=False,
exception_types=RETRY_EXCEPTION_TYPES,
)
def _save_analysis_cards_to_db_if_possible(
experiment: Experiment,
analysis_cards: Iterable[AnalysisCard],
sqa_config: SQAConfig,
suppress_all_errors: bool, # Used by the decorator.
) -> None:
save_analysis_cards(
experiment=experiment,
analysis_cards=[*analysis_cards],
config=sqa_config,
)

0 comments on commit 84b307d

Please sign in to comment.