diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 2ab4428d48a..3c4e3fbef29 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -104,7 +104,6 @@ class GenerationStrategy(GenerationStrategyInterface): # Experiment, for which this generation strategy has generated trials, if # it exists. _experiment: Optional[Experiment] = None - # Trial indices as last seen by the model; updated in `_model` property setter. _model: Optional[ModelBridge] = None # Current model. def __init__( @@ -376,6 +375,55 @@ def gen( **kwargs, )[0] + def gen_with_multiple_nodes( + self, + experiment: Experiment, + data: Optional[Data] = None, + n: int = 1, # Total arms to generate + pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None, + **kwargs: Any, # TODO: @mgarrard Ensure correct dispatch to nodes + ) -> List[GeneratorRun]: + """Produces a List of GeneratorRuns for a single trial, either ```Trial``` or + ```BatchTrial```, and if producing a ```BatchTrial`` allows for multiple + models to be used to generate GeneratorRuns for that trial. + + NOTE: This method is in development. Please do not use it yet. + + Args: + experiment: Experiment, for which the generation strategy is producing + a new generator run in the course of `gen`, and to which that + generator run will be added as trial(s). Information stored on the + experiment (e.g., trial statuses) is used to determine which model + will be used to produce the generator run returned from this method. + data: Optional data to be passed to the underlying model's `gen`, which + is called within this method and actually produces the resulting + generator run. By default, data is all data on the `experiment`. + n: Integer representing how total arms to generate for this trial. + pending_observations: A map from metric name to pending + observations for that metric, used by some models to avoid + resuggesting points that are currently being evaluated. + + Returns: + A list of ```GeneratorRuns``` for a single trial. + """ + # TODO: @mgarrard merge into gen method, just starting here to derisk + grs = [] + gen_for_complete_trial = False + + while not gen_for_complete_trial: + grs.extend( + self._gen_multiple( + experiment=experiment, + num_generator_runs=1, + data=data, + n=n, + pending_observations=pending_observations, + **kwargs, + ) + ) + gen_for_complete_trial = self._will_gen_for_new_trial() + return grs + def gen_for_multiple_trials_with_multiple_models( self, experiment: Experiment, @@ -708,6 +756,46 @@ def _gen_multiple( ) return generator_runs + def _will_gen_for_new_trial(self) -> bool: + """Uses the transition criterion defined on ```GenerationNode``` to determine + if generation for this trial is complete, and we should move to the next trial. + + Returns: + A boolean which represents if generation for a trial is complete + """ + # if no criterion defined, always move to the next trial + if len(self._curr.transition_criteria) == 0: + return True + + trial_moving_criterion = [ + tc for tc in self._curr.transition_criteria if tc.complete_trial_generation + ] + if len(trial_moving_criterion) > 0: + trial_tc = [ + tc for tc in trial_moving_criterion if tc.block_transition_if_unmet + ] + + if len(trial_tc) > 0: + return all( + tc.is_met( + experiment=self._curr.experiment, + trials_from_node=self._curr.trials_from_node, + node_name=self._curr.node_name, + node_that_generated_last_gr=( + self.last_generator_run._generation_node_name + if self.last_generator_run is not None + else None + ), + ) + for tc in trial_tc + ) + # no trial_tc is defined so move to next trial + return True + + # if no enforce_next_trial criterion is specified, default not moving to + # the next trial + return False + # ------------------------- Model selection logic helpers. ------------------------- def _fit_current_model(self, data: Optional[Data]) -> None: diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/modelbridge/tests/test_generation_node.py index 115d38be506..8cfe42db402 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/modelbridge/tests/test_generation_node.py @@ -172,7 +172,8 @@ def test_node_string_representation(self) -> None: "'only_in_statuses': [.RUNNING], " "'not_in_statuses': None, 'transition_to': None, " "'block_transition_if_unmet': True, 'block_gen_if_met': False, " - "'use_all_trials_in_exp': False})])" + "'use_all_trials_in_exp': False, " + "'complete_trial_generation': True})])" ), ) diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index e14e2def1d7..33fb55a4d62 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -1045,7 +1045,7 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features( GenerationStep( model=Models.BOTORCH_MODULAR, model_kwargs={ - # this will cause and error if the model + # this will cause an error if the model # doesn't get fixed features "transforms": ST_MTGP_trans, **self.step_model_kwargs, @@ -1445,34 +1445,45 @@ def test_node_gs_with_auto_transitions(self) -> None: node_name="sobol_2", model_specs=[self.sobol_model_spec], transition_criteria=[ - AutoTransitionAfterGenCriterion(transition_to="gpei_2") + AutoTransitionAfterGenCriterion(transition_to="sobol_3") ], ), GenerationNode( - node_name="gpei_2", - model_specs=[self.gpei_model_spec], + node_name="sobol_3", + model_specs=[self.sobol_model_spec], transition_criteria=[ - AutoTransitionAfterGenCriterion(transition_to="gpei") + AutoTransitionAfterGenCriterion( + transition_to="gpei", + block_transition_if_unmet=True, + complete_trial_generation=True, + ) ], ), ], ) exp = get_branin_experiment() + gs.experiment = exp + # for the first trial, we start on sobol, we generate the trial, but it hasn't + # been run yet, so we remain on sobol, after the trial is run, the subsequent + # trials should be from node gpei, sobol_2, and sobol_3 + self.assertEqual(gs.current_node_name, "sobol") + trial0 = exp.new_batch_trial(generator_runs=gs.gen_with_multiple_nodes(exp)) self.assertEqual(gs.current_node_name, "sobol") - exp.new_trial(generator_run=gs.gen(exp)).run() # while here, test the last generator run property on node self.assertEqual(gs.current_node.node_that_generated_last_gr, "sobol") - gs.gen(exp) - self.assertEqual(gs.current_node_name, "gpei") - gs.gen(exp) - self.assertEqual(gs.current_node_name, "sobol_2") - gs.gen(exp) - self.assertEqual(gs.current_node_name, "gpei_2") - gs.gen(exp) - self.assertEqual(gs.current_node_name, "gpei") - # TODO: @mgarrard modify below test when gen handles multiple nodes + trial0.run() + for _i in range(0, 2): + trial = exp.new_batch_trial(generator_runs=gs.gen_with_multiple_nodes(exp)) + self.assertEqual(gs.current_node_name, "sobol_3") + self.assertEqual(len(trial.generator_runs), 3) + self.assertEqual(trial.generator_runs[0]._generation_node_name, "gpei") + self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_2") + self.assertEqual(trial.generator_runs[2]._generation_node_name, "sobol_3") + + def test_node_gs_with_auto_transitions_three_phase(self) -> None: + exp = get_branin_experiment() gs_2 = GenerationStrategy( nodes=[ GenerationNode( @@ -1484,44 +1495,65 @@ def test_node_gs_with_auto_transitions(self) -> None: node_name="gpei", model_specs=[self.gpei_model_spec], transition_criteria=[ - AutoTransitionAfterGenCriterion(transition_to="sobol_2") + AutoTransitionAfterGenCriterion( + transition_to="sobol_2", + ) ], ), GenerationNode( node_name="sobol_2", model_specs=[self.sobol_model_spec], transition_criteria=[ - AutoTransitionAfterGenCriterion(transition_to="gpei_2") + AutoTransitionAfterGenCriterion(transition_to="sobol_3") ], ), GenerationNode( - node_name="gpei_2", - model_specs=[self.gpei_model_spec], + node_name="sobol_3", + model_specs=[self.sobol_model_spec], transition_criteria=[ MaxTrials( threshold=2, - transition_to="sobol_3", + transition_to="sobol_4", block_transition_if_unmet=True, only_in_statuses=[TrialStatus.RUNNING], use_all_trials_in_exp=True, - ) + ), + AutoTransitionAfterGenCriterion(transition_to="gpei"), ], ), GenerationNode( - node_name="sobol_3", + node_name="sobol_4", model_specs=[self.sobol_model_spec], ), ], ) + gs_2.experiment = exp + + # for the first trial, we start on sobol, we generate the trial, but it hasn't + # been run yet, so we remain on sobol + self.assertEqual(gs_2.current_node_name, "sobol") + trial0 = exp.new_batch_trial(generator_runs=gs_2.gen_with_multiple_nodes(exp)) self.assertEqual(gs_2.current_node_name, "sobol") - exp.new_trial(generator_run=gs_2.gen(exp)).run() - self.assertEqual(gs_2.current_node_name, "gpei") - gs_2.gen(exp) - self.assertEqual(gs_2.current_node_name, "sobol_2") - gs_2.gen(exp) # noqa - self.assertEqual(gs_2.current_node_name, "gpei_2") - exp.new_trial(generator_run=gs_2.gen(exp)).run() - self.assertEqual(gs_2.current_node_name, "sobol_3") + trial0.run() + + # after trial 0 is run, we create a trial with nodes gpei, sobol_2, and sobol_3 + # However, the sobol_3 criterion requires that we have two running trials. We + # don't move onto sobol_4 until we have two running trials, instead we reset + # to the last first node in a trial. + for _i in range(0, 2): + trial = exp.new_batch_trial( + generator_runs=gs_2.gen_with_multiple_nodes(exp) + ) + self.assertEqual(gs_2.current_node_name, "sobol_3") + self.assertEqual(len(trial.generator_runs), 3) + self.assertEqual(trial.generator_runs[0]._generation_node_name, "gpei") + self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_2") + self.assertEqual(trial.generator_runs[2]._generation_node_name, "sobol_3") + + # after running the next trial should be made from sobol 4 + trial.run() + trial = exp.new_batch_trial(generator_runs=gs_2.gen_with_multiple_nodes(exp)) + self.assertEqual(trial.generator_runs[0]._generation_node_name, "sobol_4") # ------------- Testing helpers (put tests above this line) ------------- diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/modelbridge/tests/test_transition_criterion.py index 5ed9fafdac8..b88a0d63f91 100644 --- a/ax/modelbridge/tests/test_transition_criterion.py +++ b/ax/modelbridge/tests/test_transition_criterion.py @@ -428,7 +428,8 @@ def test_repr(self) -> None: + "'transition_to': 'GenerationStep_1', " + "'block_transition_if_unmet': False, " + "'block_gen_if_met': True, " - + "'use_all_trials_in_exp': False})", + + "'use_all_trials_in_exp': False, " + + "'complete_trial_generation': True})", ) minimum_trials_in_status_criterion = MinTrials( only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], @@ -446,7 +447,8 @@ def test_repr(self) -> None: + "'transition_to': 'GenerationStep_2', " + "'block_transition_if_unmet': False, " + "'block_gen_if_met': True, " - + "'use_all_trials_in_exp': False})", + + "'use_all_trials_in_exp': False, " + + "'complete_trial_generation': True})", ) minimum_preference_occurances_criterion = MinimumPreferenceOccurances( metric_name="m1", threshold=3 @@ -483,12 +485,15 @@ def test_repr(self) -> None: + "'transition_to': 'GenerationStep_2', " + "'block_transition_if_unmet': False, " + "'block_gen_if_met': True, " - + "'use_all_trials_in_exp': False})", + + "'use_all_trials_in_exp': False, " + + "'complete_trial_generation': False})", ) auto_transition = AutoTransitionAfterGenCriterion( transition_to="GenerationStep_2" ) self.assertEqual( str(auto_transition), - "AutoTransitionAfterGenCriterion({'transition_to': 'GenerationStep_2'})", + "AutoTransitionAfterGenCriterion({'transition_to': 'GenerationStep_2', " + + "'complete_trial_generation': False, " + + "'block_transition_if_unmet': True})", ) diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index 3ca78041d43..eae8c9dbbd8 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -40,6 +40,12 @@ class TransitionCriterion(SortableBase, SerializationMixin): being able to transition to another node. Ex: MaxGenerationParallelism defaults to setting this to False since we can complete and move on from this node without ever reaching its threshold. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. """ _transition_to: Optional[str] = None @@ -49,10 +55,12 @@ def __init__( transition_to: Optional[str] = None, block_transition_if_unmet: Optional[bool] = True, block_gen_if_met: Optional[bool] = False, + complete_trial_generation: Optional[bool] = True, ) -> None: self._transition_to = transition_to self.block_transition_if_unmet = block_transition_if_unmet self.block_gen_if_met = block_gen_if_met + self.complete_trial_generation = complete_trial_generation @property def transition_to(self) -> Optional[str]: @@ -104,10 +112,29 @@ class AutoTransitionAfterGenCriterion(TransitionCriterion): Args: transition_to: The name of the GenerationNode the GenerationStrategy should transition to next. + block_transition_if_unmet: A flag to prevent the node from completing and + being able to transition to another node. Ex: This criterion defaults to + setting this to True to ensure we validate a GeneratorRun is generated by + the current GenerationNode. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. """ - def __init__(self, transition_to: str) -> None: - super().__init__(transition_to=transition_to) + def __init__( + self, + transition_to: str, + complete_trial_generation: Optional[bool] = False, + block_transition_if_unmet: Optional[bool] = True, + ) -> None: + super().__init__( + transition_to=transition_to, + complete_trial_generation=complete_trial_generation, + block_transition_if_unmet=block_transition_if_unmet, + ) def is_met( self, @@ -116,7 +143,9 @@ def is_met( node_that_generated_last_gr: Optional[str] = None, node_name: Optional[str] = None, ) -> bool: - """Return true as soon as any trial is generated by this GenerationNode.""" + """Return True as soon as any GeneratorRun is generated by this + GenerationNode. + """ return node_that_generated_last_gr == node_name def block_continued_generation_error( @@ -154,6 +183,12 @@ class TrialBasedCriterion(TransitionCriterion): transition to when this criterion is met, if it exists. use_all_trials_in_exp: A flag to use all trials in the experiment, instead of only those generated by the current GenerationNode. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. """ def __init__( @@ -165,6 +200,7 @@ def __init__( not_in_statuses: Optional[List[TrialStatus]] = None, transition_to: Optional[str] = None, use_all_trials_in_exp: Optional[bool] = False, + complete_trial_generation: Optional[bool] = True, ) -> None: self.threshold = threshold self.only_in_statuses = only_in_statuses @@ -174,6 +210,7 @@ def __init__( transition_to=transition_to, block_transition_if_unmet=block_transition_if_unmet, block_gen_if_met=block_gen_if_met, + complete_trial_generation=complete_trial_generation, ) def experiment_trials_by_status( @@ -306,6 +343,14 @@ class MaxGenerationParallelism(TrialBasedCriterion): until MinimumTrialsInStatus is met (thus overriding MaxTrials). use_all_trials_in_exp: A flag to use all trials in the experiment, instead of only those generated by the current GenerationNode. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. Defautls to False for + MaxGenerationParallelism since this criterion isn't currently used for + node -> node or trial -> trial transition. """ def __init__( @@ -317,6 +362,7 @@ def __init__( block_transition_if_unmet: Optional[bool] = False, block_gen_if_met: Optional[bool] = True, use_all_trials_in_exp: Optional[bool] = False, + complete_trial_generation: Optional[bool] = False, ) -> None: super().__init__( threshold=threshold, @@ -326,6 +372,7 @@ def __init__( block_gen_if_met=block_gen_if_met, block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, + complete_trial_generation=complete_trial_generation, ) def block_continued_generation_error( @@ -380,6 +427,12 @@ class MaxTrials(TrialBasedCriterion): until MinimumTrialsInStatus is met (thus overriding MaxTrials). use_all_trials_in_exp: A flag to use all trials in the experiment, instead of only those generated by the current GenerationNode. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. """ def __init__( @@ -391,6 +444,7 @@ def __init__( block_transition_if_unmet: Optional[bool] = True, block_gen_if_met: Optional[bool] = False, use_all_trials_in_exp: Optional[bool] = False, + complete_trial_generation: Optional[bool] = True, ) -> None: super().__init__( threshold=threshold, @@ -400,6 +454,7 @@ def __init__( block_gen_if_met=block_gen_if_met, block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, + complete_trial_generation=complete_trial_generation, ) def block_continued_generation_error( @@ -448,6 +503,12 @@ class MinTrials(TrialBasedCriterion): until MinimumTrialsInStatus is met (thus overriding MaxTrials). use_all_trials_in_exp: A flag to use all trials in the experiment, instead of only those generated by the current GenerationNode. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. """ def __init__( @@ -459,6 +520,7 @@ def __init__( block_transition_if_unmet: Optional[bool] = True, block_gen_if_met: Optional[bool] = False, use_all_trials_in_exp: Optional[bool] = False, + complete_trial_generation: Optional[bool] = True, ) -> None: super().__init__( threshold=threshold, @@ -468,6 +530,7 @@ def __init__( block_gen_if_met=block_gen_if_met, block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, + complete_trial_generation=complete_trial_generation, ) def block_continued_generation_error(