Skip to content

Commit

Permalink
model <> node 4/n: Support multiple node generation for single trial (#…
Browse files Browse the repository at this point in the history
…2428)

Summary:

This diff enables multiple nodes to be used to generate a single batch trial. Right now the limitations are that: 
(1) currently each node only contributes 1 gr to the trial -- we will extend this to n grs in a future diff 

To do this we added:
(1) should_move_trials method
(2) updates to the transition criterion args

Differential Revision: D56743651
  • Loading branch information
mgarrard authored and facebook-github-bot committed May 10, 2024
1 parent ec742c9 commit 0cabf69
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 39 deletions.
75 changes: 74 additions & 1 deletion ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ class GenerationStrategy(GenerationStrategyInterface):
# Experiment, for which this generation strategy has generated trials, if
# it exists.
_experiment: Optional[Experiment] = None
# Trial indices as last seen by the model; updated in `_model` property setter.
_model: Optional[ModelBridge] = None # Current model.

def __init__(
Expand Down Expand Up @@ -376,6 +375,55 @@ def gen(
**kwargs,
)[0]

def gen_with_multiple_nodes(
self,
experiment: Experiment,
data: Optional[Data] = None,
n: int = 1, # Total arms to generate
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
**kwargs: Any, # TODO: @mgarrard Ensure correct dispatch to nodes
) -> List[GeneratorRun]:
"""Produces a List of GeneratorRuns for a single trial, either ```Trial``` or
```BatchTrial```, and if producing a ```BatchTrial`` allows for multiple
models to be used to generate GeneratorRuns for that trial.
NOTE: This method is in development. Please do not use it yet.
Args:
experiment: Experiment, for which the generation strategy is producing
a new generator run in the course of `gen`, and to which that
generator run will be added as trial(s). Information stored on the
experiment (e.g., trial statuses) is used to determine which model
will be used to produce the generator run returned from this method.
data: Optional data to be passed to the underlying model's `gen`, which
is called within this method and actually produces the resulting
generator run. By default, data is all data on the `experiment`.
n: Integer representing how total arms to generate for this trial.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
Returns:
A list of ```GeneratorRuns``` for a single trial.
"""
# TODO: @mgarrard merge into gen method, just starting here to derisk
grs = []
continue_gen_for_trial = True

while continue_gen_for_trial:
grs.extend(
self._gen_multiple(
experiment=experiment,
num_generator_runs=1,
data=data,
n=n,
pending_observations=pending_observations,
**kwargs,
)
)
continue_gen_for_trial = self._should_continue_gen_for_trial()
return grs

def gen_for_multiple_trials_with_multiple_models(
self,
experiment: Experiment,
Expand Down Expand Up @@ -708,6 +756,31 @@ def _gen_multiple(
)
return generator_runs

def _should_continue_gen_for_trial(self) -> bool:
"""Determine if we should continue generating for the current trial, or end
generation for the current trial.
Returns:
A boolean which represents if generation for a trial is complete
"""
should_transition, next_node = self._curr.should_transition_to_next_node(
raise_data_required_error=False
)
# if we should not transition nodes, we should stop generation for this trial.
if not should_transition:
return False

# if we will transition nodes, check if the transition criterion which define
# the transition from this node to the next node indicate that we should
# continue generating in the same trial, otherwise end the generation.
assert next_node is not None
if all(
tc.continue_trial_generation
for tc in self._curr.transition_edges[next_node]
):
return True
return False

# ------------------------- Model selection logic helpers. -------------------------

def _fit_current_model(self, data: Optional[Data]) -> None:
Expand Down
3 changes: 2 additions & 1 deletion ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def test_node_string_representation(self) -> None:
"'only_in_statuses': [<enum 'TrialStatus'>.RUNNING], "
"'not_in_statuses': None, 'transition_to': None, "
"'block_transition_if_unmet': True, 'block_gen_if_met': False, "
"'use_all_trials_in_exp': False})])"
"'use_all_trials_in_exp': False, "
"'continue_trial_generation': False})])"
),
)

Expand Down
97 changes: 67 additions & 30 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features(
GenerationStep(
model=Models.BOTORCH_MODULAR,
model_kwargs={
# this will cause and error if the model
# this will cause an error if the model
# doesn't get fixed features
"transforms": ST_MTGP_trans,
**self.step_model_kwargs,
Expand Down Expand Up @@ -1445,34 +1445,45 @@ def test_node_gs_with_auto_transitions(self) -> None:
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei_2")
AutoTransitionAfterGenCriterion(transition_to="sobol_3")
],
),
GenerationNode(
node_name="gpei_2",
model_specs=[self.gpei_model_spec],
node_name="sobol_3",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei")
AutoTransitionAfterGenCriterion(
transition_to="gpei",
block_transition_if_unmet=True,
continue_trial_generation=False,
)
],
),
],
)
exp = get_branin_experiment()
gs.experiment = exp

# for the first trial, we start on sobol, we generate the trial, but it hasn't
# been run yet, so we remain on sobol, after the trial is run, the subsequent
# trials should be from node gpei, sobol_2, and sobol_3
self.assertEqual(gs.current_node_name, "sobol")
trial0 = exp.new_batch_trial(generator_runs=gs.gen_with_multiple_nodes(exp))
self.assertEqual(gs.current_node_name, "sobol")
exp.new_trial(generator_run=gs.gen(exp)).run()
# while here, test the last generator run property on node
self.assertEqual(gs.current_node.node_that_generated_last_gr, "sobol")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "sobol_2")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei_2")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei")

# TODO: @mgarrard modify below test when gen handles multiple nodes
trial0.run()
for _i in range(0, 2):
trial = exp.new_batch_trial(generator_runs=gs.gen_with_multiple_nodes(exp))
self.assertEqual(gs.current_node_name, "sobol_3")
self.assertEqual(len(trial.generator_runs), 3)
self.assertEqual(trial.generator_runs[0]._generation_node_name, "gpei")
self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_2")
self.assertEqual(trial.generator_runs[2]._generation_node_name, "sobol_3")

def test_node_gs_with_auto_transitions_three_phase(self) -> None:
exp = get_branin_experiment()
gs_2 = GenerationStrategy(
nodes=[
GenerationNode(
Expand All @@ -1484,44 +1495,70 @@ def test_node_gs_with_auto_transitions(self) -> None:
node_name="gpei",
model_specs=[self.gpei_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="sobol_2")
AutoTransitionAfterGenCriterion(
transition_to="sobol_2",
)
],
),
GenerationNode(
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei_2")
AutoTransitionAfterGenCriterion(transition_to="sobol_3")
],
),
GenerationNode(
node_name="gpei_2",
model_specs=[self.gpei_model_spec],
node_name="sobol_3",
model_specs=[self.sobol_model_spec],
transition_criteria=[
MaxTrials(
threshold=2,
transition_to="sobol_3",
transition_to="sobol_4",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
use_all_trials_in_exp=True,
)
),
AutoTransitionAfterGenCriterion(
transition_to="gpei",
block_transition_if_unmet=True,
continue_trial_generation=False,
),
],
),
GenerationNode(
node_name="sobol_3",
node_name="sobol_4",
model_specs=[self.sobol_model_spec],
),
],
)
gs_2.experiment = exp

# for the first trial, we start on sobol, we generate the trial, but it hasn't
# been run yet, so we remain on sobol

self.assertEqual(gs_2.current_node_name, "sobol")
trial0 = exp.new_batch_trial(generator_runs=gs_2.gen_with_multiple_nodes(exp))
self.assertEqual(gs_2.current_node_name, "sobol")
exp.new_trial(generator_run=gs_2.gen(exp)).run()
self.assertEqual(gs_2.current_node_name, "gpei")
gs_2.gen(exp)
self.assertEqual(gs_2.current_node_name, "sobol_2")
gs_2.gen(exp) # noqa
self.assertEqual(gs_2.current_node_name, "gpei_2")
exp.new_trial(generator_run=gs_2.gen(exp)).run()
self.assertEqual(gs_2.current_node_name, "sobol_3")
trial0.run()

# after trial 0 is run, we create a trial with nodes gpei, sobol_2, and sobol_3
# However, the sobol_3 criterion requires that we have two running trials. We
# don't move onto sobol_4 until we have two running trials, instead we reset
# to the last first node in a trial.
for _i in range(0, 2):
trial = exp.new_batch_trial(
generator_runs=gs_2.gen_with_multiple_nodes(exp)
)
self.assertEqual(gs_2.current_node_name, "sobol_3")
self.assertEqual(len(trial.generator_runs), 3)
self.assertEqual(trial.generator_runs[0]._generation_node_name, "gpei")
self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_2")
self.assertEqual(trial.generator_runs[2]._generation_node_name, "sobol_3")

# after running the next trial should be made from sobol 4
trial.run()
trial = exp.new_batch_trial(generator_runs=gs_2.gen_with_multiple_nodes(exp))
self.assertEqual(trial.generator_runs[0]._generation_node_name, "sobol_4")

# ------------- Testing helpers (put tests above this line) -------------

Expand Down
13 changes: 9 additions & 4 deletions ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ def test_repr(self) -> None:
+ "'transition_to': 'GenerationStep_1', "
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False})",
+ "'use_all_trials_in_exp': False, "
+ "'continue_trial_generation': False})",
)
minimum_trials_in_status_criterion = MinTrials(
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
Expand All @@ -446,7 +447,8 @@ def test_repr(self) -> None:
+ "'transition_to': 'GenerationStep_2', "
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False})",
+ "'use_all_trials_in_exp': False, "
+ "'continue_trial_generation': False})",
)
minimum_preference_occurances_criterion = MinimumPreferenceOccurances(
metric_name="m1", threshold=3
Expand Down Expand Up @@ -483,12 +485,15 @@ def test_repr(self) -> None:
+ "'transition_to': 'GenerationStep_2', "
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False})",
+ "'use_all_trials_in_exp': False, "
+ "'continue_trial_generation': True})",
)
auto_transition = AutoTransitionAfterGenCriterion(
transition_to="GenerationStep_2"
)
self.assertEqual(
str(auto_transition),
"AutoTransitionAfterGenCriterion({'transition_to': 'GenerationStep_2'})",
"AutoTransitionAfterGenCriterion({'transition_to': 'GenerationStep_2', "
+ "'block_transition_if_unmet': True, "
+ "'continue_trial_generation': True})",
)
Loading

0 comments on commit 0cabf69

Please sign in to comment.