Skip to content

Commit

Permalink
node <> model 3/n: update should_transition_to_next_node (#2447)
Browse files Browse the repository at this point in the history
Summary:

This diff builds on the previous diff, and uses the new property that we added for TC to decide if the gs should move forward.

If all criterion that specificy a transition edge between nodes are met, we move to that node. We progress through tc in order -- meaning when constructing a gs tc ordering is important.

Differential Revision: D57132018
  • Loading branch information
mgarrard authored and facebook-github-bot committed May 9, 2024
1 parent da5b96e commit c3e429a
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 63 deletions.
88 changes: 41 additions & 47 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,11 @@ def should_transition_to_next_node(
self, raise_data_required_error: bool = True
) -> Tuple[bool, Optional[str]]:
"""Checks whether we should transition to the next node based on this node's
TransitionCriterion
TransitionCriterion.
Important: This method relies on the `transition_criterion` of this node to be
listed in order of importance. Ex: a fallback transition should come after the
primary transition in the transition criterion list.
Args:
raise_data_required_error: Whether to raise ``DataRequiredError`` in the
Expand All @@ -439,54 +443,44 @@ def should_transition_to_next_node(
if len(self.transition_criteria) == 0:
return False, None

transition_blocking = [
tc for tc in self.transition_criteria if tc.block_transition_if_unmet
]
transition_blocking_met = all(
tc.is_met(
experiment=self.experiment,
trials_from_node=self.trials_from_node,
node_name=self.node_name,
node_that_generated_last_gr=(
self.generation_strategy.last_generator_run._generation_node_name
if self.generation_strategy.last_generator_run is not None
else None
),
# for each edge in node dag, check if the transition criterion are met, if so
# transition to the next node defined by that edge.
for next_node, all_tc in self.transition_edges.items():
transition_blocking = [tc for tc in all_tc if tc.block_transition_if_unmet]
gs_lgr = self.generation_strategy.last_generator_run
transition_blocking_met = all(
tc.is_met(
experiment=self.experiment,
trials_from_node=self.trials_from_node,
node_name=self.node_name,
node_that_generated_last_gr=(
gs_lgr._generation_node_name if gs_lgr is not None else None
),
)
for tc in transition_blocking
)
for tc in transition_blocking
)
# Raise any necessary generation errors: for any met criterion,
# call its `block_continued_generation_error` method if not all
# transition-blocking criteria are met. The method might not raise an
# error, depending on its implementation on given criterion, so the error
# from the first met one that does block continued generation, will be raised.
if not transition_blocking_met:
for tc in self.transition_criteria:
if (
tc.is_met(self.experiment, trials_from_node=self.trials_from_node)
and raise_data_required_error
):
tc.block_continued_generation_error(
node_name=self.node_name,
model_name=self.model_to_gen_from_name,
experiment=self.experiment,
trials_from_node=self.trials_from_node,
)

# Determine transition state
if len(transition_blocking) > 0 and transition_blocking_met:
next_nodes = [
c.transition_to
for c in transition_blocking
if c._transition_to is not None
]
if len(set(next_nodes)) > 1:
# TODO: support intelligent selection between multiple transition nodes
raise NotImplementedError(
"Cannot currently select between multiple nodes to transition to."
)
else:
return True, next_nodes[0]
# Raise any necessary generation errors: for any met criterion,
# call its `block_continued_generation_error` method if not all
# transition-blocking criteria are met. The method might not raise an
# error, depending on its implementation on given criterion, so the error
# from the first met one that does block continued generation, will raise.
if not transition_blocking_met:
for tc in all_tc:
if (
tc.is_met(
self.experiment, trials_from_node=self.trials_from_node
)
and raise_data_required_error
):
tc.block_continued_generation_error(
node_name=self.node_name,
model_name=self.model_to_gen_from_name,
experiment=self.experiment,
trials_from_node=self.trials_from_node,
)
if len(transition_blocking) > 0 and transition_blocking_met:
return True, next_node

return False, None

Expand Down
116 changes: 100 additions & 16 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,53 @@ def setUp(self) -> None:
transition_criteria=self.gpei_criterion,
model_specs=[self.gpei_model_spec],
)

self.sobol_GPEI_GS_nodes = GenerationStrategy(
name="Sobol+GPEI_Nodes",
nodes=[self.sobol_node, self.gpei_node],
)
self.gpei_to_sobol2_max = MaxTrials(
threshold=1,
transition_to="sobol_2",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
use_all_trials_in_exp=True,
)
self.gpei_to_sobol2_min = MinTrials(
threshold=1,
transition_to="sobol_2",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.COMPLETED],
use_all_trials_in_exp=True,
)
self.gpei_to_sobol_auto = AutoTransitionAfterGenCriterion(
transition_to="sobol_3"
)
self.competing_tc_gs = GenerationStrategy(
nodes=[
GenerationNode(
node_name="sobol",
model_specs=[self.sobol_model_spec],
transition_criteria=self.single_running_trial_criterion,
),
GenerationNode(
node_name="gpei",
model_specs=[self.gpei_model_spec],
transition_criteria=[
self.gpei_to_sobol2_max,
self.gpei_to_sobol2_min,
self.gpei_to_sobol_auto,
],
),
GenerationNode(
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
),
GenerationNode(
node_name="sobol_3",
model_specs=[self.sobol_model_spec],
),
],
)

def tearDown(self) -> None:
self.torch_model_bridge_patcher.stop()
Expand Down Expand Up @@ -1280,23 +1322,65 @@ def test_generation_strategy_eq_no_print(self) -> None:
gs2 = self.basic_sobol_gpei_gs
self.assertEqual(gs1, gs2)

def test_gs_with_competing_transition_edges(self) -> None:
"""Test that a ```GenerationStrategy``` with a node with competing transition
edges correctly transitions.
"""
# this gs has a single sobol node which transitions to gpei. If the MaxTrials
# and MinTrials criterion are met, the transition to sobol_2 should occur,
# otherwise, should transition to sobol_3
gs = self.competing_tc_gs
exp = get_branin_experiment()

# check that gpei will move to sobol_3 when MaxTrials and MinTrials are unmet
exp.new_trial(generator_run=gs.gen(exp)).run()
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "sobol_3")

def test_gs_with_competing_transition_edges_2(self) -> None:
"""Test that a ```GenerationStrategy``` with a node with competing transition
edges correctly transitions.
"""
# this gs has a single sobol node which transitions to gpei. If the MaxTrials
# and MinTrials criterion are met, the transition to sobol_2 should occur,
# otherwise, should transition to sobol_3
gs = self.competing_tc_gs
exp = get_branin_experiment()

# check that gpei will move to sobol_3 when MaxTrials is met and MinTrials
# is unmet
exp.new_trial(generator_run=gs.gen(exp)).run()
exp.new_trial(generator_run=gs.gen(exp)).run()
self.assertEqual(gs.current_node_name, "gpei")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "sobol_3")

def test_gs_with_competing_transition_edges_3(self) -> None:
"""Test that a ```GenerationStrategy``` with a node with competing transition
edges correctly transitions.
"""
# this gs has a single sobol node which transitions to gpei. If the MaxTrials
# and MinTrials criterion are met, the transition to sobol_2 should occur,
# otherwise, should transition to sobol_3
gs = self.competing_tc_gs
exp = get_branin_experiment()

# check that gpei will transition to sobol_2 when MaxTrials is met and
# MinTrials are met
trial = exp.new_trial(generator_run=gs.gen(exp)).run()
exp.new_trial(generator_run=gs.gen(exp)).run()
self.assertEqual(gs.current_node_name, "gpei")
trial.mark_completed()
gs.gen(exp)
self.assertEqual(gs.current_node_name, "sobol_2")

def test_transition_edges(self) -> None:
"""Test transition_edges property of ``GenerationNode``"""
# this gs has a single sobol node which transitions to gpei. If the MaxTrials
# and MinTrials criterion are met, the transition to sobol_2 should occur,
# otherwise, should transition back to sobol.
gpei_to_sobol2_max = MaxTrials(
threshold=2,
transition_to="sobol_2",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
)
gpei_to_sobol2_min = MinTrials(
threshold=1,
transition_to="sobol_2",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.COMPLETED],
)
gpei_to_sobol_auto = AutoTransitionAfterGenCriterion(transition_to="sobol")
gs = GenerationStrategy(
nodes=[
Expand All @@ -1309,8 +1393,8 @@ def test_transition_edges(self) -> None:
node_name="gpei",
model_specs=[self.gpei_model_spec],
transition_criteria=[
gpei_to_sobol2_max,
gpei_to_sobol2_min,
self.gpei_to_sobol2_max,
self.gpei_to_sobol2_min,
gpei_to_sobol_auto,
],
),
Expand All @@ -1330,7 +1414,7 @@ def test_transition_edges(self) -> None:
self.assertEqual(
gs._curr.transition_edges,
{
"sobol_2": [gpei_to_sobol2_max, gpei_to_sobol2_min],
"sobol_2": [self.gpei_to_sobol2_max, self.gpei_to_sobol2_min],
"sobol": [gpei_to_sobol_auto],
},
)
Expand Down

0 comments on commit c3e429a

Please sign in to comment.