Skip to content

Commit

Permalink
model <> node 4/n: Support multiple node generation for single trial (#…
Browse files Browse the repository at this point in the history
…2428)

Summary:

This diff enables multiple nodes to be used to generate a single batch trial. Right now the limitations are that: 
(1) currently each node only contributes 1 gr to the trial -- we will extend this to n grs in a future diff 

To do this we added:
(1) should_move_trials method
(2) updates to the transition criterion args

Differential Revision: D56743651
  • Loading branch information
mgarrard authored and facebook-github-bot committed May 9, 2024
1 parent 47ea0cf commit 4e15316
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 40 deletions.
90 changes: 89 additions & 1 deletion ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ class GenerationStrategy(GenerationStrategyInterface):
# Experiment, for which this generation strategy has generated trials, if
# it exists.
_experiment: Optional[Experiment] = None
# Trial indices as last seen by the model; updated in `_model` property setter.
_model: Optional[ModelBridge] = None # Current model.

def __init__(
Expand Down Expand Up @@ -376,6 +375,55 @@ def gen(
**kwargs,
)[0]

def gen_with_multiple_nodes(
self,
experiment: Experiment,
data: Optional[Data] = None,
n: int = 1, # Total arms to generate
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
**kwargs: Any, # TODO: @mgarrard Ensure correct dispatch to nodes
) -> List[GeneratorRun]:
"""Produces a List of GeneratorRuns for a single trial, either ```Trial``` or
```BatchTrial```, and if producing a ```BatchTrial`` allows for multiple
models to be used to generate GeneratorRuns for that trial.
NOTE: This method is in development. Please do not use it yet.
Args:
experiment: Experiment, for which the generation strategy is producing
a new generator run in the course of `gen`, and to which that
generator run will be added as trial(s). Information stored on the
experiment (e.g., trial statuses) is used to determine which model
will be used to produce the generator run returned from this method.
data: Optional data to be passed to the underlying model's `gen`, which
is called within this method and actually produces the resulting
generator run. By default, data is all data on the `experiment`.
n: Integer representing how total arms to generate for this trial.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
Returns:
A list of ```GeneratorRuns``` for a single trial.
"""
# TODO: @mgarrard merge into gen method, just starting here to derisk
grs = []
gen_for_complete_trial = False

while not gen_for_complete_trial:
grs.extend(
self._gen_multiple(
experiment=experiment,
num_generator_runs=1,
data=data,
n=n,
pending_observations=pending_observations,
**kwargs,
)
)
gen_for_complete_trial = self._will_gen_for_new_trial()
return grs

def gen_for_multiple_trials_with_multiple_models(
self,
experiment: Experiment,
Expand Down Expand Up @@ -708,6 +756,46 @@ def _gen_multiple(
)
return generator_runs

def _will_gen_for_new_trial(self) -> bool:
"""Uses the transition criterion defined on ```GenerationNode``` to determine
if generation for this trial is complete, and we should move to the next trial.
Returns:
A boolean which represents if generation for a trial is complete
"""
# if no criterion defined, always move to the next trial
if len(self._curr.transition_criteria) == 0:
return True

trial_moving_criterion = [
tc for tc in self._curr.transition_criteria if tc.complete_trial_generation
]
if len(trial_moving_criterion) > 0:
trial_tc = [
tc for tc in trial_moving_criterion if tc.block_transition_if_unmet
]

if len(trial_tc) > 0:
return any(
tc.is_met(
experiment=self._curr.experiment,
trials_from_node=self._curr.trials_from_node,
node_name=self._curr.node_name,
node_that_generated_last_gr=(
self.last_generator_run._generation_node_name
if self.last_generator_run is not None
else None
),
)
for tc in trial_tc
)
# no trial_tc is defined so move to next trial
return True

# if no enforce_next_trial criterion is specified, default not moving to
# the next trial
return False

# ------------------------- Model selection logic helpers. -------------------------

def _fit_current_model(self, data: Optional[Data]) -> None:
Expand Down
3 changes: 2 additions & 1 deletion ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def test_node_string_representation(self) -> None:
"'only_in_statuses': [<enum 'TrialStatus'>.RUNNING], "
"'not_in_statuses': None, 'transition_to': None, "
"'block_transition_if_unmet': True, 'block_gen_if_met': False, "
"'use_all_trials_in_exp': False})])"
"'use_all_trials_in_exp': False, "
"'complete_trial_generation': True})])"
),
)

Expand Down
109 changes: 78 additions & 31 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features(
GenerationStep(
model=Models.BOTORCH_MODULAR,
model_kwargs={
# this will cause and error if the model
# this will cause an error if the model
# doesn't get fixed features
"transforms": ST_MTGP_trans,
**self.step_model_kwargs,
Expand Down Expand Up @@ -1445,83 +1445,130 @@ def test_node_gs_with_auto_transitions(self) -> None:
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei_2")
AutoTransitionAfterGenCriterion(transition_to="sobol_3")
],
),
GenerationNode(
node_name="gpei_2",
model_specs=[self.gpei_model_spec],
node_name="sobol_3",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei")
AutoTransitionAfterGenCriterion(
transition_to="gpei",
block_transition_if_unmet=True,
complete_trial_generation=True,
)
],
),
],
)
exp = get_branin_experiment()
gs.experiment = exp

# for the first trial, we start on sobol, we generate the trial, but it hasn't
# been run yet, so we remain on sobol, after the trial is run, the subsequent
# trials should be from node gpei, sobol_2, and sobol_3
self.assertEqual(gs.current_node_name, "sobol")
trial0 = exp.new_batch_trial(generator_runs=gs.gen_with_multiple_nodes(exp))
self.assertEqual(gs.current_node_name, "sobol")
exp.new_trial(generator_run=gs.gen(exp)).run()
# while here, test the last generator run property on node
self.assertEqual(gs.current_node.node_that_generated_last_gr, "sobol")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "sobol_2")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei_2")
gs.gen(exp)
self.assertEqual(gs.current_node_name, "gpei")

# TODO: @mgarrard modify below test when gen handles multiple nodes
trial0.run()
for _i in range(0, 2):
trial = exp.new_batch_trial(generator_runs=gs.gen_with_multiple_nodes(exp))
self.assertEqual(gs.current_node_name, "sobol_3")
self.assertEqual(len(trial.generator_runs), 3)
self.assertEqual(trial.generator_runs[0]._generation_node_name, "gpei")
self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_2")
self.assertEqual(trial.generator_runs[2]._generation_node_name, "sobol_3")

def test_node_gs_with_auto_transitions_three_phase(self) -> None:
exp = get_branin_experiment()
gs_2 = GenerationStrategy(
nodes=[
GenerationNode(
node_name="sobol",
model_specs=[self.sobol_model_spec],
transition_criteria=self.single_running_trial_criterion,
transition_criteria=[
MaxTrials(
threshold=1,
transition_to="gpei",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
),
AutoTransitionAfterGenCriterion(
transition_to="sobol", complete_trial_generation=True
),
],
),
GenerationNode(
node_name="gpei",
model_specs=[self.gpei_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="sobol_2")
AutoTransitionAfterGenCriterion(
transition_to="sobol_2",
)
],
),
GenerationNode(
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGenCriterion(transition_to="gpei_2")
AutoTransitionAfterGenCriterion(transition_to="sobol_3")
],
),
GenerationNode(
node_name="gpei_2",
model_specs=[self.gpei_model_spec],
node_name="sobol_3",
model_specs=[self.sobol_model_spec],
transition_criteria=[
MaxTrials(
threshold=2,
transition_to="sobol_3",
transition_to="sobol_4",
block_transition_if_unmet=True,
only_in_statuses=[TrialStatus.RUNNING],
use_all_trials_in_exp=True,
)
),
AutoTransitionAfterGenCriterion(
transition_to="gpei",
block_transition_if_unmet=True,
complete_trial_generation=True,
),
],
),
GenerationNode(
node_name="sobol_3",
node_name="sobol_4",
model_specs=[self.sobol_model_spec],
),
],
)
gs_2.experiment = exp

# for the first trial, we start on sobol, we generate the trial, but it hasn't
# been run yet, so we remain on sobol

self.assertEqual(gs_2.current_node_name, "sobol")
exp.new_trial(generator_run=gs_2.gen(exp)).run()
self.assertEqual(gs_2.current_node_name, "gpei")
gs_2.gen(exp)
self.assertEqual(gs_2.current_node_name, "sobol_2")
gs_2.gen(exp) # noqa
self.assertEqual(gs_2.current_node_name, "gpei_2")
exp.new_trial(generator_run=gs_2.gen(exp)).run()
self.assertEqual(gs_2.current_node_name, "sobol_3")
trial0 = exp.new_batch_trial(generator_runs=gs_2.gen_with_multiple_nodes(exp))
self.assertEqual(gs_2.current_node_name, "sobol")
trial0.run()

# after trial 0 is run, we create a trial with nodes gpei, sobol_2, and sobol_3
# However, the sobol_3 criterion requires that we have two running trials. We
# don't move onto sobol_4 until we have two running trials, instead we reset
# to the last first node in a trial.
for _i in range(0, 2):
trial = exp.new_batch_trial(
generator_runs=gs_2.gen_with_multiple_nodes(exp)
)
self.assertEqual(gs_2.current_node_name, "sobol_3")
self.assertEqual(len(trial.generator_runs), 3)
self.assertEqual(trial.generator_runs[0]._generation_node_name, "gpei")
self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_2")
self.assertEqual(trial.generator_runs[2]._generation_node_name, "sobol_3")

# after running the next trial should be made from sobol 4
trial.run()
trial = exp.new_batch_trial(generator_runs=gs_2.gen_with_multiple_nodes(exp))
self.assertEqual(trial.generator_runs[0]._generation_node_name, "sobol_4")

# ------------- Testing helpers (put tests above this line) -------------

Expand Down
13 changes: 9 additions & 4 deletions ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ def test_repr(self) -> None:
+ "'transition_to': 'GenerationStep_1', "
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False})",
+ "'use_all_trials_in_exp': False, "
+ "'complete_trial_generation': True})",
)
minimum_trials_in_status_criterion = MinTrials(
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
Expand All @@ -446,7 +447,8 @@ def test_repr(self) -> None:
+ "'transition_to': 'GenerationStep_2', "
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False})",
+ "'use_all_trials_in_exp': False, "
+ "'complete_trial_generation': True})",
)
minimum_preference_occurances_criterion = MinimumPreferenceOccurances(
metric_name="m1", threshold=3
Expand Down Expand Up @@ -483,12 +485,15 @@ def test_repr(self) -> None:
+ "'transition_to': 'GenerationStep_2', "
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False})",
+ "'use_all_trials_in_exp': False, "
+ "'complete_trial_generation': False})",
)
auto_transition = AutoTransitionAfterGenCriterion(
transition_to="GenerationStep_2"
)
self.assertEqual(
str(auto_transition),
"AutoTransitionAfterGenCriterion({'transition_to': 'GenerationStep_2'})",
"AutoTransitionAfterGenCriterion({'transition_to': 'GenerationStep_2', "
+ "'complete_trial_generation': False, "
+ "'block_transition_if_unmet': True})",
)
Loading

0 comments on commit 4e15316

Please sign in to comment.