Skip to content

Commit

Permalink
model <> node 2/n: Add transition edges property to GenerationNode (f…
Browse files Browse the repository at this point in the history
…acebook#2446)

Summary:

As part of modifying the generation node dag we want to support multiple transition edges now (meaning one node x could transition either to node y or node z depending on what criterion are met). This property creates a dict with {'next_node_name': [tc]} where tc is the transition criterion to move the gs from the current node to the `next node`

The reason we need this, is currently and in any future i can envision, to transition to another node all criterion for that edge that are transition blocking must be met. In the next diff I modify the should_transition method. I just like bite sized diffs :)

Reviewed By: lena-kashtelyan

Differential Revision: D57127974
  • Loading branch information
mgarrard authored and facebook-github-bot committed May 9, 2024
1 parent d330708 commit da5b96e
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 4 deletions.
30 changes: 27 additions & 3 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from __future__ import annotations

from collections import defaultdict

from dataclasses import dataclass, field
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
Expand Down Expand Up @@ -153,7 +155,7 @@ def generation_strategy(self) -> modelbridge.generation_strategy.GenerationStrat
"""Returns a backpointer to the GenerationStrategy, useful for obtaining the
experiment associated with this GenerationStrategy"""
# TODO: @mgarrard remove this property once we make experiment a required
# arguement on GenerationStrategy
# argument on GenerationStrategy
if self._generation_strategy is None:
raise ValueError(
"Generation strategy has not been initialized on this node."
Expand Down Expand Up @@ -296,7 +298,7 @@ def gen(
)
assert generator_run is not None, (
"The GeneratorRun is None which is an unexpected state of this"
" GenerationStrategy. This occured on GenerationNode: {self.node_name}."
" GenerationStrategy. This occurred on GenerationNode: {self.node_name}."
)
generator_run._generation_node_name = self.node_name
return generator_run
Expand Down Expand Up @@ -397,6 +399,27 @@ def node_that_generated_last_gr(self) -> Optional[str]:
else None
)

@property
def transition_edges(self) -> Dict[str, List[TransitionCriterion]]:
"""Returns a dictionary mapping the next ```GenerationNode``` to the
TransitionCriteria that define the transition that that node.
Ex: if the transition from the current node to node x is defined by MaxTrials
and MinTrials criterion then the return would be {'x': [MaxTrials, MinTrials]}.
Returns:
Dict[str, List[TransitionCriterion]]: A dictionary mapping the next
```GenerationNode``` to the ```TransitionCriterion``` that are associated
with it.
"""
if self.transition_criteria is None:
return {}

tc_edges = defaultdict(list)
for tc in self.transition_criteria:
tc_edges[tc.transition_to].append(tc)
return tc_edges

def should_transition_to_next_node(
self, raise_data_required_error: bool = True
) -> Tuple[bool, Optional[str]]:
Expand All @@ -409,7 +432,8 @@ def should_transition_to_next_node(
check how many generator runs (to be made into trials) can be produced,
but not actually producing them yet.
Returns:
bool: Whether we should transition to the next node.
Tuple[bool, Optional[str]]: Whether we should transition to the next node
and the name of the next node.
"""
# if no transition criteria are defined, this node can generate unlimited trials
if len(self.transition_criteria) == 0:
Expand Down
57 changes: 56 additions & 1 deletion ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features(

# ---------- Tests for GenerationStrategies composed of GenerationNodes --------
def test_gs_setup_with_nodes(self) -> None:
"""Test GS initalization and validation with nodes"""
"""Test GS initialization and validation with nodes"""
node_1_criterion = [
MaxTrials(
threshold=4,
Expand Down Expand Up @@ -1280,6 +1280,61 @@ def test_generation_strategy_eq_no_print(self) -> None:
gs2 = self.basic_sobol_gpei_gs
self.assertEqual(gs1, gs2)

def test_transition_edges(self) -> None:
"""Test transition_edges property of ``GenerationNode``"""
# this gs has a single sobol node which transitions to gpei. If the MaxTrials
# and MinTrials criterion are met, the transition to sobol_2 should occur,
# otherwise, should transition back to sobol.
gpei_to_sobol2_max = MaxTrials(
threshold=2,
transition_to="sobol_2",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
)
gpei_to_sobol2_min = MinTrials(
threshold=1,
transition_to="sobol_2",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.COMPLETED],
)
gpei_to_sobol_auto = AutoTransitionAfterGenCriterion(transition_to="sobol")
gs = GenerationStrategy(
nodes=[
GenerationNode(
node_name="sobol",
model_specs=[self.sobol_model_spec],
transition_criteria=self.single_running_trial_criterion,
),
GenerationNode(
node_name="gpei",
model_specs=[self.gpei_model_spec],
transition_criteria=[
gpei_to_sobol2_max,
gpei_to_sobol2_min,
gpei_to_sobol_auto,
],
),
GenerationNode(
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
),
],
)
exp = get_branin_experiment()
self.assertEqual(
gs._curr.transition_edges, {"gpei": self.single_running_trial_criterion}
)
exp.new_trial(generator_run=gs.gen(exp)).run()
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei")
self.assertEqual(
gs._curr.transition_edges,
{
"sobol_2": [gpei_to_sobol2_max, gpei_to_sobol2_min],
"sobol": [gpei_to_sobol_auto],
},
)

def test_node_gs_with_auto_transitions(self) -> None:
"""Test that node-based generation strategies which leverage
AutoTransitionAfterGen criterion correctly transition and create trials.
Expand Down

0 comments on commit da5b96e

Please sign in to comment.