Skip to content

Commit

Permalink
node <> model 3/n: update should_transition_to_next_node (facebook#2447)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#2447

This diff builds on the previous diff, and uses the new property that we added for TC to decide if the gs should move forward.

If all criterion that specificy a transition edge between nodes are met, we move to that node. We progress through tc in order -- meaning when constructing a gs tc ordering is important.

Reviewed By: lena-kashtelyan

Differential Revision: D57132018

fbshipit-source-id: 5f7da7d631eab07a37988dd0196d096068234908
  • Loading branch information
mgarrard authored and facebook-github-bot committed May 13, 2024
1 parent 74838f2 commit 8926797
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 63 deletions.
91 changes: 44 additions & 47 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,11 @@ def should_transition_to_next_node(
self, raise_data_required_error: bool = True
) -> Tuple[bool, Optional[str]]:
"""Checks whether we should transition to the next node based on this node's
TransitionCriterion
TransitionCriterion.
Important: This method relies on the ``transition_criterion`` of this node to
be listed in order of importance. Ex: a fallback transition should come after
the primary transition in the transition criterion list.
Args:
raise_data_required_error: Whether to raise ``DataRequiredError`` in the
Expand All @@ -439,54 +443,47 @@ def should_transition_to_next_node(
if len(self.transition_criteria) == 0:
return False, None

transition_blocking = [
tc for tc in self.transition_criteria if tc.block_transition_if_unmet
]
transition_blocking_met = all(
tc.is_met(
experiment=self.experiment,
trials_from_node=self.trials_from_node,
curr_node_name=self.node_name,
node_that_generated_last_gr=(
self.generation_strategy.last_generator_run._generation_node_name
if self.generation_strategy.last_generator_run is not None
else None
),
# for each edge in node DAG, check if the transition criterion are met, if so
# transition to the next node defined by that edge.
for next_node, all_tc in self.transition_edges.items():
transition_blocking = [tc for tc in all_tc if tc.block_transition_if_unmet]
gs_lgr = self.generation_strategy.last_generator_run
transition_blocking_met = all(
tc.is_met(
experiment=self.experiment,
trials_from_node=self.trials_from_node,
curr_node_name=self.node_name,
# TODO @mgarrard: should we instead pass a backpointer to gs/node
node_that_generated_last_gr=(
gs_lgr._generation_node_name if gs_lgr is not None else None
),
)
for tc in transition_blocking
)
for tc in transition_blocking
)
# Raise any necessary generation errors: for any met criterion,
# call its `block_continued_generation_error` method if not all
# transition-blocking criteria are met. The method might not raise an
# error, depending on its implementation on given criterion, so the error
# from the first met one that does block continued generation, will be raised.
if not transition_blocking_met:
for tc in self.transition_criteria:
if (
tc.is_met(self.experiment, trials_from_node=self.trials_from_node)
and raise_data_required_error
):
tc.block_continued_generation_error(
node_name=self.node_name,
model_name=self.model_to_gen_from_name,
experiment=self.experiment,
trials_from_node=self.trials_from_node,
)

# Determine transition state
if len(transition_blocking) > 0 and transition_blocking_met:
next_nodes = [
c.transition_to
for c in transition_blocking
if c._transition_to is not None
]
if len(set(next_nodes)) > 1:
# TODO: support intelligent selection between multiple transition nodes
raise NotImplementedError(
"Cannot currently select between multiple nodes to transition to."
)
else:
return True, next_nodes[0]
# Raise any necessary generation errors: for any met criterion,
# call its `block_continued_generation_error` method if not all
# transition-blocking criteria are met. The method might not raise an
# error, depending on its implementation on given criterion, so the error
# from the first met one that does block continued generation, will raise.
# TODO: @mgarrard see if we can replace MaxGenerationParallelism with a
# transition to self and rework this error block.
if not transition_blocking_met:
for tc in all_tc:
if (
tc.is_met(
self.experiment, trials_from_node=self.trials_from_node
)
and raise_data_required_error
):
tc.block_continued_generation_error(
node_name=self.node_name,
model_name=self.model_to_gen_from_name,
experiment=self.experiment,
trials_from_node=self.trials_from_node,
)
if len(transition_blocking) > 0 and transition_blocking_met:
return True, next_node

return False, None

Expand Down
116 changes: 100 additions & 16 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,53 @@ def setUp(self) -> None:
transition_criteria=self.gpei_criterion,
model_specs=[self.gpei_model_spec],
)

self.sobol_GPEI_GS_nodes = GenerationStrategy(
name="Sobol+GPEI_Nodes",
nodes=[self.sobol_node, self.gpei_node],
)
self.gpei_to_sobol2_max = MaxTrials(
threshold=1,
transition_to="sobol_2",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
use_all_trials_in_exp=True,
)
self.gpei_to_sobol2_min = MinTrials(
threshold=1,
transition_to="sobol_2",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.COMPLETED],
use_all_trials_in_exp=True,
)
self.gpei_to_sobol_auto = AutoTransitionAfterGenCriterion(
transition_to="sobol_3"
)
self.competing_tc_gs = GenerationStrategy(
nodes=[
GenerationNode(
node_name="sobol",
model_specs=[self.sobol_model_spec],
transition_criteria=self.single_running_trial_criterion,
),
GenerationNode(
node_name="gpei",
model_specs=[self.gpei_model_spec],
transition_criteria=[
self.gpei_to_sobol2_max,
self.gpei_to_sobol2_min,
self.gpei_to_sobol_auto,
],
),
GenerationNode(
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
),
GenerationNode(
node_name="sobol_3",
model_specs=[self.sobol_model_spec],
),
],
)

def tearDown(self) -> None:
self.torch_model_bridge_patcher.stop()
Expand Down Expand Up @@ -1280,23 +1322,65 @@ def test_generation_strategy_eq_no_print(self) -> None:
gs2 = self.basic_sobol_gpei_gs
self.assertEqual(gs1, gs2)

def test_gs_with_competing_transition_edges(self) -> None:
"""Test that a ```GenerationStrategy``` with a node with competing transition
edges correctly transitions.
"""
# this gs has a single sobol node which transitions to gpei. If the MaxTrials
# and MinTrials criterion are met, the transition to sobol_2 should occur,
# otherwise, should transition to sobol_3
gs = self.competing_tc_gs
exp = get_branin_experiment()

# check that gpei will move to sobol_3 when MaxTrials and MinTrials are unmet
exp.new_trial(generator_run=gs.gen(exp)).run()
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "sobol_3")

def test_gs_with_competing_transition_edges_2(self) -> None:
"""Test that a ```GenerationStrategy``` with a node with competing transition
edges correctly transitions.
"""
# this gs has a single sobol node which transitions to gpei. If the MaxTrials
# and MinTrials criterion are met, the transition to sobol_2 should occur,
# otherwise, should transition to sobol_3
gs = self.competing_tc_gs
exp = get_branin_experiment()

# check that gpei will move to sobol_3 when MaxTrials is met and MinTrials
# is unmet
exp.new_trial(generator_run=gs.gen(exp)).run()
exp.new_trial(generator_run=gs.gen(exp)).run()
self.assertEqual(gs.current_node_name, "gpei")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "sobol_3")

def test_gs_with_competing_transition_edges_3(self) -> None:
"""Test that a ```GenerationStrategy``` with a node with competing transition
edges correctly transitions.
"""
# this gs has a single sobol node which transitions to gpei. If the MaxTrials
# and MinTrials criterion are met, the transition to sobol_2 should occur,
# otherwise, should transition to sobol_3
gs = self.competing_tc_gs
exp = get_branin_experiment()

# check that gpei will transition to sobol_2 when MaxTrials is met and
# MinTrials are met
trial = exp.new_trial(generator_run=gs.gen(exp)).run()
exp.new_trial(generator_run=gs.gen(exp)).run()
self.assertEqual(gs.current_node_name, "gpei")
trial.mark_completed()
gs.gen(exp)
self.assertEqual(gs.current_node_name, "sobol_2")

def test_transition_edges(self) -> None:
"""Test transition_edges property of ``GenerationNode``"""
# this gs has a single sobol node which transitions to gpei. If the MaxTrials
# and MinTrials criterion are met, the transition to sobol_2 should occur,
# otherwise, should transition back to sobol.
gpei_to_sobol2_max = MaxTrials(
threshold=2,
transition_to="sobol_2",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
)
gpei_to_sobol2_min = MinTrials(
threshold=1,
transition_to="sobol_2",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.COMPLETED],
)
gpei_to_sobol_auto = AutoTransitionAfterGenCriterion(transition_to="sobol")
gs = GenerationStrategy(
nodes=[
Expand All @@ -1309,8 +1393,8 @@ def test_transition_edges(self) -> None:
node_name="gpei",
model_specs=[self.gpei_model_spec],
transition_criteria=[
gpei_to_sobol2_max,
gpei_to_sobol2_min,
self.gpei_to_sobol2_max,
self.gpei_to_sobol2_min,
gpei_to_sobol_auto,
],
),
Expand All @@ -1330,7 +1414,7 @@ def test_transition_edges(self) -> None:
self.assertEqual(
gs._curr.transition_edges,
{
"sobol_2": [gpei_to_sobol2_max, gpei_to_sobol2_min],
"sobol_2": [self.gpei_to_sobol2_max, self.gpei_to_sobol2_min],
"sobol": [gpei_to_sobol_auto],
},
)
Expand Down

0 comments on commit 8926797

Please sign in to comment.