diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 77c09d81e0b..b011a614f66 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -8,6 +8,8 @@ from __future__ import annotations +from collections import defaultdict + from dataclasses import dataclass, field from logging import Logger from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union @@ -153,7 +155,7 @@ def generation_strategy(self) -> modelbridge.generation_strategy.GenerationStrat """Returns a backpointer to the GenerationStrategy, useful for obtaining the experiment associated with this GenerationStrategy""" # TODO: @mgarrard remove this property once we make experiment a required - # arguement on GenerationStrategy + # argument on GenerationStrategy if self._generation_strategy is None: raise ValueError( "Generation strategy has not been initialized on this node." @@ -296,7 +298,7 @@ def gen( ) assert generator_run is not None, ( "The GeneratorRun is None which is an unexpected state of this" - " GenerationStrategy. This occured on GenerationNode: {self.node_name}." + " GenerationStrategy. This occurred on GenerationNode: {self.node_name}." ) generator_run._generation_node_name = self.node_name return generator_run @@ -397,6 +399,27 @@ def node_that_generated_last_gr(self) -> Optional[str]: else None ) + @property + def transition_edges(self) -> Dict[str, List[TransitionCriterion]]: + """Returns a dictionary mapping the next ```GenerationNode``` to the + TransitionCriteria that define the transition that that node. + + Ex: if the transition from the current node to node x is defined by MaxTrials + and MinTrials criterion then the return would be {'x': [MaxTrials, MinTrials]}. + + Returns: + Dict[str, List[TransitionCriterion]]: A dictionary mapping the next + ```GenerationNode``` to the ```TransitionCriterion``` that are associated + with it. + """ + if self.transition_criteria is None: + return {} + + tc_edges = defaultdict(list) + for tc in self.transition_criteria: + tc_edges[tc.transition_to].append(tc) + return tc_edges + def should_transition_to_next_node( self, raise_data_required_error: bool = True ) -> Tuple[bool, Optional[str]]: @@ -409,7 +432,8 @@ def should_transition_to_next_node( check how many generator runs (to be made into trials) can be produced, but not actually producing them yet. Returns: - bool: Whether we should transition to the next node. + Tuple[bool, Optional[str]]: Whether we should transition to the next node + and the name of the next node. """ # if no transition criteria are defined, this node can generate unlimited trials if len(self.transition_criteria) == 0: diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index b86314ae885..2eb0bcc8d8a 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -1029,7 +1029,7 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features( # ---------- Tests for GenerationStrategies composed of GenerationNodes -------- def test_gs_setup_with_nodes(self) -> None: - """Test GS initalization and validation with nodes""" + """Test GS initialization and validation with nodes""" node_1_criterion = [ MaxTrials( threshold=4, @@ -1280,6 +1280,61 @@ def test_generation_strategy_eq_no_print(self) -> None: gs2 = self.basic_sobol_gpei_gs self.assertEqual(gs1, gs2) + def test_transition_edges(self) -> None: + """Test transition_edges property of ``GenerationNode``""" + # this gs has a single sobol node which transitions to gpei. If the MaxTrials + # and MinTrials criterion are met, the transition to sobol_2 should occur, + # otherwise, should transition back to sobol. + gpei_to_sobol2_max = MaxTrials( + threshold=2, + transition_to="sobol_2", + block_transition_if_unmet=True, + only_in_statuses=[TrialStatus.RUNNING], + ) + gpei_to_sobol2_min = MinTrials( + threshold=1, + transition_to="sobol_2", + block_transition_if_unmet=True, + only_in_statuses=[TrialStatus.COMPLETED], + ) + gpei_to_sobol_auto = AutoTransitionAfterGenCriterion(transition_to="sobol") + gs = GenerationStrategy( + nodes=[ + GenerationNode( + node_name="sobol", + model_specs=[self.sobol_model_spec], + transition_criteria=self.single_running_trial_criterion, + ), + GenerationNode( + node_name="gpei", + model_specs=[self.gpei_model_spec], + transition_criteria=[ + gpei_to_sobol2_max, + gpei_to_sobol2_min, + gpei_to_sobol_auto, + ], + ), + GenerationNode( + node_name="sobol_2", + model_specs=[self.sobol_model_spec], + ), + ], + ) + exp = get_branin_experiment() + self.assertEqual( + gs._curr.transition_edges, {"gpei": self.single_running_trial_criterion} + ) + exp.new_trial(generator_run=gs.gen(exp)).run() + gs.gen(exp) + self.assertEqual(gs.current_node_name, "gpei") + self.assertEqual( + gs._curr.transition_edges, + { + "sobol_2": [gpei_to_sobol2_max, gpei_to_sobol2_min], + "sobol": [gpei_to_sobol_auto], + }, + ) + def test_node_gs_with_auto_transitions(self) -> None: """Test that node-based generation strategies which leverage AutoTransitionAfterGen criterion correctly transition and create trials.