diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index c07c91af3ba..417065827f9 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -468,7 +468,7 @@ def should_transition_to_next_node( return True, transition_nodes[0] return False, None - def generator_run_limit(self) -> int: + def generator_run_limit(self, supress_generation_errors: bool = True) -> int: """How many generator runs can this generation strategy generate right now, assuming each one of them becomes its own trial. Only considers `transition_criteria` that are TrialBasedCriterion. @@ -477,14 +477,43 @@ def generator_run_limit(self) -> int: - the number of generator runs that can currently be produced, with -1 meaning unlimited generator runs, """ + + # TODO: @mgarrard Should we consider returning `None` if there is no limit? + # TODO:@mgarrard Should we instead have `raise_generation_error`? The name + # of this method doesn't suggest that it would raise errors by default, since + # it's just finding out the limit according to the name. I know we want the + # errors in some cases, so we could call the flag `raise_error_if_cannot_gen` or + # something like that : ) + trial_based_gen_blocking_criteria = [ + criterion + for criterion in self.transition_criteria + if criterion.block_gen_if_met and isinstance(criterion, TrialBasedCriterion) + ] gen_blocking_criterion_delta_from_threshold = [ criterion.num_till_threshold( experiment=self.experiment, trials_from_node=self.trials_from_node ) - for criterion in self.transition_criteria - if criterion.block_gen_if_met and isinstance(criterion, TrialBasedCriterion) + for criterion in trial_based_gen_blocking_criteria ] + # Raise any necessary generation errors: for any met criterion, + # call its `block_continued_generation_error` method The method might not + # raise an error, depending on its implementation on given criterion, so the + # error from the first met one that does block continued generation, will be + # raised. + if not supress_generation_errors: + for criterion in trial_based_gen_blocking_criteria: + # TODO[mgarrard]: Raise a group of all the errors, from each gen- + # blocking transition criterion. + if criterion.is_met( + self.experiment, trials_from_node=self.trials_from_node + ): + criterion.block_continued_generation_error( + node_name=self.node_name, + model_name=self.model_to_gen_from_name, + experiment=self.experiment, + trials_from_node=self.trials_from_node, + ) if len(gen_blocking_criterion_delta_from_threshold) == 0: if not self.gen_unlimited_trials: logger.warning( diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index af2cdcfc5d2..c6cc06624b7 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -396,21 +396,13 @@ def _gen_multiple( self._maybe_move_to_next_step() self._fit_current_model(data=data) - # Make sure to not make too many generator runs and - # exceed maximum allowed paralellism for the step. - num_until_max_parallelism = ( - self._curr.num_remaining_trials_until_max_parallelism() - ) - if num_until_max_parallelism is not None: - num_generator_runs = min(num_generator_runs, num_until_max_parallelism) - - # Make sure not to extend number of trials expected in step. - if self._curr.enforce_num_trials and self._curr.num_trials > 0: - num_generator_runs = min( - num_generator_runs, - self._curr.num_trials - self._curr.num_can_complete, - ) - + # Get GeneratorRun limit that respects the node's transition criterion that + # affect the number of generator runs that can be produced. + gr_limit = self._curr.generator_run_limit(supress_generation_errors=False) + if gr_limit == -1: + num_generator_runs = max(num_generator_runs, 1) + else: + num_generator_runs = max(min(num_generator_runs, gr_limit), 1) generator_runs = [] pending_observations = deepcopy(pending_observations) or {} for _ in range(num_generator_runs): diff --git a/ax/service/tests/test_interactive_loop.py b/ax/service/tests/test_interactive_loop.py index 4fe08450e44..13ccc0af8c8 100644 --- a/ax/service/tests/test_interactive_loop.py +++ b/ax/service/tests/test_interactive_loop.py @@ -12,6 +12,8 @@ import numpy as np from ax.core.types import TEvaluationOutcome +from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy +from ax.modelbridge.registry import Models from ax.service.ax_client import AxClient, TParameterization from ax.service.interactive_loop import interactive_optimize_with_client from ax.utils.common.testutils import TestCase @@ -99,7 +101,11 @@ def _elicit( }, ) - ax_client = AxClient() + # GS with lo max parallelismm to induce MaxParallelismException: + generation_strategy = GenerationStrategy( + steps=[GenerationStep(model=Models.SOBOL, max_parallelism=1, num_trials=-1)] + ) + ax_client = AxClient(generation_strategy=generation_strategy) ax_client.create_experiment( name="hartmann_test_experiment", # pyre-fixme[6] @@ -116,9 +122,6 @@ def _elicit( minimize=True, ) - # Lower max parallelism to induce MaxParallelismException - ax_client.generation_strategy._steps[0].max_parallelism = 1 - with self.assertLogs(logger="ax", level=WARN) as logger: interactive_optimize_with_client( ax_client=ax_client,