diff --git a/ax/core/generator_run.py b/ax/core/generator_run.py index 3308b886ada..66930350c89 100644 --- a/ax/core/generator_run.py +++ b/ax/core/generator_run.py @@ -175,15 +175,7 @@ def __init__( "one is provided." ) for arm, weight in zip(arms, weights): - existing_cw = self._arm_weight_table.get(arm.signature) - if existing_cw: - self._arm_weight_table[arm.signature] = ArmWeight( - arm=arm, weight=existing_cw.weight + weight - ) - else: - self._arm_weight_table[arm.signature] = ArmWeight( - arm=arm, weight=weight - ) + self.add_arm(arm=arm, weight=weight) self._generator_run_type: Optional[str] = type self._time_created: datetime = datetime.now() @@ -394,6 +386,22 @@ def clone(self) -> GeneratorRun: ) return generator_run + def add_arm(self, arm: Arm, weight: float = 1.0) -> None: + """Adds an arm to this generator run. This should not be used to + mutate generator runs that are attached to trials. + + Args: + arm: The arm to add. + weight: The weight to associate with the arm. + """ + existing_cw = self._arm_weight_table.get(arm.signature) + if existing_cw: + self._arm_weight_table[arm.signature] = ArmWeight( + arm=arm, weight=existing_cw.weight + weight + ) + else: + self._arm_weight_table[arm.signature] = ArmWeight(arm=arm, weight=weight) + def __repr__(self) -> str: """String representation of a GeneratorRun.""" class_name = self.__class__.__name__ diff --git a/ax/core/utils.py b/ax/core/utils.py index 01b5d55ec08..38692ba91e7 100644 --- a/ax/core/utils.py +++ b/ax/core/utils.py @@ -157,15 +157,11 @@ def best_feasible_objective( return accumulate(f) -def _extract_generator_run(trial: BaseTrial) -> GeneratorRun: +def _extract_generator_runs(trial: BaseTrial) -> List[GeneratorRun]: if isinstance(trial, BatchTrial): - if len(trial.generator_run_structs) > 1: - raise NotImplementedError( - "Run time is not supported with multiple generator runs per trial." - ) - return trial._generator_run_structs[0].generator_run + return trial.generator_runs if isinstance(trial, Trial): - return none_throws(trial.generator_run) + return [none_throws(trial.generator_run)] raise ValueError("Unexpected trial type") @@ -180,7 +176,9 @@ def get_model_trace_of_times( List of fit times, list of gen times. """ generator_runs = [ - _extract_generator_run(trial=trial) for trial in experiment.trials.values() + gr + for trial in experiment.trials.values() + for gr in _extract_generator_runs(trial=trial) ] fit_times = [gr.fit_time for gr in generator_runs] gen_times = [gr.gen_time for gr in generator_runs] diff --git a/ax/telemetry/experiment.py b/ax/telemetry/experiment.py index eaba3902d04..04e2f02957b 100644 --- a/ax/telemetry/experiment.py +++ b/ax/telemetry/experiment.py @@ -26,7 +26,7 @@ from ax.telemetry.common import INITIALIZATION_MODELS, OTHER_MODELS INITIALIZATION_MODEL_STRS: List[str] = [enum.value for enum in INITIALIZATION_MODELS] -OTHER_MODEL_STRS: List[str] = [enum.value for enum in OTHER_MODELS] +OTHER_MODEL_STRS: List[str] = [enum.value for enum in OTHER_MODELS] + [None] @dataclass(frozen=True) @@ -272,25 +272,34 @@ def from_experiment(cls, experiment: Experiment) -> ExperimentCompletedRecord: } model_keys = [ - trial.generator_runs[0]._model_key for trial in experiment.trials.values() + [gr._model_key for gr in trial.generator_runs] + for trial in experiment.trials.values() ] fit_time, gen_time = get_model_times(experiment=experiment) return cls( num_initialization_trials=sum( - 1 for model_key in model_keys if model_key in INITIALIZATION_MODEL_STRS + 1 + for model_key_list in model_keys + if all( + model_key in INITIALIZATION_MODEL_STRS + for model_key in model_key_list + ) ), num_bayesopt_trials=sum( 1 - for model_key in model_keys - if not ( - model_key in INITIALIZATION_MODEL_STRS - or model_key in OTHER_MODEL_STRS + for model_key_list in model_keys + if any( + model_key not in INITIALIZATION_MODEL_STRS + and model_key not in OTHER_MODEL_STRS + for model_key in model_key_list ) ), num_other_trials=sum( - 1 for model_key in model_keys if model_key in OTHER_MODEL_STRS + 1 + for model_key_list in model_keys + if all(model_key in OTHER_MODEL_STRS for model_key in model_key_list) ), num_completed_trials=trial_count_by_status[TrialStatus.COMPLETED], num_failed_trials=trial_count_by_status[TrialStatus.FAILED], diff --git a/ax/telemetry/tests/test_experiment.py b/ax/telemetry/tests/test_experiment.py index 684e1a4c55a..852e7e8b3c0 100644 --- a/ax/telemetry/tests/test_experiment.py +++ b/ax/telemetry/tests/test_experiment.py @@ -7,9 +7,14 @@ # pyre-strict from ax.core.utils import get_model_times +from ax.modelbridge.registry import Models from ax.telemetry.experiment import ExperimentCompletedRecord, ExperimentCreatedRecord from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_experiment_with_custom_runner_and_metric +from ax.utils.testing.core_stubs import ( + get_branin_experiment, + get_experiment_with_custom_runner_and_metric, +) +from ax.utils.testing.mock import fast_botorch_optimize class TestExperiment(TestCase): @@ -63,3 +68,47 @@ def test_experiment_completed_record_from_experiment(self) -> None: total_gen_time=int(gen_time), ) self.assertEqual(record, expected) + + @fast_botorch_optimize + def test_bayesopt_trials_are_trials_containing_bayesopt(self) -> None: + experiment = get_branin_experiment() + sobol = Models.SOBOL(search_space=experiment.search_space) + trial = experiment.new_batch_trial().add_generator_run( + generator_run=sobol.gen(5) + ) + trial.mark_completed(unsafe=True) + + # create a trial that among other things does bayesopt + botorch = Models.BOTORCH_MODULAR( + experiment=experiment, data=experiment.fetch_data() + ) + trial = ( + experiment.new_batch_trial() + .add_generator_run(generator_run=sobol.gen(2)) + .add_generator_run(generator_run=botorch.gen(5)) + ) + trial.add_arm(experiment.arms_by_name["0_0"]) + trial.mark_completed(unsafe=True) + + record = ExperimentCompletedRecord.from_experiment(experiment=experiment) + self.assertEqual(record.num_initialization_trials, 1) + self.assertEqual(record.num_bayesopt_trials, 1) + self.assertEqual(record.num_other_trials, 0) + + def test_other_trials_are_trials_with_no_models(self) -> None: + experiment = get_branin_experiment() + sobol = Models.SOBOL(search_space=experiment.search_space) + trial = experiment.new_batch_trial().add_generator_run( + generator_run=sobol.gen(5) + ) + trial.mark_completed(unsafe=True) + + # create a trial that has no GRs that used models + trial = experiment.new_batch_trial() + trial.add_arm(experiment.arms_by_name["0_0"]) + trial.mark_completed(unsafe=True) + + record = ExperimentCompletedRecord.from_experiment(experiment=experiment) + self.assertEqual(record.num_initialization_trials, 1) + self.assertEqual(record.num_bayesopt_trials, 0) + self.assertEqual(record.num_other_trials, 1)