Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gen nodes and gensteps #2024

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions ax/exceptions/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,14 @@ class GenerationStrategyRepeatedPoints(GenerationStrategyCompleted):
"""

pass


class GenerationStrategyMisconfiguredException(AxError):
"""Special exception indicating that the generation strategy is misconfigured."""

def __init__(self, error_info: Optional[str]) -> None:
super().__init__(
"This GenerationStrategy was unable to be initialized properly. Please "
+ "check the documentation, and adjust the configuration accordingly. "
+ f"{error_info}"
)
1 change: 1 addition & 0 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class GenerationNode:

# Optional specifications
_model_spec_to_gen_from: Optional[ModelSpec] = None
# TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping?
_transition_criteria: Optional[Sequence[TransitionCriterion]]

# [TODO] Handle experiment passing more eloquently by enforcing experiment
Expand Down
270 changes: 217 additions & 53 deletions ax/modelbridge/generation_strategy.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions ax/modelbridge/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
CVResult,
)
from ax.modelbridge.registry import ModelRegistryBase
from ax.utils.common.base import Base
from ax.utils.common.base import SortableBase
from ax.utils.common.kwargs import (
consolidate_kwargs,
filter_kwargs,
Expand All @@ -48,7 +48,7 @@ def default(self, o: Any) -> str:


@dataclass
class ModelSpec(Base):
class ModelSpec(SortableBase):
model_enum: ModelRegistryBase
# Kwargs to pass into the `Model` + `ModelBridge` constructors in
# `ModelRegistryBase.__call__`.
Expand Down Expand Up @@ -288,6 +288,12 @@ def __hash__(self) -> int:
def __eq__(self, other: ModelSpec) -> bool:
return repr(self) == repr(other)

@property
def _unique_id(self) -> str:
"""Returns the unique ID of this model spec"""
# TODO @mgarrard verify that this is unique enough
return str(hash(self))


@dataclass
class FactoryFunctionModelSpec(ModelSpec):
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_completion_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_single_criterion(self) -> None:
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)

def test_many_criteria(self) -> None:
criteria = [
Expand Down Expand Up @@ -145,4 +145,4 @@ def test_many_criteria(self) -> None:
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)
260 changes: 256 additions & 4 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import cast, List
from unittest import mock
from unittest.mock import MagicMock, patch
Expand All @@ -21,16 +22,23 @@
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
GenerationStrategyMisconfiguredException,
GenerationStrategyRepeatedPoints,
MaxParallelismReachedException,
)
from ax.modelbridge.discrete import DiscreteModelBridge
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.registry import Cont_X_trans, MODEL_KEY_TO_MODEL_SETUP, Models
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transition_criterion import (
MaxGenerationParallelism,
MaxTrials,
MinTrials,
)
from ax.models.random.sobol import SobolGenerator
from ax.utils.common.equality import same_elements
from ax.utils.common.mock import mock_patch_method_original
Expand Down Expand Up @@ -416,8 +424,8 @@ def test_clone_reset(self) -> None:
]
)
ftgs._curr = ftgs._steps[1]
self.assertEqual(ftgs._curr.index, 1)
self.assertEqual(ftgs.clone_reset()._curr.index, 0)
self.assertEqual(ftgs.current_step_index, 1)
self.assertEqual(ftgs.clone_reset().current_step_index, 0)

def test_kwargs_passed(self) -> None:
gs = GenerationStrategy(
Expand Down Expand Up @@ -527,10 +535,12 @@ def test_trials_as_df(self) -> None:
# attach necessary trials to fill up the Generation Strategy
trial = exp.new_trial(sobol_generation_strategy.gen(experiment=exp))
self.assertEqual(
sobol_generation_strategy.trials_as_df.head()["Generation Step"][0], 0
sobol_generation_strategy.trials_as_df.head()["Generation Step"][0],
"GenerationStep_0",
)
self.assertEqual(
sobol_generation_strategy.trials_as_df.head()["Generation Step"][2], 1
sobol_generation_strategy.trials_as_df.head()["Generation Step"][2],
"GenerationStep_1",
)

def test_max_parallelism_reached(self) -> None:
Expand Down Expand Up @@ -883,6 +893,248 @@ def test_gen_for_multiple_trials_with_multiple_models(self) -> None:
for p in original_pending[m]:
self.assertIn(p, pending[m])

# ---------- Tests for GenerationStrategies composed of GenerationNodes --------
def test_gs_setup_with_nodes(self) -> None:
"""Test GS initalization and validation with nodes"""
sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs={},
model_gen_kwargs={"n": 2},
)
node_1_criterion = [
MaxTrials(
threshold=4,
block_gen_if_met=False,
transition_to="node_2",
only_in_statuses=None,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
),
MinTrials(
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
threshold=2,
transition_to="node_2",
),
MaxGenerationParallelism(
threshold=1,
only_in_statuses=[TrialStatus.RUNNING],
block_gen_if_met=True,
block_transition_if_unmet=False,
),
]

# check error raised if node names are not unique
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "All node names"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_1",
model_specs=[sobol_model_spec],
),
],
)
# check error raised if transition to arguemnt is not valid
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "`transition_to` argument"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_3",
model_specs=[sobol_model_spec],
),
],
)

# check error raised if provided both steps and nodes
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "either steps or nodes"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_3",
model_specs=[sobol_model_spec],
),
],
steps=[
GenerationStep(
model=Models.SOBOL,
num_trials=5,
model_kwargs=self.step_model_kwargs,
),
GenerationStep(
model=Models.GPEI,
num_trials=-1,
model_kwargs=self.step_model_kwargs,
),
],
)

# check error raised if provided both steps and nodes under node list
with self.assertRaisesRegex(
GenerationStrategyMisconfiguredException, "must either be a GenerationStep"
):
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
transition_criteria=node_1_criterion,
model_specs=[sobol_model_spec],
),
GenerationStep(
model=Models.SOBOL,
num_trials=5,
model_kwargs=self.step_model_kwargs,
),
GenerationNode(
node_name="node_2",
model_specs=[sobol_model_spec],
),
],
)
# check that warning is logged if no nodes have transition arguments
with self.assertLogs(GenerationStrategy.__module__, logging.WARNING) as logger:
warning_msg = (
"None of the nodes in this GenerationStrategy "
"contain a `transition_to` argument in their transition_criteria. "
)
GenerationStrategy(
nodes=[
GenerationNode(
node_name="node_1",
model_specs=[sobol_model_spec],
),
GenerationNode(
node_name="node_3",
model_specs=[sobol_model_spec],
),
],
)
self.assertTrue(
any(warning_msg in output for output in logger.output),
logger.output,
)

def test_gs_with_generation_nodes(self) -> None:
"Simple test of a SOBOL + GPEI GenerationStrategy composed of GenerationNodes"
sobol_criterion = [
MaxTrials(
threshold=5,
transition_to="GPEI_node",
block_gen_if_met=True,
only_in_statuses=None,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
)
]
gpei_criterion = [
MaxTrials(
threshold=2,
transition_to=None,
block_gen_if_met=True,
only_in_statuses=None,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
)
]
sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs=self.step_model_kwargs,
model_gen_kwargs={},
)
gpei_model_spec = ModelSpec(
model_enum=Models.GPEI,
model_kwargs=self.step_model_kwargs,
model_gen_kwargs={},
)
sobol_node = GenerationNode(
node_name="sobol_node",
transition_criteria=sobol_criterion,
model_specs=[sobol_model_spec],
gen_unlimited_trials=False,
)
gpei_node = GenerationNode(
node_name="GPEI_node",
transition_criteria=gpei_criterion,
model_specs=[gpei_model_spec],
gen_unlimited_trials=False,
)

sobol_GPEI_GS_nodes = GenerationStrategy(
name="Sobol+GPEI_Nodes",
nodes=[sobol_node, gpei_node],
)
exp = get_branin_experiment()
self.assertEqual(sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes")
self.assertEqual(sobol_GPEI_GS_nodes.model_transitions, [5])

for i in range(7):
g = sobol_GPEI_GS_nodes.gen(exp)
exp.new_trial(generator_run=g).run()
self.assertEqual(len(sobol_GPEI_GS_nodes._generator_runs), i + 1)
if i > 4:
self.mock_torch_model_bridge.assert_called()
else:
self.assertEqual(g._model_key, "Sobol")
mkw = g._model_kwargs
self.assertIsNotNone(mkw)
if i > 0:
# Generated points are randomized, so checking that they're there.
self.assertIsNotNone(mkw.get("generated_points"))
else:
# This is the first GR, there should be no generated points yet.
self.assertIsNone(mkw.get("generated_points"))
# Remove the randomized generated points to compare the rest.
mkw = mkw.copy()
del mkw["generated_points"]
self.assertEqual(
mkw,
{
"seed": None,
"deduplicate": True,
"init_position": i,
"scramble": True,
"fallback_to_sample_polytope": False,
},
)
self.assertEqual(
g._bridge_kwargs,
{
"optimization_config": None,
"status_quo_features": None,
"status_quo_name": None,
"transform_configs": None,
"transforms": Cont_X_trans,
"fit_out_of_design": False,
"fit_abandoned": False,
"fit_tracking_metrics": True,
"fit_on_init": True,
},
)
ms = g._model_state_after_gen
self.assertIsNotNone(ms)
# Generated points are randomized, so just checking that they are there.
self.assertIn("generated_points", ms)
# Remove the randomized generated points to compare the rest.
ms = ms.copy()
del ms["generated_points"]
self.assertEqual(ms, {"init_position": i + 1})

# ------------- Testing helpers (put tests above this line) -------------

def _run_GS_for_N_rounds(
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_minimum_preference_criterion(self) -> None:
raise_data_required_error=False
)
)
self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)

def test_default_step_criterion_setup(self) -> None:
"""This test ensures that the default completion criterion for GenerationSteps
Expand Down
Loading