diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index b011a614f66..d8c6cfccb2c 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -424,7 +424,11 @@ def should_transition_to_next_node( self, raise_data_required_error: bool = True ) -> Tuple[bool, Optional[str]]: """Checks whether we should transition to the next node based on this node's - TransitionCriterion + TransitionCriterion. + + Important: This method relies on the `transition_criterion` of this node to be + listed in order of importance. Ex: a fallback transition should come after the + primary transition in the transition criterion list. Args: raise_data_required_error: Whether to raise ``DataRequiredError`` in the @@ -439,54 +443,44 @@ def should_transition_to_next_node( if len(self.transition_criteria) == 0: return False, None - transition_blocking = [ - tc for tc in self.transition_criteria if tc.block_transition_if_unmet - ] - transition_blocking_met = all( - tc.is_met( - experiment=self.experiment, - trials_from_node=self.trials_from_node, - node_name=self.node_name, - node_that_generated_last_gr=( - self.generation_strategy.last_generator_run._generation_node_name - if self.generation_strategy.last_generator_run is not None - else None - ), + # for each edge in node dag, check if the transition criterion are met, if so + # transition to the next node defined by that edge. + for next_node, all_tc in self.transition_edges.items(): + transition_blocking = [tc for tc in all_tc if tc.block_transition_if_unmet] + gs_lgr = self.generation_strategy.last_generator_run + transition_blocking_met = all( + tc.is_met( + experiment=self.experiment, + trials_from_node=self.trials_from_node, + node_name=self.node_name, + node_that_generated_last_gr=( + gs_lgr._generation_node_name if gs_lgr is not None else None + ), + ) + for tc in transition_blocking ) - for tc in transition_blocking - ) - # Raise any necessary generation errors: for any met criterion, - # call its `block_continued_generation_error` method if not all - # transition-blocking criteria are met. The method might not raise an - # error, depending on its implementation on given criterion, so the error - # from the first met one that does block continued generation, will be raised. - if not transition_blocking_met: - for tc in self.transition_criteria: - if ( - tc.is_met(self.experiment, trials_from_node=self.trials_from_node) - and raise_data_required_error - ): - tc.block_continued_generation_error( - node_name=self.node_name, - model_name=self.model_to_gen_from_name, - experiment=self.experiment, - trials_from_node=self.trials_from_node, - ) - # Determine transition state - if len(transition_blocking) > 0 and transition_blocking_met: - next_nodes = [ - c.transition_to - for c in transition_blocking - if c._transition_to is not None - ] - if len(set(next_nodes)) > 1: - # TODO: support intelligent selection between multiple transition nodes - raise NotImplementedError( - "Cannot currently select between multiple nodes to transition to." - ) - else: - return True, next_nodes[0] + # Raise any necessary generation errors: for any met criterion, + # call its `block_continued_generation_error` method if not all + # transition-blocking criteria are met. The method might not raise an + # error, depending on its implementation on given criterion, so the error + # from the first met one that does block continued generation, will raise. + if not transition_blocking_met: + for tc in all_tc: + if ( + tc.is_met( + self.experiment, trials_from_node=self.trials_from_node + ) + and raise_data_required_error + ): + tc.block_continued_generation_error( + node_name=self.node_name, + model_name=self.model_to_gen_from_name, + experiment=self.experiment, + trials_from_node=self.trials_from_node, + ) + if len(transition_blocking) > 0 and transition_blocking_met: + return True, next_node return False, None diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index 2eb0bcc8d8a..e14e2def1d7 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -186,11 +186,53 @@ def setUp(self) -> None: transition_criteria=self.gpei_criterion, model_specs=[self.gpei_model_spec], ) - self.sobol_GPEI_GS_nodes = GenerationStrategy( name="Sobol+GPEI_Nodes", nodes=[self.sobol_node, self.gpei_node], ) + self.gpei_to_sobol2_max = MaxTrials( + threshold=1, + transition_to="sobol_2", + block_transition_if_unmet=True, + only_in_statuses=[TrialStatus.RUNNING], + use_all_trials_in_exp=True, + ) + self.gpei_to_sobol2_min = MinTrials( + threshold=1, + transition_to="sobol_2", + block_transition_if_unmet=True, + only_in_statuses=[TrialStatus.COMPLETED], + use_all_trials_in_exp=True, + ) + self.gpei_to_sobol_auto = AutoTransitionAfterGenCriterion( + transition_to="sobol_3" + ) + self.competing_tc_gs = GenerationStrategy( + nodes=[ + GenerationNode( + node_name="sobol", + model_specs=[self.sobol_model_spec], + transition_criteria=self.single_running_trial_criterion, + ), + GenerationNode( + node_name="gpei", + model_specs=[self.gpei_model_spec], + transition_criteria=[ + self.gpei_to_sobol2_max, + self.gpei_to_sobol2_min, + self.gpei_to_sobol_auto, + ], + ), + GenerationNode( + node_name="sobol_2", + model_specs=[self.sobol_model_spec], + ), + GenerationNode( + node_name="sobol_3", + model_specs=[self.sobol_model_spec], + ), + ], + ) def tearDown(self) -> None: self.torch_model_bridge_patcher.stop() @@ -1280,23 +1322,65 @@ def test_generation_strategy_eq_no_print(self) -> None: gs2 = self.basic_sobol_gpei_gs self.assertEqual(gs1, gs2) + def test_gs_with_competing_transition_edges(self) -> None: + """Test that a ```GenerationStrategy``` with a node with competing transition + edges correctly transitions. + """ + # this gs has a single sobol node which transitions to gpei. If the MaxTrials + # and MinTrials criterion are met, the transition to sobol_2 should occur, + # otherwise, should transition to sobol_3 + gs = self.competing_tc_gs + exp = get_branin_experiment() + + # check that gpei will move to sobol_3 when MaxTrials and MinTrials are unmet + exp.new_trial(generator_run=gs.gen(exp)).run() + gs.gen(exp) + self.assertEqual(gs.current_node_name, "gpei") + gs.gen(exp) + self.assertEqual(gs.current_node_name, "sobol_3") + + def test_gs_with_competing_transition_edges_2(self) -> None: + """Test that a ```GenerationStrategy``` with a node with competing transition + edges correctly transitions. + """ + # this gs has a single sobol node which transitions to gpei. If the MaxTrials + # and MinTrials criterion are met, the transition to sobol_2 should occur, + # otherwise, should transition to sobol_3 + gs = self.competing_tc_gs + exp = get_branin_experiment() + + # check that gpei will move to sobol_3 when MaxTrials is met and MinTrials + # is unmet + exp.new_trial(generator_run=gs.gen(exp)).run() + exp.new_trial(generator_run=gs.gen(exp)).run() + self.assertEqual(gs.current_node_name, "gpei") + gs.gen(exp) + self.assertEqual(gs.current_node_name, "sobol_3") + + def test_gs_with_competing_transition_edges_3(self) -> None: + """Test that a ```GenerationStrategy``` with a node with competing transition + edges correctly transitions. + """ + # this gs has a single sobol node which transitions to gpei. If the MaxTrials + # and MinTrials criterion are met, the transition to sobol_2 should occur, + # otherwise, should transition to sobol_3 + gs = self.competing_tc_gs + exp = get_branin_experiment() + + # check that gpei will transition to sobol_2 when MaxTrials is met and + # MinTrials are met + trial = exp.new_trial(generator_run=gs.gen(exp)).run() + exp.new_trial(generator_run=gs.gen(exp)).run() + self.assertEqual(gs.current_node_name, "gpei") + trial.mark_completed() + gs.gen(exp) + self.assertEqual(gs.current_node_name, "sobol_2") + def test_transition_edges(self) -> None: """Test transition_edges property of ``GenerationNode``""" # this gs has a single sobol node which transitions to gpei. If the MaxTrials # and MinTrials criterion are met, the transition to sobol_2 should occur, # otherwise, should transition back to sobol. - gpei_to_sobol2_max = MaxTrials( - threshold=2, - transition_to="sobol_2", - block_transition_if_unmet=True, - only_in_statuses=[TrialStatus.RUNNING], - ) - gpei_to_sobol2_min = MinTrials( - threshold=1, - transition_to="sobol_2", - block_transition_if_unmet=True, - only_in_statuses=[TrialStatus.COMPLETED], - ) gpei_to_sobol_auto = AutoTransitionAfterGenCriterion(transition_to="sobol") gs = GenerationStrategy( nodes=[ @@ -1309,8 +1393,8 @@ def test_transition_edges(self) -> None: node_name="gpei", model_specs=[self.gpei_model_spec], transition_criteria=[ - gpei_to_sobol2_max, - gpei_to_sobol2_min, + self.gpei_to_sobol2_max, + self.gpei_to_sobol2_min, gpei_to_sobol_auto, ], ), @@ -1330,7 +1414,7 @@ def test_transition_edges(self) -> None: self.assertEqual( gs._curr.transition_edges, { - "sobol_2": [gpei_to_sobol2_max, gpei_to_sobol2_min], + "sobol_2": [self.gpei_to_sobol2_max, self.gpei_to_sobol2_min], "sobol": [gpei_to_sobol_auto], }, )