From 66dd7df67abab5b48c6e8797643b858359577070 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Tue, 28 Nov 2023 12:18:27 -0800 Subject: [PATCH] Replace get_generator_run_limit function on GenerationStep with new method on GenerationNode (#2018) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2018 This diff does the following: Replaces the `get_generator_run_limit()` method on GenerationStep with `generator_run_limit` on GenerationNode. The new method relies on transition criterion to determine the number of generator runs, and only checks criterion that are trial based. I actually think this may not need to be expanded because the trial based criterion seem the most related to new generator run creation, but it could be expanded easily in the future if a usecase requires doing so. upcoming: (0) Finish removing GenerationStep methods in (1) delete functions from GenStep that aren't needed anymore (2) update the storage to include nodes independently (and not just as part of step) (3) final pass on all the doc strings (4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode (5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed (6) rename transiton criterion to action criterion Reviewed By: lena-kashtelyan Differential Revision: D51169425 fbshipit-source-id: 8949b093ebc4b96c61e941748c3dafc0415dfa8a --- ax/modelbridge/generation_node.py | 37 ++++++++++++++++++++ ax/modelbridge/generation_strategy.py | 2 +- ax/modelbridge/tests/test_generation_node.py | 36 +++++++++++++++++++ 3 files changed, 74 insertions(+), 1 deletion(-) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 5c3a3a541f2..457180d8845 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -38,6 +38,7 @@ MaxTrials, MinTrials, TransitionCriterion, + TrialBasedCriterion, ) from ax.utils.common.base import Base, SortableBase from ax.utils.common.logger import get_logger @@ -467,6 +468,42 @@ def should_transition_to_next_node( return True, transition_nodes[0] return False, None + def generator_run_limit(self) -> int: + """How many generator runs can this generation strategy generate right now, + assuming each one of them becomes its own trial. Only considers + `transition_criteria` that are TrialBasedCriterion. + + Returns: + - the number of generator runs that can currently be produced, with -1 + meaning unlimited generator runs, + """ + # TODO @mgarrard remove filter when legacy usecases are updated + valid_criterion = [] + for criterion in self.transition_criteria: + if criterion.criterion_class not in { + "MinAsks", + "RunIndefinitely", + }: + valid_criterion.append(criterion) + + gen_blocking_criterion_delta_from_threshold = [ + criterion.num_till_threshold( + experiment=self.experiment, trials_from_node=self.trials_from_node + ) + for criterion in valid_criterion + if criterion.block_gen_if_met and isinstance(criterion, TrialBasedCriterion) + ] + + if len(gen_blocking_criterion_delta_from_threshold) == 0: + if not self.gen_unlimited_trials: + logger.warning( + "Even though this node is not flagged for generation of unlimited " + "trials, there are no generation blocking criterion, therefore, " + "unlimited trials will be generated." + ) + return -1 + return min(gen_blocking_criterion_delta_from_threshold) + def __repr__(self) -> str: "String representation of this GenerationNode" # add model specs diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 0ccddcb3868..af2cdcfc5d2 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -308,7 +308,7 @@ def current_generator_run_limit( return 0, True # if the generation strategy is not complete, optimization is not complete - return self._curr.get_generator_run_limit(), False + return self._curr.generator_run_limit(), False def clone_reset(self) -> GenerationStrategy: """Copy this generation strategy without it's state.""" diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/modelbridge/tests/test_generation_node.py index 2970078cdf6..a4206b6aeee 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/modelbridge/tests/test_generation_node.py @@ -4,6 +4,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging +from logging import Logger from unittest.mock import patch, PropertyMock from ax.core.observation import ObservationFeatures @@ -17,10 +19,13 @@ from ax.modelbridge.generation_node import GenerationNode, GenerationStep from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec from ax.modelbridge.registry import Models +from ax.utils.common.logger import get_logger from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_branin_experiment from ax.utils.testing.mock import fast_botorch_optimize +logger: Logger = get_logger(__name__) + class TestGenerationNode(TestCase): def setUp(self) -> None: @@ -182,6 +187,37 @@ def test_multiple_same_fixed_features(self) -> None: ) self.assertEqual(node.fixed_features, ObservationFeatures(parameters={"x": 0})) + def test_generator_run_limit_unlimited_without_flag(self) -> None: + """This tests checks that when the `gen_unlimited_trials` flag is false + but there are no generation blocking criteria, then the generator run limit + is set to -1 and a warning is logged. + """ + node = GenerationNode( + node_name="test", + model_specs=[ + ModelSpec( + model_enum=Models.GPEI, + model_kwargs={}, + model_gen_kwargs={ + "n": -1, + "fixed_features": ObservationFeatures(parameters={"x": 0}), + }, + ), + ], + gen_unlimited_trials=False, + ) + warning_msg = ( + "Even though this node is not flagged for generation of unlimited " + "trials, there are no generation blocking criterion, therefore, " + "unlimited trials will be generated." + ) + with self.assertLogs(GenerationNode.__module__, logging.WARNING) as logger: + self.assertEqual(node.generator_run_limit(), -1) + self.assertTrue( + any(warning_msg in output for output in logger.output), + logger.output, + ) + class TestGenerationStep(TestCase): def setUp(self) -> None: