Skip to content

Commit

Permalink
Create Scheduler.generate_candidates() function (facebook#2640)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#2640

Add `Scheduler.generate_candidates()` method which calls
- poll and fetch
- get next trial
- eventaully gen report
- save new trials

Differential Revision: D59606488
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Aug 8, 2024
1 parent c1e0a3a commit 10cb6ae
Show file tree
Hide file tree
Showing 5 changed files with 367 additions and 6 deletions.
7 changes: 7 additions & 0 deletions ax/exceptions/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,10 @@ def __init__(self, error_info: Optional[str]) -> None:
+ "check the documentation, and adjust the configuration accordingly. "
+ f"{error_info}"
)


class OptimizationConfigRequired(ValueError):
"""Error indicating that an candidate generation cannot be completed
because an optimization config was not provided."""

pass
3 changes: 2 additions & 1 deletion ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ax.core.search_space import SearchSpace
from ax.core.types import TCandidateMetadata, TModelPredictArm
from ax.exceptions.core import DataRequiredError, UnsupportedError
from ax.exceptions.generation_strategy import OptimizationConfigRequired
from ax.modelbridge.base import gen_arms, GenResults, ModelBridge
from ax.modelbridge.modelbridge_utils import (
array_to_observation_data,
Expand Down Expand Up @@ -804,7 +805,7 @@ def _get_transformed_model_gen_args(
search_space=search_space, param_names=self.parameters
)
if optimization_config is None:
raise ValueError(
raise OptimizationConfigRequired(
f"{self.__class__.__name__} requires an OptimizationConfig "
"to be specified"
)
Expand Down
43 changes: 43 additions & 0 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from ax.exceptions.generation_strategy import (
AxGenerationException,
MaxParallelismReachedException,
OptimizationConfigRequired,
)
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_strategy import GenerationStrategy
Expand Down Expand Up @@ -1708,6 +1709,15 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]:
)
self.logger.debug(f"Message from generation strategy: {err}")
return []
except OptimizationConfigRequired as err:
if self._log_next_no_trials_reason:
self.logger.info(
"Generated all trials that can be generated currently. "
"`generation_strategy` requires an optimization config "
"to be set before generating more trials."
)
self.logger.debug(f"Message from generation strategy: {err}")
return []

if self.options.trial_type == TrialType.TRIAL and any(
len(generator_run_list[0].arms) > 1 or len(generator_run_list) > 1
Expand Down Expand Up @@ -1737,6 +1747,39 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]:
trials.append(trial)
return trials

def generate_candidates(
self,
num_trials: int = 1,
reduce_state_generator_runs: bool = False,
) -> List[BaseTrial]:
"""Fetch the latest data and generate new candidate trials.
Args:
num_trials: Number of candidate trials to generate.
reduce_state_generator_runs: Flag to determine
whether to save model state for every generator run (default)
or to only save model state on the final generator run of each
batch.
Returns:
List of trials, empty if generation is not possible.
"""
self.poll_and_process_results()
new_trials = self._get_next_trials(
num_trials=num_trials,
n=self.options.batch_size or 1,
)
if len(new_trials) > 0:
new_generator_runs = [gr for t in new_trials for gr in t.generator_runs]
self._save_or_update_trials_and_generation_strategy_if_possible(
experiment=self.experiment,
trials=new_trials,
generation_strategy=self.generation_strategy,
new_generator_runs=new_generator_runs,
reduce_state_generator_runs=reduce_state_generator_runs,
)
return new_trials

def _gen_new_trials_from_generation_strategy(
self,
num_trials: int,
Expand Down
Loading

0 comments on commit 10cb6ae

Please sign in to comment.