Skip to content

Commit

Permalink
Support gen nodes and gensteps (facebook#2024)
Browse files Browse the repository at this point in the history
Summary:

Support both steps and nodes at the generation strategy level

This diff does the following:
Supports GenerationNodes at the level of GenerationStrategy. This is the big hurrah diff!

upcoming:
(1) update the storage to include nodes independently (and not just as part of step)
(2) delete now unused GenStep functions
(3) final pass on all the doc strings
(4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode
(5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed
(6) rename transiton criterion to action criterion
(7) remove conditionals for legacy usecase

Differential Revision: D51120002
  • Loading branch information
Mia Garrard authored and facebook-github-bot committed Nov 29, 2023
1 parent 001c50c commit 6dd77c8
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 35 deletions.
11 changes: 11 additions & 0 deletions ax/exceptions/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,14 @@ class GenerationStrategyRepeatedPoints(GenerationStrategyCompleted):
"""

pass


class GenerationStrategyMisconfiguredException(AxError):
"""Special exception indicating that the generation strategy is misconfigured."""

def __init__(self, error_info: Optional[str]) -> None:
super().__init__(
"This GenerationStrategy was unable to be initialized properly. Please "
+ "check the documentation, and adjust the configuration accordingly. "
+ f"{error_info}"
)
159 changes: 135 additions & 24 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
from ax.core.observation import ObservationFeatures
from ax.core.utils import extend_pending_observations
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.exceptions.generation_strategy import GenerationStrategyCompleted
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
GenerationStrategyMisconfiguredException,
)

from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_node import GenerationStep
from ax.modelbridge.generation_node import GenerationNode, GenerationStep
from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase
from ax.utils.common.base import Base
from ax.utils.common.logger import _round_floats_for_logging, get_logger
Expand Down Expand Up @@ -56,7 +59,8 @@ class GenerationStrategy(Base):

_name: Optional[str]
_steps: List[GenerationStep]
_curr: GenerationStep # Current step in the strategy.
_nodes: List[GenerationNode]
_curr: GenerationNode # Current step in the strategy.
# Whether all models in this GS are in Models registry enum.
_uses_registered_models: bool
# All generator runs created through this generation strategy, in chronological
Expand All @@ -70,22 +74,57 @@ class GenerationStrategy(Base):
_seen_trial_indices_by_status = None
_model: Optional[ModelBridge] = None # Current model.

def __init__(self, steps: List[GenerationStep], name: Optional[str] = None) -> None:
assert isinstance(steps, list) and all(
isinstance(s, GenerationStep) for s in steps
), "Steps must be a GenerationStep list."
def __init__(
self,
steps: Optional[List[GenerationStep]] = None,
name: Optional[str] = None,
nodes: Optional[List[GenerationNode]] = None,
) -> None:
self._name = name
self._steps = steps
self._uses_registered_models = True
self._generator_runs = []

# validate that only one of steps or nodes is provided
if not ((steps is None) ^ (nodes is None)):
raise GenerationStrategyMisconfiguredException(
error_info="GenerationStrategy must contain either steps or nodes."
)
if steps is not None:
self._steps = steps
if nodes is not None:
self._nodes = nodes

if isinstance(steps, list) and all(
isinstance(s, GenerationStep) for s in steps
):
self._validate_and_set_step_sequence(steps=steps)
elif isinstance(nodes, list) and all(
isinstance(n, GenerationNode) for n in nodes
):
self._validate_and_set_node_graph(nodes=nodes)
else:
raise GenerationStrategyMisconfiguredException(
error_info="Steps must either be a GenerationStep list or a "
+ "GenerationNode list."
)

if not self._uses_registered_models:
logger.info(
"Using model via callable function, "
"so optimization is not resumable if interrupted."
)
self._seen_trial_indices_by_status = None

def _validate_and_set_step_sequence(self, steps: List[GenerationStep]) -> None:
"""Initialize and validate that the sequence of steps."""
for idx, step in enumerate(self._steps):
if step.num_trials == -1 and len(step.completion_criteria) < 1:
if idx < len(self._steps) - 1:
raise UserInputError(
"Only last step in generation strategy can have `num_trials` "
"set to -1 to indicate that the model in the step should "
"be used to generate new trials indefinitely unless "
"completion critera present."
"Only last step in generation strategy can have "
"`num_trials` set to -1 to indicate that the model in "
"the step shouldbe used to generate new trials "
"indefinitely unless completion critera present."
)
elif step.num_trials < 1 and step.num_trials != -1:
raise UserInputError(
Expand All @@ -94,15 +133,16 @@ def __init__(self, steps: List[GenerationStep], name: Optional[str] = None) -> N
)
if step.max_parallelism is not None and step.max_parallelism < 1:
raise UserInputError(
"Maximum parallelism should be None (if no limit) or a positive"
f" number. Got: {step.max_parallelism} for step {step.model_name}."
"Maximum parallelism should be None (if no limit) or "
f"a positive number. Got: {step.max_parallelism} for "
f"step {step.model_name}."
)
# TODO[mgarrard]: Validate node name uniqueness when adding node support,
# uniqueness is gaurenteed for steps currently due to list structure.

step._node_name = f"GenerationStep_{str(idx)}"
step.index = idx

# Set transition_to field for all but the last step, which remains null.
# Set transition_to field for all but the last step, which remains
# null.
if idx != len(self._steps):
for transition_criteria in step.transition_criteria:
if (
Expand All @@ -115,13 +155,55 @@ def __init__(self, steps: List[GenerationStep], name: Optional[str] = None) -> N
step._generation_strategy = self
if not isinstance(step.model, ModelRegistryBase):
self._uses_registered_models = False
if not self._uses_registered_models:
logger.info(
"Using model via callable function, "
"so optimization is not resumable if interrupted."
)
self._curr = steps[0]
self._seen_trial_indices_by_status = None

def _validate_and_set_node_graph(self, nodes: List[GenerationNode]) -> None:
"""Initialize and validate the graph of nodes."""
node_names = []
for node in self._nodes:
# validate that all node names are unique
if node.node_name in node_names:
raise GenerationStrategyMisconfiguredException(
error_info="All node names in a GenerationStrategy "
+ "must be unique."
)

node_names.append(node.node_name)
node._generation_strategy = self
fitted_model = None
try:
fitted_model = node.fitted_model
except UserInputError:
pass

if fitted_model is not None and not isinstance(
fitted_model, ModelRegistryBase
):
self._uses_registered_models = False

# validate `transition_criterion`
contains_a_transition_to_argument = False
for node in self._nodes:
for transition_criteria in node.transition_criteria:
if transition_criteria.transition_to is not None:
contains_a_transition_to_argument = True
if transition_criteria.transition_to not in node_names:
raise GenerationStrategyMisconfiguredException(
error_info=f"`transition_to` argument "
f"{transition_criteria.transition_to} does not "
"correspond to any node in this GenerationStrategy."
)

# validate that at least one node has transition_to field
if len(self._nodes) > 1 and not contains_a_transition_to_argument:
logger.warning(
"None of the nodes in this GenerationStrategy "
"contain a `transition_to` argument in their transition_criteria. "
"Therefore, the GenerationStrategy will not be able to "
"move from one node to another. Please add a "
"`transition_to` argument."
)
self._curr = nodes[0]

@property
def name(self) -> str:
Expand Down Expand Up @@ -152,8 +234,37 @@ def model_transitions(self) -> List[int]:
@property
def current_step(self) -> GenerationStep:
"""Current generation step."""
assert isinstance(self._curr, GenerationStep), (
"The current object is not a GenerationStep, you may be looking for the"
" current_node property."
)
return self._curr

@property
def current_node(self) -> GenerationNode:
"""Current generation step."""
assert isinstance(self._curr, GenerationNode), (
"The current object is not a GenerationNode, you may be looking for the"
" current_step property."
)
return self._curr

@property
def current_step_index(self) -> int:
"""Returns the index of the current generation step. This attribute
is replaced by node_name in newer GenerationStrategies but surfaced here
for backward compatibility.
"""
assert isinstance(
self._curr, GenerationStep
), "current_step_index only works with GenerationStep"
node_names_for_all_steps = [step._node_name for step in self._steps]
assert (
self._curr.node_name in node_names_for_all_steps
), "The current step is not found in the list of steps"

return node_names_for_all_steps.index(self._curr.node_name)

@property
def model(self) -> Optional[ModelBridge]:
"""Current model in this strategy. Returns None if no model has been set
Expand Down Expand Up @@ -512,7 +623,7 @@ def _get_model_state_from_last_generator_run(self) -> Dict[str, Any]:
return model_state_on_lgr

if all(isinstance(s, GenerationStep) for s in self._steps):
grs_equal = lgr._generation_step_index == self._curr.index
grs_equal = lgr._generation_step_index == self.current_step_index
else:
grs_equal = lgr._generation_node_name == self._curr.node_name

Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_completion_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_single_criterion(self) -> None:
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)

def test_many_criteria(self) -> None:
criteria = [
Expand Down Expand Up @@ -145,4 +145,4 @@ def test_many_criteria(self) -> None:
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)
self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI)
Loading

0 comments on commit 6dd77c8

Please sign in to comment.