From ff5fed4c098e438f7ac929fedc9951926d7d25cf Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Wed, 6 Dec 2023 13:24:06 -0800 Subject: [PATCH] Support GenerationNodes and GenerationSteps in GenerationStrategy and default to GenerationNodes (#2024) Summary: This diff does the following: Supports GenerationNodes at the level of GenerationStrategy. This is the big hurrah diff! upcoming: (0) Add decorator for functions that are only supported in steps (1) update the storage to include nodes independently (and not just as part of step) (2) delete now unused GenStep functions (3) final pass on all the doc strings and variables -- lots to clean up here (4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode (5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed (6) rename transiton criterion to action criterion (7) remove conditionals for legacy usecase (8) clean up any lingering todos Reviewed By: lena-kashtelyan Differential Revision: D51120002 --- ax/exceptions/generation_strategy.py | 11 + ax/modelbridge/generation_node.py | 1 + ax/modelbridge/generation_strategy.py | 270 ++++++++++++++---- ax/modelbridge/model_spec.py | 11 +- .../tests/test_completion_criterion.py | 4 +- .../tests/test_generation_strategy.py | 260 ++++++++++++++++- .../tests/test_transition_criterion.py | 2 +- ax/modelbridge/transition_criterion.py | 13 +- ax/service/tests/test_ax_client.py | 11 +- ax/service/utils/best_point_mixin.py | 4 +- ax/storage/json_store/encoders.py | 2 +- ax/storage/sqa_store/encoder.py | 2 +- ax/storage/sqa_store/save.py | 2 +- ax/storage/sqa_store/tests/test_sqa_store.py | 4 + ax/storage/sqa_store/utils.py | 1 + 15 files changed, 526 insertions(+), 72 deletions(-) diff --git a/ax/exceptions/generation_strategy.py b/ax/exceptions/generation_strategy.py index 9cf9de29094..0233008a753 100644 --- a/ax/exceptions/generation_strategy.py +++ b/ax/exceptions/generation_strategy.py @@ -55,3 +55,14 @@ class GenerationStrategyRepeatedPoints(GenerationStrategyCompleted): """ pass + + +class GenerationStrategyMisconfiguredException(AxError): + """Special exception indicating that the generation strategy is misconfigured.""" + + def __init__(self, error_info: Optional[str]) -> None: + super().__init__( + "This GenerationStrategy was unable to be initialized properly. Please " + + "check the documentation, and adjust the configuration accordingly. " + + f"{error_info}" + ) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 5c7011d072d..9cf5fe3b8a6 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -100,6 +100,7 @@ class GenerationNode: # Optional specifications _model_spec_to_gen_from: Optional[ModelSpec] = None + # TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping? _transition_criteria: Optional[Sequence[TransitionCriterion]] # [TODO] Handle experiment passing more eloquently by enforcing experiment diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 0b4a28fc544..cd342f573e6 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -22,11 +22,16 @@ get_pending_observation_features_based_on_trial_status, ) from ax.exceptions.core import DataRequiredError, UserInputError -from ax.exceptions.generation_strategy import GenerationStrategyCompleted +from ax.exceptions.generation_strategy import ( + GenerationStrategyCompleted, + GenerationStrategyMisconfiguredException, +) from ax.modelbridge.base import ModelBridge -from ax.modelbridge.generation_node import GenerationStep +from ax.modelbridge.generation_node import GenerationNode, GenerationStep +from ax.modelbridge.model_spec import FactoryFunctionModelSpec from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase +from ax.modelbridge.transition_criterion import TrialBasedCriterion from ax.utils.common.logger import _round_floats_for_logging, get_logger from ax.utils.common.typeutils import checked_cast, not_none @@ -58,8 +63,8 @@ class GenerationStrategy(GenerationStrategyInterface): """ _name: Optional[str] - _steps: List[GenerationStep] - _curr: GenerationStep # Current step in the strategy. + _nodes: List[GenerationNode] + _curr: GenerationNode # Current step in the strategy. # Whether all models in this GS are in Models registry enum. _uses_registered_models: bool # All generator runs created through this generation strategy, in chronological @@ -73,22 +78,69 @@ class GenerationStrategy(GenerationStrategyInterface): _seen_trial_indices_by_status = None _model: Optional[ModelBridge] = None # Current model. - def __init__(self, steps: List[GenerationStep], name: Optional[str] = None) -> None: - assert isinstance(steps, list) and all( - isinstance(s, GenerationStep) for s in steps - ), "Steps must be a GenerationStep list." + def __init__( + self, + steps: Optional[List[GenerationStep]] = None, + name: Optional[str] = None, + nodes: Optional[List[GenerationNode]] = None, + ) -> None: self._name = name - self._steps = steps self._uses_registered_models = True self._generator_runs = [] - for idx, step in enumerate(self._steps): + + # validate that only one of steps or nodes is provided + if not ((steps is None) ^ (nodes is None)): + raise GenerationStrategyMisconfiguredException( + error_info="GenerationStrategy must contain either steps or nodes." + ) + # pyre-ignore[8] + self._nodes = steps if steps is not None else nodes + node_based_strategy = self.is_node_based(nodes=self._nodes) + + if isinstance(steps, list) and not node_based_strategy: + # pyre-ignore[6] + self._validate_and_set_step_sequence(steps=self._nodes) + elif isinstance(nodes, list) and node_based_strategy: + self._validate_and_set_node_graph(nodes=nodes) + else: + raise GenerationStrategyMisconfiguredException( + error_info="Steps must either be a GenerationStep list or a " + + "GenerationNode list." + ) + self._uses_registered_models = not any( + isinstance(ms, FactoryFunctionModelSpec) + for node in self._nodes + for ms in node.model_specs + ) + if not self._uses_registered_models: + logger.info( + "Using model via callable function, " + "so optimization is not resumable if interrupted." + ) + self._seen_trial_indices_by_status = None + + def _validate_and_set_step_sequence(self, steps: List[GenerationStep]) -> None: + """Initialize and validate the steps provided to this GenerationStrategy. + + Some GenerationStrategies are composed of GenerationStep objects, but we also + need to initialize the correct GenerationNode representation for these steps. + This function validates: + 1. That only the last step has num_trials=-1, which indicates unlimited + trial generation is possible. + 2. That each step's num_trials attrivute is either positive or -1 + 3. That each step's max_parallelism attribute is either None or positive + It then sets the corect TransitionCriterion and node_name attributes on the + underlying GenerationNode objects. + """ + for idx, step in enumerate(steps): + assert isinstance(step, GenerationStep) if step.num_trials == -1 and len(step.completion_criteria) < 1: if idx < len(self._steps) - 1: raise UserInputError( - "Only last step in generation strategy can have `num_trials` " - "set to -1 to indicate that the model in the step should " - "be used to generate new trials indefinitely unless " - "completion critera present." + "Only last step in generation strategy can have " + "`num_trials` set to -1 to indicate that the model in " + "the step shouldbe used to generate new trials " + "indefinitely unless completion critera present." ) elif step.num_trials < 1 and step.num_trials != -1: raise UserInputError( @@ -97,15 +149,16 @@ def __init__(self, steps: List[GenerationStep], name: Optional[str] = None) -> N ) if step.max_parallelism is not None and step.max_parallelism < 1: raise UserInputError( - "Maximum parallelism should be None (if no limit) or a positive" - f" number. Got: {step.max_parallelism} for step {step.model_name}." + "Maximum parallelism should be None (if no limit) or " + f"a positive number. Got: {step.max_parallelism} for " + f"step {step.model_name}." ) - # TODO[mgarrard]: Validate node name uniqueness when adding node support, - # uniqueness is gaurenteed for steps currently due to list structure. + step._node_name = f"GenerationStep_{str(idx)}" step.index = idx - # Set transition_to field for all but the last step, which remains null. + # Set transition_to field for all but the last step, which remains + # null. if idx != len(self._steps): for transition_criteria in step.transition_criteria: if ( @@ -116,15 +169,69 @@ def __init__(self, steps: List[GenerationStep], name: Optional[str] = None) -> N f"GenerationStep_{str(idx + 1)}" ) step._generation_strategy = self - if not isinstance(step.model, ModelRegistryBase): - self._uses_registered_models = False - if not self._uses_registered_models: - logger.info( - "Using model via callable function, " - "so optimization is not resumable if interrupted." - ) self._curr = steps[0] - self._seen_trial_indices_by_status = None + + def _validate_and_set_node_graph(self, nodes: List[GenerationNode]) -> None: + """Initialize and validate the node graph provided to this GenerationStrategy. + + This function validates: + 1. That all nodes have unique names. + 2. That there is at least one node with a transition_to field. + 3. That all `transition_to` attributes on a TransitionCriterion point to + another node in the same GenerationStrategy. + 4. Warns if no nodes contain a transition criterion + """ + node_names = [] + for node in self._nodes: + # validate that all node names are unique + if node.node_name in node_names: + raise GenerationStrategyMisconfiguredException( + error_info="All node names in a GenerationStrategy " + + "must be unique." + ) + + node_names.append(node.node_name) + node._generation_strategy = self + + # validate `transition_criterion` + contains_a_transition_to_argument = False + for node in self._nodes: + for transition_criteria in node.transition_criteria: + if transition_criteria.transition_to is not None: + contains_a_transition_to_argument = True + if transition_criteria.transition_to not in node_names: + raise GenerationStrategyMisconfiguredException( + error_info=f"`transition_to` argument " + f"{transition_criteria.transition_to} does not " + "correspond to any node in this GenerationStrategy." + ) + + # validate that at least one node has transition_to field + if len(self._nodes) > 1 and not contains_a_transition_to_argument: + logger.warning( + "None of the nodes in this GenerationStrategy " + "contain a `transition_to` argument in their transition_criteria. " + "Therefore, the GenerationStrategy will not be able to " + "move from one node to another. Please add a " + "`transition_to` argument." + ) + self._curr = nodes[0] + + @property + def _steps(self) -> List[GenerationStep]: + """List of generation steps.""" + assert all( + isinstance(n, GenerationStep) for n in self._nodes + ), "Attempting to set steps to non-GenerationStep objects." + return self._nodes # pyre-ignore[7] + + def is_node_based(self, nodes: List[GenerationNode]) -> bool: + """Whether this strategy consists of GenerationNodes or GenerationSteps. + This is useful for determining initialization properties and other logic. + """ + if any(isinstance(n, GenerationStep) for n in nodes): + return False + return True @property def name(self) -> str: @@ -134,7 +241,7 @@ def name(self) -> str: if self._name is not None: return not_none(self._name) - factory_names = (step.model_name for step in self._steps) + factory_names = (node.model_spec_to_gen_from.model_key for node in self._nodes) # Trim the "get_" beginning of the factory function if it's there. factory_names = (n[4:] if n[:4] == "get_" else n for n in factory_names) self._name = "+".join(factory_names) @@ -149,14 +256,59 @@ def name(self, name: str) -> None: def model_transitions(self) -> List[int]: """List of trial indices where a transition happened from one model to another.""" - gen_changes = [step.num_trials for step in self._steps] + # TODO @mgarrard to support GenerationNodes here, which is non-trival + # since nodes are dynamic and may only support past model_transitions + gen_changes = [] + for node in self._nodes: + for criterion in node.transition_criteria: + if ( + isinstance(criterion, TrialBasedCriterion) + and criterion.criterion_class == "MaxTrials" + ): + gen_changes.append(criterion.threshold) + + # if the last node has unlimited generation, do not remeove the last + # transition point in the list + if self._nodes[-1].gen_unlimited_trials: + return [sum(gen_changes[: i + 1]) for i in range(len(gen_changes))] return [sum(gen_changes[: i + 1]) for i in range(len(gen_changes))][:-1] @property def current_step(self) -> GenerationStep: """Current generation step.""" + if not isinstance(self._curr, GenerationStep): + raise TypeError( + "The current object is not a GenerationStep, you may be looking " + "for the current_node property." + ) return self._curr + @property + def current_node(self) -> GenerationNode: + """Current generation node.""" + if not isinstance(self._curr, GenerationNode): + raise TypeError( + "The current object is not a GenerationNode, you may be looking for the" + " current_step property." + ) + return self._curr + + @property + def current_step_index(self) -> int: + """Returns the index of the current generation step. This attribute + is replaced by node_name in newer GenerationStrategies but surfaced here + for backward compatibility. + """ + assert isinstance( + self._curr, GenerationStep + ), "current_step_index only works with GenerationStep" + node_names_for_all_steps = [step._node_name for step in self._nodes] + assert ( + self._curr.node_name in node_names_for_all_steps + ), "The current step is not found in the list of steps" + + return node_names_for_all_steps.index(self._curr.node_name) + @property def model(self) -> Optional[ModelBridge]: """Current model in this strategy. Returns None if no model has been set @@ -208,15 +360,15 @@ def trials_as_df(self) -> Optional[pd.DataFrame]: ) if self._experiment is None or all( - len(trials) == 0 - for step in self._steps - for trials in step.trial_indices.values() + len(step.trials_from_node) == 0 for step in self._nodes ): return None records = [ { - "Generation Step": step.index, - "Generation Model": self._steps[step.index].model_name, + "Generation Step": step.node_name, + "Generation Model": self._nodes[ + step_idx + ].model_spec_to_gen_from.model_key, "Trial Index": trial_idx, "Trial Status": self.experiment.trials[trial_idx].status.name, "Arm Parameterizations": { @@ -224,9 +376,8 @@ def trials_as_df(self) -> Optional[pd.DataFrame]: for arm in self.experiment.trials[trial_idx].arms }, } - for step in self._steps - for _, trials in step.trial_indices.items() - for trial_idx in trials + for step_idx, step in enumerate(self._nodes) + for trial_idx in step.trials_from_node ] return pd.DataFrame.from_records(records).reindex( columns=[ @@ -248,7 +399,7 @@ def gen( ) -> GeneratorRun: """Produce the next points in the experiment. Additional kwargs passed to this method are propagated directly to the underlying model's `gen`, along - with the `model_gen_kwargs` set on the current generation step. + with the `model_gen_kwargs` set on the current generation node. NOTE: Each generator run returned from this function must become a single trial on the experiment to comply with assumptions made in generation @@ -332,7 +483,7 @@ def gen_for_multiple_trials_with_multiple_models( def current_generator_run_limit( self, ) -> Tuple[int, bool]: - """First check if we can move the generation strategy to the next step, which + """First check if we can move the generation strategy to the next , which is safe, as the next call to ``gen`` will just pick up from there. Then determine how many generator runs this generation strategy can generate right now, assuming each one of them becomes its own trial, and whether optimization @@ -354,6 +505,14 @@ def current_generator_run_limit( def clone_reset(self) -> GenerationStrategy: """Copy this generation strategy without it's state.""" + if self.is_node_based(nodes=self._nodes): + nodes = deepcopy(self._nodes) + for n in nodes: + # Unset the generation strategy back-pointer, so the nodes are not + # associated with any generation strategy. + n._generation_strategy = None + return GenerationStrategy(name=self.name, nodes=nodes) + steps = deepcopy(self._steps) for s in steps: # Unset the generation strategy back-pointer, so the steps are not @@ -370,19 +529,24 @@ def _unset_non_persistent_state_fields(self) -> None: """ self._seen_trial_indices_by_status = None self._model = None - for s in self._steps: + for s in self._nodes: s._model_spec_to_gen_from = None def __repr__(self) -> str: """String representation of this generation strategy.""" repr = f"GenerationStrategy(name='{self.name}', steps=[" - remaining_trials = "subsequent" if len(self._steps) > 1 else "all" - for step in self._steps: - num_trials = ( - f"{step.num_trials}" if step.num_trials != -1 else remaining_trials - ) + remaining_trials = "subsequent" if len(self._nodes) > 1 else "all" + for step in self._nodes: + # TODO @mgarrard handle this more gracefully for more general nodes + num_trials = remaining_trials + for criterion in step.transition_criteria: + if criterion.criterion_class == "MaxTrials" and isinstance( + criterion, TrialBasedCriterion + ): + num_trials = criterion.threshold + try: - model_name = step.model_name + model_name = step.model_spec_to_gen_from.model_key except TypeError: model_name = "model with unknown name" @@ -480,15 +644,15 @@ def _gen_multiple( # ------------------------- Model selection logic helpers. ------------------------- def _fit_current_model(self, data: Optional[Data]) -> None: - """Fits or update the model on the current generation step (does not move - between generation steps). + """Fits or update the model on the current generation node (does not move + between generation nodes). Args: data: Optional ``Data`` to fit or update with; if not specified, generation strategy will obtain the data via ``experiment.lookup_data``. """ data = self.experiment.lookup_data() if data is None else data - # If last generator run's index matches the current step, extract + # If last generator run's index matches the current node, extract # model state from last generator run and pass it to the model # being instantiated in this function. model_state_on_lgr = self._get_model_state_from_last_generator_run() @@ -525,14 +689,14 @@ def _maybe_move_to_next_step(self, raise_data_required_error: bool = True) -> bo ), "This node should never attempt to transition since" " it can generate unlimited trials" # TODO: @mgarrard clean up with legacy usecase removal - if all(node.is_completed for node in self._steps) and "AEPsych" not in str( + if all(node.is_completed for node in self._nodes) and "AEPsych" not in str( self._curr ): raise GenerationStrategyCompleted( f"Generation strategy {self} generated all the trials as " "specified in its steps." ) - for step in self._steps: + for step in self._nodes: if step.node_name == next_node: self._curr = step # Moving to the next step also entails unsetting this GS's model @@ -553,8 +717,8 @@ def _get_model_state_from_last_generator_run(self) -> Dict[str, Any]: if lgr is None: return model_state_on_lgr - if all(isinstance(s, GenerationStep) for s in self._steps): - grs_equal = lgr._generation_step_index == self._curr.index + if not self.is_node_based(nodes=self._nodes): + grs_equal = lgr._generation_step_index == self.current_step_index else: grs_equal = lgr._generation_node_name == self._curr.node_name diff --git a/ax/modelbridge/model_spec.py b/ax/modelbridge/model_spec.py index 8180ce1304c..15becceac9c 100644 --- a/ax/modelbridge/model_spec.py +++ b/ax/modelbridge/model_spec.py @@ -10,6 +10,7 @@ import warnings from copy import deepcopy from dataclasses import dataclass +from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Tuple from ax.core.data import Data @@ -27,7 +28,7 @@ CVResult, ) from ax.modelbridge.registry import ModelRegistryBase -from ax.utils.common.base import Base +from ax.utils.common.base import SortableBase from ax.utils.common.kwargs import ( consolidate_kwargs, filter_kwargs, @@ -48,7 +49,7 @@ def default(self, o: Any) -> str: @dataclass -class ModelSpec(Base): +class ModelSpec(SortableBase): model_enum: ModelRegistryBase # Kwargs to pass into the `Model` + `ModelBridge` constructors in # `ModelRegistryBase.__call__`. @@ -288,6 +289,12 @@ def __hash__(self) -> int: def __eq__(self, other: ModelSpec) -> bool: return repr(self) == repr(other) + @property + def _unique_id(self) -> str: + """Returns the unique ID of this model spec""" + # TODO @mgarrard verify that this is unique enough + return str(hash(self)) + @dataclass class FactoryFunctionModelSpec(ModelSpec): diff --git a/ax/modelbridge/tests/test_completion_criterion.py b/ax/modelbridge/tests/test_completion_criterion.py index 4af57c64573..db2e869fdbc 100644 --- a/ax/modelbridge/tests/test_completion_criterion.py +++ b/ax/modelbridge/tests/test_completion_criterion.py @@ -71,7 +71,7 @@ def test_single_criterion(self) -> None: ) ) - self.assertEqual(generation_strategy._curr.model, Models.GPEI) + self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI) def test_many_criteria(self) -> None: criteria = [ @@ -145,4 +145,4 @@ def test_many_criteria(self) -> None: ) ) - self.assertEqual(generation_strategy._curr.model, Models.GPEI) + self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI) diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index adf118dfdfc..06948642363 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import cast, List from unittest import mock from unittest.mock import MagicMock, patch @@ -21,16 +22,23 @@ from ax.exceptions.core import DataRequiredError, UserInputError from ax.exceptions.generation_strategy import ( GenerationStrategyCompleted, + GenerationStrategyMisconfiguredException, GenerationStrategyRepeatedPoints, MaxParallelismReachedException, ) from ax.modelbridge.discrete import DiscreteModelBridge from ax.modelbridge.factory import get_sobol +from ax.modelbridge.generation_node import GenerationNode from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.random import RandomModelBridge from ax.modelbridge.registry import Cont_X_trans, MODEL_KEY_TO_MODEL_SETUP, Models from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.transition_criterion import ( + MaxGenerationParallelism, + MaxTrials, + MinTrials, +) from ax.models.random.sobol import SobolGenerator from ax.utils.common.equality import same_elements from ax.utils.common.mock import mock_patch_method_original @@ -416,8 +424,8 @@ def test_clone_reset(self) -> None: ] ) ftgs._curr = ftgs._steps[1] - self.assertEqual(ftgs._curr.index, 1) - self.assertEqual(ftgs.clone_reset()._curr.index, 0) + self.assertEqual(ftgs.current_step_index, 1) + self.assertEqual(ftgs.clone_reset().current_step_index, 0) def test_kwargs_passed(self) -> None: gs = GenerationStrategy( @@ -527,10 +535,12 @@ def test_trials_as_df(self) -> None: # attach necessary trials to fill up the Generation Strategy trial = exp.new_trial(sobol_generation_strategy.gen(experiment=exp)) self.assertEqual( - sobol_generation_strategy.trials_as_df.head()["Generation Step"][0], 0 + sobol_generation_strategy.trials_as_df.head()["Generation Step"][0], + "GenerationStep_0", ) self.assertEqual( - sobol_generation_strategy.trials_as_df.head()["Generation Step"][2], 1 + sobol_generation_strategy.trials_as_df.head()["Generation Step"][2], + "GenerationStep_1", ) def test_max_parallelism_reached(self) -> None: @@ -883,6 +893,248 @@ def test_gen_for_multiple_trials_with_multiple_models(self) -> None: for p in original_pending[m]: self.assertIn(p, pending[m]) + # ---------- Tests for GenerationStrategies composed of GenerationNodes -------- + def test_gs_setup_with_nodes(self) -> None: + """Test GS initalization and validation with nodes""" + sobol_model_spec = ModelSpec( + model_enum=Models.SOBOL, + model_kwargs={}, + model_gen_kwargs={"n": 2}, + ) + node_1_criterion = [ + MaxTrials( + threshold=4, + block_gen_if_met=False, + transition_to="node_2", + only_in_statuses=None, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + ), + MinTrials( + only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], + threshold=2, + transition_to="node_2", + ), + MaxGenerationParallelism( + threshold=1, + only_in_statuses=[TrialStatus.RUNNING], + block_gen_if_met=True, + block_transition_if_unmet=False, + ), + ] + + # check error raised if node names are not unique + with self.assertRaisesRegex( + GenerationStrategyMisconfiguredException, "All node names" + ): + GenerationStrategy( + nodes=[ + GenerationNode( + node_name="node_1", + transition_criteria=node_1_criterion, + model_specs=[sobol_model_spec], + ), + GenerationNode( + node_name="node_1", + model_specs=[sobol_model_spec], + ), + ], + ) + # check error raised if transition to arguemnt is not valid + with self.assertRaisesRegex( + GenerationStrategyMisconfiguredException, "`transition_to` argument" + ): + GenerationStrategy( + nodes=[ + GenerationNode( + node_name="node_1", + transition_criteria=node_1_criterion, + model_specs=[sobol_model_spec], + ), + GenerationNode( + node_name="node_3", + model_specs=[sobol_model_spec], + ), + ], + ) + + # check error raised if provided both steps and nodes + with self.assertRaisesRegex( + GenerationStrategyMisconfiguredException, "either steps or nodes" + ): + GenerationStrategy( + nodes=[ + GenerationNode( + node_name="node_1", + transition_criteria=node_1_criterion, + model_specs=[sobol_model_spec], + ), + GenerationNode( + node_name="node_3", + model_specs=[sobol_model_spec], + ), + ], + steps=[ + GenerationStep( + model=Models.SOBOL, + num_trials=5, + model_kwargs=self.step_model_kwargs, + ), + GenerationStep( + model=Models.GPEI, + num_trials=-1, + model_kwargs=self.step_model_kwargs, + ), + ], + ) + + # check error raised if provided both steps and nodes under node list + with self.assertRaisesRegex( + GenerationStrategyMisconfiguredException, "must either be a GenerationStep" + ): + GenerationStrategy( + nodes=[ + GenerationNode( + node_name="node_1", + transition_criteria=node_1_criterion, + model_specs=[sobol_model_spec], + ), + GenerationStep( + model=Models.SOBOL, + num_trials=5, + model_kwargs=self.step_model_kwargs, + ), + GenerationNode( + node_name="node_2", + model_specs=[sobol_model_spec], + ), + ], + ) + # check that warning is logged if no nodes have transition arguments + with self.assertLogs(GenerationStrategy.__module__, logging.WARNING) as logger: + warning_msg = ( + "None of the nodes in this GenerationStrategy " + "contain a `transition_to` argument in their transition_criteria. " + ) + GenerationStrategy( + nodes=[ + GenerationNode( + node_name="node_1", + model_specs=[sobol_model_spec], + ), + GenerationNode( + node_name="node_3", + model_specs=[sobol_model_spec], + ), + ], + ) + self.assertTrue( + any(warning_msg in output for output in logger.output), + logger.output, + ) + + def test_gs_with_generation_nodes(self) -> None: + "Simple test of a SOBOL + GPEI GenerationStrategy composed of GenerationNodes" + sobol_criterion = [ + MaxTrials( + threshold=5, + transition_to="GPEI_node", + block_gen_if_met=True, + only_in_statuses=None, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + ) + ] + gpei_criterion = [ + MaxTrials( + threshold=2, + transition_to=None, + block_gen_if_met=True, + only_in_statuses=None, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + ) + ] + sobol_model_spec = ModelSpec( + model_enum=Models.SOBOL, + model_kwargs=self.step_model_kwargs, + model_gen_kwargs={}, + ) + gpei_model_spec = ModelSpec( + model_enum=Models.GPEI, + model_kwargs=self.step_model_kwargs, + model_gen_kwargs={}, + ) + sobol_node = GenerationNode( + node_name="sobol_node", + transition_criteria=sobol_criterion, + model_specs=[sobol_model_spec], + gen_unlimited_trials=False, + ) + gpei_node = GenerationNode( + node_name="GPEI_node", + transition_criteria=gpei_criterion, + model_specs=[gpei_model_spec], + gen_unlimited_trials=False, + ) + + sobol_GPEI_GS_nodes = GenerationStrategy( + name="Sobol+GPEI_Nodes", + nodes=[sobol_node, gpei_node], + ) + exp = get_branin_experiment() + self.assertEqual(sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes") + self.assertEqual(sobol_GPEI_GS_nodes.model_transitions, [5]) + + for i in range(7): + g = sobol_GPEI_GS_nodes.gen(exp) + exp.new_trial(generator_run=g).run() + self.assertEqual(len(sobol_GPEI_GS_nodes._generator_runs), i + 1) + if i > 4: + self.mock_torch_model_bridge.assert_called() + else: + self.assertEqual(g._model_key, "Sobol") + mkw = g._model_kwargs + self.assertIsNotNone(mkw) + if i > 0: + # Generated points are randomized, so checking that they're there. + self.assertIsNotNone(mkw.get("generated_points")) + else: + # This is the first GR, there should be no generated points yet. + self.assertIsNone(mkw.get("generated_points")) + # Remove the randomized generated points to compare the rest. + mkw = mkw.copy() + del mkw["generated_points"] + self.assertEqual( + mkw, + { + "seed": None, + "deduplicate": True, + "init_position": i, + "scramble": True, + "fallback_to_sample_polytope": False, + }, + ) + self.assertEqual( + g._bridge_kwargs, + { + "optimization_config": None, + "status_quo_features": None, + "status_quo_name": None, + "transform_configs": None, + "transforms": Cont_X_trans, + "fit_out_of_design": False, + "fit_abandoned": False, + "fit_tracking_metrics": True, + "fit_on_init": True, + }, + ) + ms = g._model_state_after_gen + self.assertIsNotNone(ms) + # Generated points are randomized, so just checking that they are there. + self.assertIn("generated_points", ms) + # Remove the randomized generated points to compare the rest. + ms = ms.copy() + del ms["generated_points"] + self.assertEqual(ms, {"init_position": i + 1}) + # ------------- Testing helpers (put tests above this line) ------------- def _run_GS_for_N_rounds( diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/modelbridge/tests/test_transition_criterion.py index ae6c16100f4..9fbde23aeca 100644 --- a/ax/modelbridge/tests/test_transition_criterion.py +++ b/ax/modelbridge/tests/test_transition_criterion.py @@ -74,7 +74,7 @@ def test_minimum_preference_criterion(self) -> None: raise_data_required_error=False ) ) - self.assertEqual(generation_strategy._curr.model, Models.GPEI) + self.assertEqual(generation_strategy._curr.model_enum, Models.GPEI) def test_default_step_criterion_setup(self) -> None: """This test ensures that the default completion criterion for GenerationSteps diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index c9b5e7851dd..4592cdcb8ee 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from abc import abstractmethod +from datetime import datetime from logging import Logger from typing import List, Optional, Set @@ -12,14 +13,14 @@ from ax.core.experiment import Experiment from ax.exceptions.generation_strategy import MaxParallelismReachedException from ax.modelbridge.generation_strategy import DataRequiredError -from ax.utils.common.base import Base +from ax.utils.common.base import SortableBase from ax.utils.common.logger import get_logger from ax.utils.common.serialization import SerializationMixin, serialize_init_args logger: Logger = get_logger(__name__) -class TransitionCriterion(Base, SerializationMixin): +class TransitionCriterion(SortableBase, SerializationMixin): # TODO: @mgarrard rename to ActionCriterion """ Simple class to descibe a condition which must be met for this GenerationNode to @@ -66,7 +67,7 @@ def is_met( """If the criterion of this TransitionCriterion is met, returns True.""" pass - @abstractmethod + # TODO: @mgarrard add back abstractmethod once legacy usecases are updated def block_continued_generation_error( self, node_name: Optional[str], @@ -85,6 +86,12 @@ def criterion_class(self) -> str: def __repr__(self) -> str: return f"{self.criterion_class}({serialize_init_args(obj=self)})" + @property + def _unique_id(self) -> str: + """Unique id for this TransitionCriterion.""" + # TODO @mgarrard validate that this is unique enough + return str(self) + class TrialBasedCriterion(TransitionCriterion): """Common class for action criterion that are based on trial information.""" diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 10440f2102e..3e0b6f01fc3 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -1983,6 +1983,7 @@ def test_sqa_storage(self) -> None: # set during next model fitting call), so we unset them on the original GS as # well. gs._unset_non_persistent_state_fields() + ax_client.generation_strategy._unset_non_persistent_state_fields() self.assertEqual(gs, ax_client.generation_strategy) with self.assertRaises(ValueError): # Overwriting existing experiment. @@ -2342,7 +2343,10 @@ def helper_test_get_pareto_optimal_points( ax_client.generation_strategy._fit_current_model( data=ax_client.experiment.lookup_data() ) - self.assertEqual(ax_client.generation_strategy._curr.model_name, "BoTorch") + self.assertEqual( + ax_client.generation_strategy._curr.model_spec_to_gen_from.model_key, + "BoTorch", + ) # Check calling get_best_parameters fails (user must call # get_pareto_optimal_parameters). @@ -2409,7 +2413,10 @@ def helper_test_get_pareto_optimal_points_from_sobol_step( ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials( num_trials=20, minimize=minimize, outcome_constraints=outcome_constraints ) - self.assertEqual(ax_client.generation_strategy._curr.model_name, "Sobol") + self.assertEqual( + ax_client.generation_strategy._curr.model_spec_to_gen_from.model_key, + "Sobol", + ) cfg = not_none(ax_client.experiment.optimization_config) assert isinstance(cfg, MultiObjectiveOptimizationConfig) diff --git a/ax/service/utils/best_point_mixin.py b/ax/service/utils/best_point_mixin.py index 5702709de2a..586483a8937 100644 --- a/ax/service/utils/best_point_mixin.py +++ b/ax/service/utils/best_point_mixin.py @@ -266,7 +266,7 @@ def _get_best_trial( # TODO[drfreund]: Find a way to include data for last trial in the # calculation of best parameters. if use_model_predictions: - current_model = generation_strategy._curr.model + current_model = generation_strategy._curr.model_enum # Cover for the case where source of `self._curr.model` was not a `Models` # enum but a factory function, in which case we cannot do # `get_model_from_generator_run` (since we don't have model type and inputs @@ -381,7 +381,7 @@ def _get_hypervolume( ) if use_model_predictions: - current_model = generation_strategy._curr.model + current_model = generation_strategy._curr.model_enum # Cover for the case where source of `self._curr.model` was not a `Models` # enum but a factory function, in which case we cannot do # `get_model_from_generator_run` (since we don't have model type and inputs diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 30cc6b534ca..f019772cad5 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -494,7 +494,7 @@ def generation_strategy_to_dict( "db_id": generation_strategy._db_id, "name": generation_strategy.name, "steps": generation_strategy._steps, - "curr_index": generation_strategy._curr.index, + "curr_index": generation_strategy.current_step_index, "generator_runs": generation_strategy._generator_runs, "had_initialized_model": generation_strategy.model is not None, "experiment": generation_strategy._experiment, diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 449b5abd205..0b2029f5811 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -858,7 +858,7 @@ def generation_strategy_to_sqa( encoder_registry=self.config.json_encoder_registry, class_encoder_registry=self.config.json_class_encoder_registry, ), - curr_index=generation_strategy._curr.index, + curr_index=generation_strategy.current_step_index, generator_runs=generator_runs_sqa, experiment_id=experiment_id, ) diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index e0c902858ec..8ff4d28acd1 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -328,7 +328,7 @@ def _update_generation_strategy( with session_scope() as session: session.query(gs_sqa_class).filter_by(id=gs_id).update( { - "curr_index": generation_strategy._curr.index, + "curr_index": generation_strategy.current_step_index, "experiment_id": experiment_id, } ) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 848f262313d..83a9a8a335f 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -1270,6 +1270,10 @@ def test_EncodeDecodeGenerationStrategy(self) -> None: # pyre-fixme[6]: For 1st param expected `int` but got `Optional[int]`. gs_id=generation_strategy._db_id ) + # Some fields of the reloaded GS are not expected to be set (both will be + # set during next model fitting call), so we unset them on the original GS as + # well. + generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) self.assertIsNone(generation_strategy._experiment) diff --git a/ax/storage/sqa_store/utils.py b/ax/storage/sqa_store/utils.py index 899461c0494..09a5fe4dae7 100644 --- a/ax/storage/sqa_store/utils.py +++ b/ax/storage/sqa_store/utils.py @@ -35,6 +35,7 @@ "_seen_trial_indices_by_status", "_steps", "analysis_scheduler", + "_nodes", } SKIP_ATTRS_ERROR_SUFFIX = "Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate."