Skip to content

Commit

Permalink
Replace get_generator_run_limit function on GenerationStep with new m…
Browse files Browse the repository at this point in the history
…ethod on GenerationNode (facebook#2018)

Summary:

This diff does the following:
Replaces the `get_generator_run_limit()` method on GenerationStep with `generator_run_limit` on GenerationNode. The new method relies on transition criterion to determine the number of generator runs, and only checks criterion that are trial based. I actually think this may not need to be expanded because the trial based criterion seem the most related to new generator run creation, but it could be expanded easily in the future if a usecase requires doing so.

upcoming:
(0) Finish removing GenerationStep methods in
(1) delete functions from GenStep that aren't needed anymore
(2) update the storage to include nodes independently (and not just as part of step)
(3) final pass on all the doc strings
(4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode
(5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed
(6) rename transiton criterion to action criterion

Reviewed By: lena-kashtelyan

Differential Revision: D51169425
  • Loading branch information
Mia Garrard authored and facebook-github-bot committed Nov 28, 2023
1 parent 46052a1 commit b390bee
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
37 changes: 37 additions & 0 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
MaxTrials,
MinTrials,
TransitionCriterion,
TrialBasedCriterion,
)
from ax.utils.common.base import Base, SortableBase
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -467,6 +468,42 @@ def should_transition_to_next_node(
return True, transition_nodes[0]
return False, None

def generator_run_limit(self) -> int:
"""How many generator runs can this generation strategy generate right now,
assuming each one of them becomes its own trial. Only considers
`transition_criteria` that are TrialBasedCriterion.
Returns:
- the number of generator runs that can currently be produced, with -1
meaning unlimited generator runs,
"""
# TODO @mgarrard remove filter when legacy usecases are updated
valid_criterion = []
for criterion in self.transition_criteria:
if criterion.criterion_class not in {
"MinAsks",
"RunIndefinitely",
}:
valid_criterion.append(criterion)

gen_blocking_criterion_delta_from_threshold = [
criterion.num_till_threshold(
experiment=self.experiment, trials_from_node=self.trials_from_node
)
for criterion in valid_criterion
if criterion.block_gen_if_met and isinstance(criterion, TrialBasedCriterion)
]

if len(gen_blocking_criterion_delta_from_threshold) == 0:
if not self.gen_unlimited_trials:
logger.warning(
"Even though this node is not flagged for generation of unlimited "
"trials, there are no generation blocking criterion, therefore, "
"unlimited trials will be generated."
)
return -1
return min(gen_blocking_criterion_delta_from_threshold)

def __repr__(self) -> str:
"String representation of this GenerationNode"
# add model specs
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def current_generator_run_limit(
return 0, True

# if the generation strategy is not complete, optimization is not complete
return self._curr.get_generator_run_limit(), False
return self._curr.generator_run_limit(), False

def clone_reset(self) -> GenerationStrategy:
"""Copy this generation strategy without it's state."""
Expand Down
36 changes: 36 additions & 0 deletions ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from logging import Logger
from unittest.mock import patch, PropertyMock

from ax.core.observation import ObservationFeatures
Expand All @@ -17,10 +19,13 @@
from ax.modelbridge.generation_node import GenerationNode, GenerationStep
from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import Models
from ax.utils.common.logger import get_logger
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment
from ax.utils.testing.mock import fast_botorch_optimize

logger: Logger = get_logger(__name__)


class TestGenerationNode(TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -182,6 +187,37 @@ def test_multiple_same_fixed_features(self) -> None:
)
self.assertEqual(node.fixed_features, ObservationFeatures(parameters={"x": 0}))

def test_generator_run_limit_unlimited_without_flag(self) -> None:
"""This tests checks that when the `gen_unlimited_trials` flag is false
but there are no generation blocking criteria, then the generator run limit
is set to -1 and a warning is logged.
"""
node = GenerationNode(
node_name="test",
model_specs=[
ModelSpec(
model_enum=Models.GPEI,
model_kwargs={},
model_gen_kwargs={
"n": -1,
"fixed_features": ObservationFeatures(parameters={"x": 0}),
},
),
],
gen_unlimited_trials=False,
)
warning_msg = (
"Even though this node is not flagged for generation of unlimited "
"trials, there are no generation blocking criterion, therefore, "
"unlimited trials will be generated."
)
with self.assertLogs(GenerationNode.__module__, logging.WARNING) as logger:
self.assertEqual(node.generator_run_limit(), -1)
self.assertTrue(
any(warning_msg in output for output in logger.output),
logger.output,
)


class TestGenerationStep(TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit b390bee

Please sign in to comment.