From 892679711bdaf10e5837d0691b7e5b37749f0d39 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Mon, 13 May 2024 13:20:24 -0700 Subject: [PATCH] node <> model 3/n: update should_transition_to_next_node (#2447) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2447 This diff builds on the previous diff, and uses the new property that we added for TC to decide if the gs should move forward. If all criterion that specificy a transition edge between nodes are met, we move to that node. We progress through tc in order -- meaning when constructing a gs tc ordering is important. Reviewed By: lena-kashtelyan Differential Revision: D57132018 fbshipit-source-id: 5f7da7d631eab07a37988dd0196d096068234908 --- ax/modelbridge/generation_node.py | 91 +++++++------- .../tests/test_generation_strategy.py | 116 +++++++++++++++--- 2 files changed, 144 insertions(+), 63 deletions(-) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 35d67469564..e0480d7ad77 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -424,7 +424,11 @@ def should_transition_to_next_node( self, raise_data_required_error: bool = True ) -> Tuple[bool, Optional[str]]: """Checks whether we should transition to the next node based on this node's - TransitionCriterion + TransitionCriterion. + + Important: This method relies on the ``transition_criterion`` of this node to + be listed in order of importance. Ex: a fallback transition should come after + the primary transition in the transition criterion list. Args: raise_data_required_error: Whether to raise ``DataRequiredError`` in the @@ -439,54 +443,47 @@ def should_transition_to_next_node( if len(self.transition_criteria) == 0: return False, None - transition_blocking = [ - tc for tc in self.transition_criteria if tc.block_transition_if_unmet - ] - transition_blocking_met = all( - tc.is_met( - experiment=self.experiment, - trials_from_node=self.trials_from_node, - curr_node_name=self.node_name, - node_that_generated_last_gr=( - self.generation_strategy.last_generator_run._generation_node_name - if self.generation_strategy.last_generator_run is not None - else None - ), + # for each edge in node DAG, check if the transition criterion are met, if so + # transition to the next node defined by that edge. + for next_node, all_tc in self.transition_edges.items(): + transition_blocking = [tc for tc in all_tc if tc.block_transition_if_unmet] + gs_lgr = self.generation_strategy.last_generator_run + transition_blocking_met = all( + tc.is_met( + experiment=self.experiment, + trials_from_node=self.trials_from_node, + curr_node_name=self.node_name, + # TODO @mgarrard: should we instead pass a backpointer to gs/node + node_that_generated_last_gr=( + gs_lgr._generation_node_name if gs_lgr is not None else None + ), + ) + for tc in transition_blocking ) - for tc in transition_blocking - ) - # Raise any necessary generation errors: for any met criterion, - # call its `block_continued_generation_error` method if not all - # transition-blocking criteria are met. The method might not raise an - # error, depending on its implementation on given criterion, so the error - # from the first met one that does block continued generation, will be raised. - if not transition_blocking_met: - for tc in self.transition_criteria: - if ( - tc.is_met(self.experiment, trials_from_node=self.trials_from_node) - and raise_data_required_error - ): - tc.block_continued_generation_error( - node_name=self.node_name, - model_name=self.model_to_gen_from_name, - experiment=self.experiment, - trials_from_node=self.trials_from_node, - ) - # Determine transition state - if len(transition_blocking) > 0 and transition_blocking_met: - next_nodes = [ - c.transition_to - for c in transition_blocking - if c._transition_to is not None - ] - if len(set(next_nodes)) > 1: - # TODO: support intelligent selection between multiple transition nodes - raise NotImplementedError( - "Cannot currently select between multiple nodes to transition to." - ) - else: - return True, next_nodes[0] + # Raise any necessary generation errors: for any met criterion, + # call its `block_continued_generation_error` method if not all + # transition-blocking criteria are met. The method might not raise an + # error, depending on its implementation on given criterion, so the error + # from the first met one that does block continued generation, will raise. + # TODO: @mgarrard see if we can replace MaxGenerationParallelism with a + # transition to self and rework this error block. + if not transition_blocking_met: + for tc in all_tc: + if ( + tc.is_met( + self.experiment, trials_from_node=self.trials_from_node + ) + and raise_data_required_error + ): + tc.block_continued_generation_error( + node_name=self.node_name, + model_name=self.model_to_gen_from_name, + experiment=self.experiment, + trials_from_node=self.trials_from_node, + ) + if len(transition_blocking) > 0 and transition_blocking_met: + return True, next_node return False, None diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index 2eb0bcc8d8a..e14e2def1d7 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -186,11 +186,53 @@ def setUp(self) -> None: transition_criteria=self.gpei_criterion, model_specs=[self.gpei_model_spec], ) - self.sobol_GPEI_GS_nodes = GenerationStrategy( name="Sobol+GPEI_Nodes", nodes=[self.sobol_node, self.gpei_node], ) + self.gpei_to_sobol2_max = MaxTrials( + threshold=1, + transition_to="sobol_2", + block_transition_if_unmet=True, + only_in_statuses=[TrialStatus.RUNNING], + use_all_trials_in_exp=True, + ) + self.gpei_to_sobol2_min = MinTrials( + threshold=1, + transition_to="sobol_2", + block_transition_if_unmet=True, + only_in_statuses=[TrialStatus.COMPLETED], + use_all_trials_in_exp=True, + ) + self.gpei_to_sobol_auto = AutoTransitionAfterGenCriterion( + transition_to="sobol_3" + ) + self.competing_tc_gs = GenerationStrategy( + nodes=[ + GenerationNode( + node_name="sobol", + model_specs=[self.sobol_model_spec], + transition_criteria=self.single_running_trial_criterion, + ), + GenerationNode( + node_name="gpei", + model_specs=[self.gpei_model_spec], + transition_criteria=[ + self.gpei_to_sobol2_max, + self.gpei_to_sobol2_min, + self.gpei_to_sobol_auto, + ], + ), + GenerationNode( + node_name="sobol_2", + model_specs=[self.sobol_model_spec], + ), + GenerationNode( + node_name="sobol_3", + model_specs=[self.sobol_model_spec], + ), + ], + ) def tearDown(self) -> None: self.torch_model_bridge_patcher.stop() @@ -1280,23 +1322,65 @@ def test_generation_strategy_eq_no_print(self) -> None: gs2 = self.basic_sobol_gpei_gs self.assertEqual(gs1, gs2) + def test_gs_with_competing_transition_edges(self) -> None: + """Test that a ```GenerationStrategy``` with a node with competing transition + edges correctly transitions. + """ + # this gs has a single sobol node which transitions to gpei. If the MaxTrials + # and MinTrials criterion are met, the transition to sobol_2 should occur, + # otherwise, should transition to sobol_3 + gs = self.competing_tc_gs + exp = get_branin_experiment() + + # check that gpei will move to sobol_3 when MaxTrials and MinTrials are unmet + exp.new_trial(generator_run=gs.gen(exp)).run() + gs.gen(exp) + self.assertEqual(gs.current_node_name, "gpei") + gs.gen(exp) + self.assertEqual(gs.current_node_name, "sobol_3") + + def test_gs_with_competing_transition_edges_2(self) -> None: + """Test that a ```GenerationStrategy``` with a node with competing transition + edges correctly transitions. + """ + # this gs has a single sobol node which transitions to gpei. If the MaxTrials + # and MinTrials criterion are met, the transition to sobol_2 should occur, + # otherwise, should transition to sobol_3 + gs = self.competing_tc_gs + exp = get_branin_experiment() + + # check that gpei will move to sobol_3 when MaxTrials is met and MinTrials + # is unmet + exp.new_trial(generator_run=gs.gen(exp)).run() + exp.new_trial(generator_run=gs.gen(exp)).run() + self.assertEqual(gs.current_node_name, "gpei") + gs.gen(exp) + self.assertEqual(gs.current_node_name, "sobol_3") + + def test_gs_with_competing_transition_edges_3(self) -> None: + """Test that a ```GenerationStrategy``` with a node with competing transition + edges correctly transitions. + """ + # this gs has a single sobol node which transitions to gpei. If the MaxTrials + # and MinTrials criterion are met, the transition to sobol_2 should occur, + # otherwise, should transition to sobol_3 + gs = self.competing_tc_gs + exp = get_branin_experiment() + + # check that gpei will transition to sobol_2 when MaxTrials is met and + # MinTrials are met + trial = exp.new_trial(generator_run=gs.gen(exp)).run() + exp.new_trial(generator_run=gs.gen(exp)).run() + self.assertEqual(gs.current_node_name, "gpei") + trial.mark_completed() + gs.gen(exp) + self.assertEqual(gs.current_node_name, "sobol_2") + def test_transition_edges(self) -> None: """Test transition_edges property of ``GenerationNode``""" # this gs has a single sobol node which transitions to gpei. If the MaxTrials # and MinTrials criterion are met, the transition to sobol_2 should occur, # otherwise, should transition back to sobol. - gpei_to_sobol2_max = MaxTrials( - threshold=2, - transition_to="sobol_2", - block_transition_if_unmet=True, - only_in_statuses=[TrialStatus.RUNNING], - ) - gpei_to_sobol2_min = MinTrials( - threshold=1, - transition_to="sobol_2", - block_transition_if_unmet=True, - only_in_statuses=[TrialStatus.COMPLETED], - ) gpei_to_sobol_auto = AutoTransitionAfterGenCriterion(transition_to="sobol") gs = GenerationStrategy( nodes=[ @@ -1309,8 +1393,8 @@ def test_transition_edges(self) -> None: node_name="gpei", model_specs=[self.gpei_model_spec], transition_criteria=[ - gpei_to_sobol2_max, - gpei_to_sobol2_min, + self.gpei_to_sobol2_max, + self.gpei_to_sobol2_min, gpei_to_sobol_auto, ], ), @@ -1330,7 +1414,7 @@ def test_transition_edges(self) -> None: self.assertEqual( gs._curr.transition_edges, { - "sobol_2": [gpei_to_sobol2_max, gpei_to_sobol2_min], + "sobol_2": [self.gpei_to_sobol2_max, self.gpei_to_sobol2_min], "sobol": [gpei_to_sobol_auto], }, )