From e500e77519400733650fdb94924d8f4f8ba77e65 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Mon, 13 May 2024 18:10:26 -0700 Subject: [PATCH] model <> node 4/n: Support multiple node generation for single trial (#2428) Summary: This diff enables multiple nodes to be used to generate a single batch trial. Right now the limitations are that: (1) currently each node only contributes 1 gr to the trial -- we will extend this to n grs in a future diff To do this we added: (1) should_move_trials method (2) updates to the transition criterion args Reviewed By: saitcakmak, lena-kashtelyan Differential Revision: D56743651 --- ax/modelbridge/generation_strategy.py | 82 +++++++++++++++- ax/modelbridge/tests/test_generation_node.py | 3 +- .../tests/test_generation_strategy.py | 94 +++++++++++++------ .../tests/test_transition_criterion.py | 13 ++- ax/modelbridge/transition_criterion.py | 69 +++++++++++++- 5 files changed, 222 insertions(+), 39 deletions(-) diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 2ab4428d48a..bacae17983f 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -104,7 +104,6 @@ class GenerationStrategy(GenerationStrategyInterface): # Experiment, for which this generation strategy has generated trials, if # it exists. _experiment: Optional[Experiment] = None - # Trial indices as last seen by the model; updated in `_model` property setter. _model: Optional[ModelBridge] = None # Current model. def __init__( @@ -376,6 +375,55 @@ def gen( **kwargs, )[0] + def gen_with_multiple_nodes( + self, + experiment: Experiment, + data: Optional[Data] = None, + n: int = 1, # Total arms to generate + pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None, + **kwargs: Any, # TODO: @mgarrard Ensure correct dispatch to nodes + ) -> List[GeneratorRun]: + """Produces a List of GeneratorRuns for a single trial, either ```Trial``` or + ```BatchTrial```, and if producing a ```BatchTrial`` allows for multiple + models to be used to generate GeneratorRuns for that trial. + + NOTE: This method is in development. Please do not use it yet. + + Args: + experiment: Experiment, for which the generation strategy is producing + a new generator run in the course of `gen`, and to which that + generator run will be added as trial(s). Information stored on the + experiment (e.g., trial statuses) is used to determine which model + will be used to produce the generator run returned from this method. + data: Optional data to be passed to the underlying model's `gen`, which + is called within this method and actually produces the resulting + generator run. By default, data is all data on the `experiment`. + n: Integer representing how total arms to generate for this trial. + pending_observations: A map from metric name to pending + observations for that metric, used by some models to avoid + resuggesting points that are currently being evaluated. + + Returns: + A list of ```GeneratorRuns``` for a single trial. + """ + # TODO: @mgarrard merge into gen method, just starting here to derisk + grs = [] + continue_gen_for_trial = True + + while continue_gen_for_trial: + grs.extend( + self._gen_multiple( + experiment=experiment, + num_generator_runs=1, + data=data, + n=n, + pending_observations=pending_observations, + **kwargs, + ) + ) + continue_gen_for_trial = self._should_continue_gen_for_trial() + return grs + def gen_for_multiple_trials_with_multiple_models( self, experiment: Experiment, @@ -544,6 +592,13 @@ def _validate_and_set_node_graph(self, nodes: List[GenerationNode]) -> None: node._generation_strategy = self # validate `transition_criterion` + # TODO[@mgarrard]: Validate that all TCs in one "transition edge" (so all TCs + # from one node to another) have the same `continue_trial_generation` setting. + # Since multiple TCs together constitute one "transition edge", not having all + # TCs on such an "edge" indicate the same resulting state (continuing + # generation for same trial vs. stopping it after generating from current node) + # would indicate a malformed generation node DAG definition and therefore a + # malformed `GenerationStrategy`. contains_a_transition_to_argument = False for node in self._nodes: for transition_criteria in node.transition_criteria: @@ -708,6 +763,31 @@ def _gen_multiple( ) return generator_runs + def _should_continue_gen_for_trial(self) -> bool: + """Determine if we should continue generating for the current trial, or end + generation for the current trial. Note that generating more would involve + transitioning to a next node, because each node generates once per call to + ``GenerationStrategy.gen_with_multiple_nodes``. + + Returns: + A boolean which represents if generation for a trial is complete + """ + should_transition, next_node = self._curr.should_transition_to_next_node( + raise_data_required_error=False + ) + # if we should not transition nodes, we should stop generation for this trial. + if not should_transition: + return False + + # if we will transition nodes, check if the transition criterion which define + # the transition from this node to the next node indicate that we should + # continue generating in the same trial, otherwise end the generation. + assert next_node is not None + return all( + tc.continue_trial_generation + for tc in self._curr.transition_edges[next_node] + ) + # ------------------------- Model selection logic helpers. ------------------------- def _fit_current_model(self, data: Optional[Data]) -> None: diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/modelbridge/tests/test_generation_node.py index 115d38be506..9639c82f867 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/modelbridge/tests/test_generation_node.py @@ -172,7 +172,8 @@ def test_node_string_representation(self) -> None: "'only_in_statuses': [.RUNNING], " "'not_in_statuses': None, 'transition_to': None, " "'block_transition_if_unmet': True, 'block_gen_if_met': False, " - "'use_all_trials_in_exp': False})])" + "'use_all_trials_in_exp': False, " + "'continue_trial_generation': False})])" ), ) diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index e14e2def1d7..8bec74f7d24 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -1045,7 +1045,7 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features( GenerationStep( model=Models.BOTORCH_MODULAR, model_kwargs={ - # this will cause and error if the model + # this will cause an error if the model # doesn't get fixed features "transforms": ST_MTGP_trans, **self.step_model_kwargs, @@ -1445,34 +1445,44 @@ def test_node_gs_with_auto_transitions(self) -> None: node_name="sobol_2", model_specs=[self.sobol_model_spec], transition_criteria=[ - AutoTransitionAfterGenCriterion(transition_to="gpei_2") + AutoTransitionAfterGenCriterion(transition_to="sobol_3") ], ), GenerationNode( - node_name="gpei_2", - model_specs=[self.gpei_model_spec], + node_name="sobol_3", + model_specs=[self.sobol_model_spec], transition_criteria=[ - AutoTransitionAfterGenCriterion(transition_to="gpei") + AutoTransitionAfterGenCriterion( + transition_to="gpei", + block_transition_if_unmet=True, + continue_trial_generation=False, + ) ], ), ], ) exp = get_branin_experiment() + # for the first trial, we start on sobol, we generate the trial, but it hasn't + # been run yet, so we remain on sobol, after the trial is run, the subsequent + # trials should be from node gpei, sobol_2, and sobol_3 + self.assertEqual(gs.current_node_name, "sobol") + trial0 = exp.new_batch_trial(generator_runs=gs.gen_with_multiple_nodes(exp)) self.assertEqual(gs.current_node_name, "sobol") - exp.new_trial(generator_run=gs.gen(exp)).run() # while here, test the last generator run property on node self.assertEqual(gs.current_node.node_that_generated_last_gr, "sobol") - gs.gen(exp) - self.assertEqual(gs.current_node_name, "gpei") - gs.gen(exp) - self.assertEqual(gs.current_node_name, "sobol_2") - gs.gen(exp) - self.assertEqual(gs.current_node_name, "gpei_2") - gs.gen(exp) - self.assertEqual(gs.current_node_name, "gpei") - # TODO: @mgarrard modify below test when gen handles multiple nodes + trial0.run() + for _i in range(0, 2): + trial = exp.new_batch_trial(generator_runs=gs.gen_with_multiple_nodes(exp)) + self.assertEqual(gs.current_node_name, "sobol_3") + self.assertEqual(len(trial.generator_runs), 3) + self.assertEqual(trial.generator_runs[0]._generation_node_name, "gpei") + self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_2") + self.assertEqual(trial.generator_runs[2]._generation_node_name, "sobol_3") + + def test_node_gs_with_auto_transitions_three_phase(self) -> None: + exp = get_branin_experiment() gs_2 = GenerationStrategy( nodes=[ GenerationNode( @@ -1484,44 +1494,68 @@ def test_node_gs_with_auto_transitions(self) -> None: node_name="gpei", model_specs=[self.gpei_model_spec], transition_criteria=[ - AutoTransitionAfterGenCriterion(transition_to="sobol_2") + AutoTransitionAfterGenCriterion( + transition_to="sobol_2", + ) ], ), GenerationNode( node_name="sobol_2", model_specs=[self.sobol_model_spec], transition_criteria=[ - AutoTransitionAfterGenCriterion(transition_to="gpei_2") + AutoTransitionAfterGenCriterion(transition_to="sobol_3") ], ), GenerationNode( - node_name="gpei_2", - model_specs=[self.gpei_model_spec], + node_name="sobol_3", + model_specs=[self.sobol_model_spec], transition_criteria=[ MaxTrials( threshold=2, - transition_to="sobol_3", + transition_to="sobol_4", block_transition_if_unmet=True, only_in_statuses=[TrialStatus.RUNNING], use_all_trials_in_exp=True, - ) + ), + AutoTransitionAfterGenCriterion( + transition_to="gpei", + block_transition_if_unmet=True, + continue_trial_generation=False, + ), ], ), GenerationNode( - node_name="sobol_3", + node_name="sobol_4", model_specs=[self.sobol_model_spec], ), ], ) + + # for the first trial, we start on sobol, we generate the trial, but it hasn't + # been run yet, so we remain on sobol + self.assertEqual(gs_2.current_node_name, "sobol") + trial0 = exp.new_batch_trial(generator_runs=gs_2.gen_with_multiple_nodes(exp)) self.assertEqual(gs_2.current_node_name, "sobol") - exp.new_trial(generator_run=gs_2.gen(exp)).run() - self.assertEqual(gs_2.current_node_name, "gpei") - gs_2.gen(exp) - self.assertEqual(gs_2.current_node_name, "sobol_2") - gs_2.gen(exp) # noqa - self.assertEqual(gs_2.current_node_name, "gpei_2") - exp.new_trial(generator_run=gs_2.gen(exp)).run() - self.assertEqual(gs_2.current_node_name, "sobol_3") + trial0.run() + + # after trial 0 is run, we create a trial with nodes gpei, sobol_2, and sobol_3 + # However, the sobol_3 criterion requires that we have two running trials. We + # don't move onto sobol_4 until we have two running trials, instead we reset + # to the last first node in a trial. + for _i in range(0, 2): + trial = exp.new_batch_trial( + generator_runs=gs_2.gen_with_multiple_nodes(exp) + ) + self.assertEqual(gs_2.current_node_name, "sobol_3") + self.assertEqual(len(trial.generator_runs), 3) + self.assertEqual(trial.generator_runs[0]._generation_node_name, "gpei") + self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_2") + self.assertEqual(trial.generator_runs[2]._generation_node_name, "sobol_3") + + # after running the next trial should be made from sobol 4 + trial.run() + trial = exp.new_batch_trial(generator_runs=gs_2.gen_with_multiple_nodes(exp)) + self.assertEqual(trial.generator_runs[0]._generation_node_name, "sobol_4") # ------------- Testing helpers (put tests above this line) ------------- diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/modelbridge/tests/test_transition_criterion.py index 5ed9fafdac8..3971473b983 100644 --- a/ax/modelbridge/tests/test_transition_criterion.py +++ b/ax/modelbridge/tests/test_transition_criterion.py @@ -428,7 +428,8 @@ def test_repr(self) -> None: + "'transition_to': 'GenerationStep_1', " + "'block_transition_if_unmet': False, " + "'block_gen_if_met': True, " - + "'use_all_trials_in_exp': False})", + + "'use_all_trials_in_exp': False, " + + "'continue_trial_generation': False})", ) minimum_trials_in_status_criterion = MinTrials( only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], @@ -446,7 +447,8 @@ def test_repr(self) -> None: + "'transition_to': 'GenerationStep_2', " + "'block_transition_if_unmet': False, " + "'block_gen_if_met': True, " - + "'use_all_trials_in_exp': False})", + + "'use_all_trials_in_exp': False, " + + "'continue_trial_generation': False})", ) minimum_preference_occurances_criterion = MinimumPreferenceOccurances( metric_name="m1", threshold=3 @@ -483,12 +485,15 @@ def test_repr(self) -> None: + "'transition_to': 'GenerationStep_2', " + "'block_transition_if_unmet': False, " + "'block_gen_if_met': True, " - + "'use_all_trials_in_exp': False})", + + "'use_all_trials_in_exp': False, " + + "'continue_trial_generation': True})", ) auto_transition = AutoTransitionAfterGenCriterion( transition_to="GenerationStep_2" ) self.assertEqual( str(auto_transition), - "AutoTransitionAfterGenCriterion({'transition_to': 'GenerationStep_2'})", + "AutoTransitionAfterGenCriterion({'transition_to': 'GenerationStep_2', " + + "'block_transition_if_unmet': True, " + + "'continue_trial_generation': True})", ) diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index 71306fafc3f..f3263a76dc4 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -40,6 +40,12 @@ class TransitionCriterion(SortableBase, SerializationMixin): being able to transition to another node. Ex: MaxGenerationParallelism defaults to setting this to False since we can complete and move on from this node without ever reaching its threshold. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. """ _transition_to: Optional[str] = None @@ -49,10 +55,12 @@ def __init__( transition_to: Optional[str] = None, block_transition_if_unmet: Optional[bool] = True, block_gen_if_met: Optional[bool] = False, + continue_trial_generation: Optional[bool] = False, ) -> None: self._transition_to = transition_to self.block_transition_if_unmet = block_transition_if_unmet self.block_gen_if_met = block_gen_if_met + self.continue_trial_generation = continue_trial_generation @property def transition_to(self) -> Optional[str]: @@ -104,10 +112,29 @@ class AutoTransitionAfterGenCriterion(TransitionCriterion): Args: transition_to: The name of the GenerationNode the GenerationStrategy should transition to next. + block_transition_if_unmet: A flag to prevent the node from completing and + being able to transition to another node. Ex: This criterion defaults to + setting this to True to ensure we validate a GeneratorRun is generated by + the current GenerationNode. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. """ - def __init__(self, transition_to: str) -> None: - super().__init__(transition_to=transition_to) + def __init__( + self, + transition_to: str, + block_transition_if_unmet: Optional[bool] = True, + continue_trial_generation: Optional[bool] = True, + ) -> None: + super().__init__( + transition_to=transition_to, + block_transition_if_unmet=block_transition_if_unmet, + continue_trial_generation=continue_trial_generation, + ) def is_met( self, @@ -116,7 +143,9 @@ def is_met( node_that_generated_last_gr: Optional[str] = None, curr_node_name: Optional[str] = None, ) -> bool: - """Return true as soon as any trial is generated by this GenerationNode.""" + """Return True as soon as any GeneratorRun is generated by this + GenerationNode. + """ return node_that_generated_last_gr == curr_node_name def block_continued_generation_error( @@ -154,6 +183,12 @@ class TrialBasedCriterion(TransitionCriterion): transition to when this criterion is met, if it exists. use_all_trials_in_exp: A flag to use all trials in the experiment, instead of only those generated by the current GenerationNode. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. """ def __init__( @@ -165,6 +200,7 @@ def __init__( not_in_statuses: Optional[List[TrialStatus]] = None, transition_to: Optional[str] = None, use_all_trials_in_exp: Optional[bool] = False, + continue_trial_generation: Optional[bool] = False, ) -> None: self.threshold = threshold self.only_in_statuses = only_in_statuses @@ -174,6 +210,7 @@ def __init__( transition_to=transition_to, block_transition_if_unmet=block_transition_if_unmet, block_gen_if_met=block_gen_if_met, + continue_trial_generation=continue_trial_generation, ) def experiment_trials_by_status( @@ -306,6 +343,14 @@ class MaxGenerationParallelism(TrialBasedCriterion): until MinimumTrialsInStatus is met (thus overriding MaxTrials). use_all_trials_in_exp: A flag to use all trials in the experiment, instead of only those generated by the current GenerationNode. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. Defautls to False for + MaxGenerationParallelism since this criterion isn't currently used for + node -> node or trial -> trial transition. """ def __init__( @@ -317,6 +362,7 @@ def __init__( block_transition_if_unmet: Optional[bool] = False, block_gen_if_met: Optional[bool] = True, use_all_trials_in_exp: Optional[bool] = False, + continue_trial_generation: Optional[bool] = True, ) -> None: super().__init__( threshold=threshold, @@ -326,6 +372,7 @@ def __init__( block_gen_if_met=block_gen_if_met, block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, + continue_trial_generation=continue_trial_generation, ) def block_continued_generation_error( @@ -380,6 +427,12 @@ class MaxTrials(TrialBasedCriterion): until MinimumTrialsInStatus is met (thus overriding MaxTrials). use_all_trials_in_exp: A flag to use all trials in the experiment, instead of only those generated by the current GenerationNode. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. """ def __init__( @@ -391,6 +444,7 @@ def __init__( block_transition_if_unmet: Optional[bool] = True, block_gen_if_met: Optional[bool] = False, use_all_trials_in_exp: Optional[bool] = False, + continue_trial_generation: Optional[bool] = False, ) -> None: super().__init__( threshold=threshold, @@ -400,6 +454,7 @@ def __init__( block_gen_if_met=block_gen_if_met, block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, + continue_trial_generation=continue_trial_generation, ) def block_continued_generation_error( @@ -448,6 +503,12 @@ class MinTrials(TrialBasedCriterion): until MinimumTrialsInStatus is met (thus overriding MaxTrials). use_all_trials_in_exp: A flag to use all trials in the experiment, instead of only those generated by the current GenerationNode. + complete_trial_generation: A flag to indicate that all generation for a given + trial is completed. This is necessary because in ```BatchTrial``` there + are multiple arms per trial, and we enable generation of arms within a + batch from different ```GenerationNodes```. This flag should be set to + True for the last node in a set of ```GenerationNodes``` expected to + create a given ```BatchTrial```. """ def __init__( @@ -459,6 +520,7 @@ def __init__( block_transition_if_unmet: Optional[bool] = True, block_gen_if_met: Optional[bool] = False, use_all_trials_in_exp: Optional[bool] = False, + continue_trial_generation: Optional[bool] = False, ) -> None: super().__init__( threshold=threshold, @@ -468,6 +530,7 @@ def __init__( block_gen_if_met=block_gen_if_met, block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, + continue_trial_generation=continue_trial_generation, ) def block_continued_generation_error(