Skip to content

Commit

Permalink
Support GenerationNodes and GenerationSteps in GenerationStrategy and…
Browse files Browse the repository at this point in the history
… default to GenerationNodes (facebook#2024)

Summary:

This diff does the following:
Supports GenerationNodes at the level of GenerationStrategy. This is the big hurrah diff! 

upcoming:
(0) Add decorator for functions that are only supported in steps
(1) update the storage to include nodes independently (and not just as part of step)
(2) delete now unused GenStep functions
(3) final pass on all the doc strings and variables -- lots to clean up here
(4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode
(5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed
(6) rename transiton criterion to action criterion
(7) remove conditionals for legacy usecase
(8) clean up any lingering todos

Reviewed By: lena-kashtelyan

Differential Revision: D51120002
  • Loading branch information
Mia Garrard authored and facebook-github-bot committed Dec 6, 2023
1 parent f76f721 commit ff5fed4
Show file tree
Hide file tree
Showing 15 changed files with 526 additions and 72 deletions.
11 changes: 11 additions & 0 deletions ax/exceptions/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,14 @@ class GenerationStrategyRepeatedPoints(GenerationStrategyCompleted):
"""

pass


class GenerationStrategyMisconfiguredException(AxError):
"""Special exception indicating that the generation strategy is misconfigured."""

def __init__(self, error_info: Optional[str]) -> None:
super().__init__(
"This GenerationStrategy was unable to be initialized properly. Please "
+ "check the documentation, and adjust the configuration accordingly. "
+ f"{error_info}"
)
1 change: 1 addition & 0 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class GenerationNode:

# Optional specifications
_model_spec_to_gen_from: Optional[ModelSpec] = None
# TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping?
_transition_criteria: Optional[Sequence[TransitionCriterion]]

# [TODO] Handle experiment passing more eloquently by enforcing experiment
Expand Down
270 changes: 217 additions & 53 deletions ax/modelbridge/generation_strategy.py

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions ax/modelbridge/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple

from ax.core.data import Data
Expand All @@ -27,7 +28,7 @@
CVResult,
)
from ax.modelbridge.registry import ModelRegistryBase
from ax.utils.common.base import Base
from ax.utils.common.base import SortableBase
from ax.utils.common.kwargs import (
consolidate_kwargs,
filter_kwargs,
Expand All @@ -48,7 +49,7 @@ def default(self, o: Any) -> str:


@dataclass
class ModelSpec(Base):
class ModelSpec(SortableBase):
model_enum: ModelRegistryBase
# Kwargs to pass into the `Model` + `ModelBridge` constructors in
# `ModelRegistryBase.__call__`.
Expand Down Expand Up @@ -288,6 +289,12 @@ def __hash__(self) -> int:
def __eq__(self, other: ModelSpec) -> bool:
return repr(self) == repr(other)

@property
def _unique_id(self) -> str:
"""Returns the unique ID of this model spec"""
# TODO @mgarrard verify that this is unique enough
return str(hash(self))


@dataclass
class FactoryFunctionModelSpec(ModelSpec):
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_completion_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_single_criterion(self) -> None:
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)

def test_many_criteria(self) -> None:
criteria = [
Expand Down Expand Up @@ -145,4 +145,4 @@ def test_many_criteria(self) -> None:
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)
260 changes: 256 additions & 4 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import cast, List
from unittest import mock
from unittest.mock import MagicMock, patch
Expand All @@ -21,16 +22,23 @@
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
GenerationStrategyMisconfiguredException,
GenerationStrategyRepeatedPoints,
MaxParallelismReachedException,
)
from ax.modelbridge.discrete import DiscreteModelBridge
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.registry import Cont_X_trans, MODEL_KEY_TO_MODEL_SETUP, Models
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transition_criterion import (
MaxGenerationParallelism,
MaxTrials,
MinTrials,
)
from ax.models.random.sobol import SobolGenerator
from ax.utils.common.equality import same_elements
from ax.utils.common.mock import mock_patch_method_original
Expand Down Expand Up @@ -416,8 +424,8 @@ def test_clone_reset(self) -> None:
]
)
ftgs._curr = ftgs._steps[1]
self.assertEqual(ftgs._curr.index, 1)
self.assertEqual(ftgs.clone_reset()._curr.index, 0)
self.assertEqual(ftgs.current_step_index, 1)
self.assertEqual(ftgs.clone_reset().current_step_index, 0)

def test_kwargs_passed(self) -> None:
gs = GenerationStrategy(
Expand Down Expand Up @@ -527,10 +535,12 @@ def test_trials_as_df(self) -> None:
# attach necessary trials to fill up the Generation Strategy
trial = exp.new_trial(sobol_generation_strategy.gen(experiment=exp))
self.assertEqual(
sobol_generation_strategy.trials_as_df.head()["Generation Step"][0], 0
sobol_generation_strategy.trials_as_df.head()["Generation Step"][0],
"GenerationStep_0",
)
self.assertEqual(
sobol_generation_strategy.trials_as_df.head()["Generation Step"][2], 1
sobol_generation_strategy.trials_as_df.head()["Generation Step"][2],
"GenerationStep_1",
)

def test_max_parallelism_reached(self) -> None:
Expand Down Expand Up @@ -883,6 +893,248 @@ def test_gen_for_multiple_trials_with_multiple_models(self) -> None:
for p in original_pending[m]:
self.assertIn(p, pending[m])

# ---------- Tests for GenerationStrategies composed of GenerationNodes --------
def test_gs_setup_with_nodes(self) -> None:
"""Test GS initalization and validation with nodes"""
sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs={},
model_gen_kwargs={"n": 2},
)
node_1_criterion = [
MaxTrials(
threshold=4,
block_gen_if_met=False,
transition_to="node_2",
only_in_statuses=None,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
),
MinTrials(
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
threshold=2,
transition_to="node_2",
),
MaxGenerationParallelism(
threshold=1,
only_in_statuses=[TrialStatus.RUNNING],
block_gen_if_met=True,
block_transition_if_unmet=False,
),
]

# check error raised if node names are not unique
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "All node names"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_1",
model_specs=[sobol_model_spec],
),
],
)
# check error raised if transition to arguemnt is not valid
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "`transition_to` argument"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_3",
model_specs=[sobol_model_spec],
),
],
)

# check error raised if provided both steps and nodes
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "either steps or nodes"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_3",
model_specs=[sobol_model_spec],
),
],
steps=[
GenerationStep(
model=Models.SOBOL,
num_trials=5,
model_kwargs=self.step_model_kwargs,
),
GenerationStep(
model=Models.GPEI,
num_trials=-1,
model_kwargs=self.step_model_kwargs,
),
],
)

# check error raised if provided both steps and nodes under node list
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "must either be a GenerationStep"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationStep(
model=Models.SOBOL,
num_trials=5,
model_kwargs=self.step_model_kwargs,
),
GenerationNode(
node_name="node_2",
model_specs=[sobol_model_spec],
),
],
)
# check that warning is logged if no nodes have transition arguments
with self.assertLogs(GenerationStrategy.__module__, logging.WARNING) as logger:
warning_msg = (
"None of the nodes in this GenerationStrategy "
"contain a `transition_to` argument in their transition_criteria. "
)
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_3",
model_specs=[sobol_model_spec],
),
],
)
self.assertTrue(
any(warning_msg in output for output in logger.output),
logger.output,
)

def test_gs_with_generation_nodes(self) -> None:
"Simple test of a SOBOL + GPEI GenerationStrategy composed of GenerationNodes"
sobol_criterion = [
MaxTrials(
threshold=5,
transition_to="GPEI_node",
block_gen_if_met=True,
only_in_statuses=None,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
)
]
gpei_criterion = [
MaxTrials(
threshold=2,
transition_to=None,
block_gen_if_met=True,
only_in_statuses=None,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
)
]
sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs=self.step_model_kwargs,
model_gen_kwargs={},
)
gpei_model_spec = ModelSpec(
model_enum=Models.GPEI,
model_kwargs=self.step_model_kwargs,
model_gen_kwargs={},
)
sobol_node = GenerationNode(
node_name="sobol_node",
transition_criteria=sobol_criterion,
model_specs=[sobol_model_spec],
gen_unlimited_trials=False,
)
gpei_node = GenerationNode(
node_name="GPEI_node",
transition_criteria=gpei_criterion,
model_specs=[gpei_model_spec],
gen_unlimited_trials=False,
)

sobol_GPEI_GS_nodes = GenerationStrategy(
name="Sobol+GPEI_Nodes",
nodes=[sobol_node, gpei_node],
)
exp = get_branin_experiment()
self.assertEqual(sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes")
self.assertEqual(sobol_GPEI_GS_nodes.model_transitions, [5])

for i in range(7):
g = sobol_GPEI_GS_nodes.gen(exp)
exp.new_trial(generator_run=g).run()
self.assertEqual(len(sobol_GPEI_GS_nodes._generator_runs), i + 1)
if i > 4:
self.mock_torch_model_bridge.assert_called()
else:
self.assertEqual(g._model_key, "Sobol")
mkw = g._model_kwargs
self.assertIsNotNone(mkw)
if i > 0:
# Generated points are randomized, so checking that they're there.
self.assertIsNotNone(mkw.get("generated_points"))
else:
# This is the first GR, there should be no generated points yet.
self.assertIsNone(mkw.get("generated_points"))
# Remove the randomized generated points to compare the rest.
mkw = mkw.copy()
del mkw["generated_points"]
self.assertEqual(
mkw,
{
"seed": None,
"deduplicate": True,
"init_position": i,
"scramble": True,
"fallback_to_sample_polytope": False,
},
)
self.assertEqual(
g._bridge_kwargs,
{
"optimization_config": None,
"status_quo_features": None,
"status_quo_name": None,
"transform_configs": None,
"transforms": Cont_X_trans,
"fit_out_of_design": False,
"fit_abandoned": False,
"fit_tracking_metrics": True,
"fit_on_init": True,
},
)
ms = g._model_state_after_gen
self.assertIsNotNone(ms)
# Generated points are randomized, so just checking that they are there.
self.assertIn("generated_points", ms)
# Remove the randomized generated points to compare the rest.
ms = ms.copy()
del ms["generated_points"]
self.assertEqual(ms, {"init_position": i + 1})

# ------------- Testing helpers (put tests above this line) -------------

def _run_GS_for_N_rounds(
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_minimum_preference_criterion(self) -> None:
raise_data_required_error=False
)
)
self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)

def test_default_step_criterion_setup(self) -> None:
"""This test ensures that the default completion criterion for GenerationSteps
Expand Down
Loading

0 comments on commit ff5fed4

Please sign in to comment.