diff --git a/ax/benchmark/problems/surrogate.py b/ax/benchmark/problems/surrogate.py index 4f08ecb84f8..c9f39836509 100644 --- a/ax/benchmark/problems/surrogate.py +++ b/ax/benchmark/problems/surrogate.py @@ -5,7 +5,7 @@ # pyre-strict -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union from ax.benchmark.metrics.base import BenchmarkMetricBase @@ -14,21 +14,16 @@ MultiObjectiveOptimizationConfig, OptimizationConfig, ) -from ax.core.runner import Runner from ax.core.search_space import SearchSpace -from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.base import Base -from ax.utils.common.equality import equality_typechecker -from ax.utils.common.typeutils import checked_cast, not_none -from botorch.utils.datasets import SupervisedDataset class SurrogateBenchmarkProblemBase(Base): """ Base class for SOOSurrogateBenchmarkProblem and MOOSurrogateBenchmarkProblem. - Its `runner` is created lazily, when `runner` is accessed or `set_runner` is - called, to defer construction of the surrogate and downloading of datasets. + Its `runner` is a `SurrogateRunner`, which allows for the surrogate to be + constructed lazily and datasets to be downloaded lazily. """ def __init__( @@ -38,14 +33,10 @@ def __init__( search_space: SearchSpace, optimization_config: OptimizationConfig, num_trials: int, - outcome_names: List[str], + runner: SurrogateRunner, + is_noiseless: bool, observe_noise_stds: Union[bool, Dict[str, bool]] = False, - noise_stds: Union[float, Dict[str, float]] = 0.0, - get_surrogate_and_datasets: Optional[ - Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]] - ] = None, tracking_metrics: Optional[List[BenchmarkMetricBase]] = None, - _runner: Optional[Runner] = None, ) -> None: """Construct a `SurrogateBenchmarkProblemBase` instance. @@ -54,80 +45,31 @@ def __init__( search_space: The search space to optimize over. optimization_config: THe optimization config for the problem. num_trials: The number of trials to run. - outcome_names: The names of the metrics the benchmark problem - produces outcome observations for. + runner: A `SurrogateRunner`, allowing for lazy construction of the + surrogate and datasets. observe_noise_stds: Whether or not to observe the observation noise level for each metric. If True/False, observe the the noise standard deviation for all/no metrics. If a dictionary, specify this for individual metrics (metrics not appearing in the dictionary will be assumed to not provide observation noise levels). - noise_stds: The standard deviation(s) of the observation noise(s). - If a single value is provided, it is used for all metrics. Providing - a dictionary allows specifying different noise levels for different - metrics (metrics not appearing in the dictionary will be assumed to - be noiseless - but not necessarily be known to the problem to be - noiseless). - get_surrogate_and_datasets: A factory function that retunrs the Surrogate - and a list of datasets to be used by the surrogate. tracking_metrics: Additional tracking metrics to compute during the optimization (not used to inform the optimization). """ - if get_surrogate_and_datasets is None and _runner is None: - raise ValueError( - "Either `get_surrogate_and_datasets` or `_runner` required." - ) self.name = name self.search_space = search_space self.optimization_config = optimization_config self.num_trials = num_trials - self.outcome_names = outcome_names self.observe_noise_stds = observe_noise_stds - self.noise_stds = noise_stds - self.get_surrogate_and_datasets = get_surrogate_and_datasets self.tracking_metrics: List[BenchmarkMetricBase] = tracking_metrics or [] - self._runner = _runner - - @property - def is_noiseless(self) -> bool: - if self.noise_stds is None: - return True - if isinstance(self.noise_stds, float): - return self.noise_stds == 0.0 - return all(std == 0.0 for std in checked_cast(dict, self.noise_stds).values()) + self.runner = runner + self.is_noiseless = is_noiseless @property def has_ground_truth(self) -> bool: # All surrogate-based problems have a ground truth return True - @equality_typechecker - def __eq__(self, other: Base) -> bool: - if type(other) is not type(self): - return False - - # Checking the whole datasets' equality here would be too expensive to be - # worth it; just check names instead - return self.name == other.name - - def set_runner(self) -> None: - surrogate, datasets = not_none(self.get_surrogate_and_datasets)() - - self._runner = SurrogateRunner( - name=self.name, - surrogate=surrogate, - datasets=datasets, - search_space=self.search_space, - outcome_names=self.outcome_names, - noise_stds=self.noise_stds, - ) - - @property - def runner(self) -> Runner: - if self._runner is None: - self.set_runner() - return not_none(self._runner) - def __repr__(self) -> str: """ Return a string representation that includes only the attributes that @@ -140,7 +82,7 @@ def __repr__(self) -> str: f"num_trials={self.num_trials}, " f"is_noiseless={self.is_noiseless}, " f"observe_noise_stds={self.observe_noise_stds}, " - f"noise_stds={self.noise_stds}, " + f"noise_stds={self.runner.noise_stds}, " f"tracking_metrics={self.tracking_metrics})" ) @@ -161,26 +103,18 @@ def __init__( search_space: SearchSpace, optimization_config: OptimizationConfig, num_trials: int, - outcome_names: List[str], + runner: SurrogateRunner, + is_noiseless: bool, observe_noise_stds: Union[bool, Dict[str, bool]] = False, - noise_stds: Union[float, Dict[str, float]] = 0.0, - get_surrogate_and_datasets: Optional[ - Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]] - ] = None, - tracking_metrics: Optional[List[BenchmarkMetricBase]] = None, - _runner: Optional[Runner] = None, ) -> None: super().__init__( name=name, search_space=search_space, optimization_config=optimization_config, num_trials=num_trials, - outcome_names=outcome_names, observe_noise_stds=observe_noise_stds, - noise_stds=noise_stds, - get_surrogate_and_datasets=get_surrogate_and_datasets, - tracking_metrics=tracking_metrics, - _runner=_runner, + runner=runner, + is_noiseless=is_noiseless, ) self.optimal_value = optimal_value @@ -204,26 +138,20 @@ def __init__( search_space: SearchSpace, optimization_config: MultiObjectiveOptimizationConfig, num_trials: int, - outcome_names: List[str], + runner: SurrogateRunner, + is_noiseless: bool, observe_noise_stds: Union[bool, Dict[str, bool]] = False, - noise_stds: Union[float, Dict[str, float]] = 0.0, - get_surrogate_and_datasets: Optional[ - Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]] - ] = None, tracking_metrics: Optional[List[BenchmarkMetricBase]] = None, - _runner: Optional[Runner] = None, ) -> None: super().__init__( name=name, search_space=search_space, optimization_config=optimization_config, num_trials=num_trials, - outcome_names=outcome_names, observe_noise_stds=observe_noise_stds, - noise_stds=noise_stds, - get_surrogate_and_datasets=get_surrogate_and_datasets, tracking_metrics=tracking_metrics, - _runner=_runner, + runner=runner, + is_noiseless=is_noiseless, ) self.reference_point = reference_point self.optimal_value = optimal_value diff --git a/ax/benchmark/runners/surrogate.py b/ax/benchmark/runners/surrogate.py index f64a5d1dd15..f685307d83a 100644 --- a/ax/benchmark/runners/surrogate.py +++ b/ax/benchmark/runners/surrogate.py @@ -6,7 +6,7 @@ # pyre-strict import warnings -from typing import Any, Dict, Iterable, List, Optional, Set, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from ax.benchmark.runners.base import BenchmarkRunner @@ -15,20 +15,27 @@ from ax.core.observation import ObservationFeatures from ax.core.search_space import SearchSpace from ax.modelbridge.torch import TorchModelBridge +from ax.utils.common.base import Base +from ax.utils.common.equality import equality_typechecker from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry from botorch.utils.datasets import SupervisedDataset +from pyre_extensions import assert_is_instance, none_throws from torch import Tensor class SurrogateRunner(BenchmarkRunner): def __init__( self, + *, name: str, - surrogate: TorchModelBridge, - datasets: List[SupervisedDataset], search_space: SearchSpace, outcome_names: List[str], + surrogate: Optional[TorchModelBridge] = None, + datasets: Optional[List[SupervisedDataset]] = None, noise_stds: Union[float, Dict[str, float]] = 0.0, + get_surrogate_and_datasets: Optional[ + Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]] + ] = None, ) -> None: """Runner for surrogate benchmark problems. @@ -45,15 +52,42 @@ def __init__( is added to all outputs. Alternatively, a dictionary mapping outcome names to noise standard deviations can be provided to specify different noise levels for different outputs. + get_surrogate_and_datasets: Function that returns the surrogate and + datasets, to allow for lazy construction. If + `get_surrogate_and_datasets` is not provided, `surrogate` and + `datasets` must be provided, and vice versa. """ + if get_surrogate_and_datasets is None and ( + surrogate is None or datasets is None + ): + raise ValueError( + "If get_surrogate_and_datasets is provided, surrogate and " + "datasets must not be provided, and vice versa." + ) + self.get_surrogate_and_datasets = get_surrogate_and_datasets self.name = name - self.surrogate = surrogate + self._surrogate = surrogate self._outcome_names = outcome_names - self.datasets = datasets + self._datasets = datasets self.search_space = search_space self.noise_stds = noise_stds self.statuses: Dict[int, TrialStatus] = {} + def set_surrogate_and_datasets(self) -> None: + self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)() + + @property + def surrogate(self) -> TorchModelBridge: + if self.get_surrogate_and_datasets is not None: + self.set_surrogate_and_datasets() + return none_throws(self._surrogate) + + @property + def datasets(self) -> List[SupervisedDataset]: + if self.get_surrogate_and_datasets is not None: + self.set_surrogate_and_datasets() + return none_throws(self._datasets) + @property def outcome_names(self) -> List[str]: return self._outcome_names @@ -131,3 +165,22 @@ def deserialize_init_args( class_decoder_registry: Optional[TClassDecoderRegistry] = None, ) -> Dict[str, Any]: return {} + + @property + def is_noiseless(self) -> bool: + if self.noise_stds is None: + return True + if isinstance(self.noise_stds, float): + return self.noise_stds == 0.0 + return all( + std == 0.0 for std in assert_is_instance(self.noise_stds, dict).values() + ) + + @equality_typechecker + def __eq__(self, other: Base) -> bool: + if type(other) is not type(self): + return False + + # Checking the whole datasets' equality here would be too expensive to be + # worth it; just check names instead + return self.name == other.name diff --git a/ax/benchmark/tests/problems/test_surrogate_problems.py b/ax/benchmark/tests/problems/test_surrogate_problems.py index 6d54784e0e0..c9211d57ce3 100644 --- a/ax/benchmark/tests/problems/test_surrogate_problems.py +++ b/ax/benchmark/tests/problems/test_surrogate_problems.py @@ -9,12 +9,15 @@ import numpy as np from ax.benchmark.benchmark import compute_score_trace from ax.benchmark.benchmark_problem import BenchmarkProblemProtocol -from ax.core.runner import Runner from ax.utils.common.testutils import TestCase from ax.utils.testing.benchmark_stubs import get_moo_surrogate, get_soo_surrogate class TestSurrogateProblems(TestCase): + def setUp(self) -> None: + super().setUp() + self.maxDiff = None + def test_conforms_to_protocol(self) -> None: sbp = get_soo_surrogate() self.assertIsInstance(sbp, BenchmarkProblemProtocol) @@ -22,11 +25,9 @@ def test_conforms_to_protocol(self) -> None: mbp = get_moo_surrogate() self.assertIsInstance(mbp, BenchmarkProblemProtocol) - def test_lazy_instantiation(self) -> None: + def test_repr(self) -> None: - # test instantiation from init sbp = get_soo_surrogate() - # test __repr__ method expected_repr = ( "SOOSurrogateBenchmarkProblem(name=test, " @@ -38,23 +39,6 @@ def test_lazy_instantiation(self) -> None: ) self.assertEqual(repr(sbp), expected_repr) - self.assertIsNone(sbp._runner) - # sets runner - self.assertIsInstance(sbp.runner, Runner) - - self.assertIsNotNone(sbp._runner) - self.assertIsNotNone(sbp.runner) - - # repeat for MOO - sbp = get_moo_surrogate() - - self.assertIsNone(sbp._runner) - # sets runner - self.assertIsInstance(sbp.runner, Runner) - - self.assertIsNotNone(sbp._runner) - self.assertIsNotNone(sbp.runner) - def test_compute_score_trace(self) -> None: soo_problem = get_soo_surrogate() score_trace = compute_score_trace( diff --git a/ax/benchmark/tests/runners/test_surrogate_runner.py b/ax/benchmark/tests/runners/test_surrogate_runner.py index 0fdf4e65154..b9eb7681bbc 100644 --- a/ax/benchmark/tests/runners/test_surrogate_runner.py +++ b/ax/benchmark/tests/runners/test_surrogate_runner.py @@ -8,10 +8,12 @@ from unittest.mock import MagicMock import torch -from ax.benchmark.problems.surrogate import SurrogateRunner +from ax.benchmark.runners.surrogate import SurrogateRunner from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace +from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.testutils import TestCase +from ax.utils.testing.benchmark_stubs import get_soo_surrogate class TestSurrogateRunner(TestCase): @@ -43,3 +45,14 @@ def test_surrogate_runner(self) -> None: self.assertIs(runner.surrogate, surrogate) self.assertEqual(runner.outcome_names, ["dummy_metric"]) self.assertEqual(runner.noise_stds, noise_std) + + def test_lazy_instantiation(self) -> None: + problem = get_soo_surrogate() + + self.assertIsNone(problem.runner._surrogate) + self.assertIsNone(problem.runner._datasets) + + # sets datasets and surrogate + self.assertIsInstance(problem.runner.surrogate, TorchModelBridge) + self.assertIsNotNone(problem.runner._surrogate) + self.assertIsNotNone(problem.runner._datasets) diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index dc24816effb..91ce8af336b 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -300,7 +300,6 @@ def test_replication_sobol_surrogate(self) -> None: ("moo", get_moo_surrogate()), ]: with self.subTest(name, problem=problem): - surrogate, datasets = not_none(problem.get_surrogate_and_datasets)() res = benchmark_replication(problem=problem, method=method, seed=0) self.assertEqual( diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 170334b0258..aca50f03259 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -20,6 +20,7 @@ MOOSurrogateBenchmarkProblem, SOOSurrogateBenchmarkProblem, ) +from ax.benchmark.runners.surrogate import SurrogateRunner from ax.core.experiment import Experiment from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -110,6 +111,12 @@ def get_soo_surrogate() -> SOOSurrogateBenchmarkProblem: data=experiment.lookup_data(), transforms=[], ) + runner = SurrogateRunner( + name="test", + search_space=experiment.search_space, + outcome_names=["branin"], + get_surrogate_and_datasets=lambda: (surrogate, []), + ) return SOOSurrogateBenchmarkProblem( name="test", search_space=experiment.search_space, @@ -117,10 +124,10 @@ def get_soo_surrogate() -> SOOSurrogateBenchmarkProblem: OptimizationConfig, experiment.optimization_config ), num_trials=6, - outcome_names=["branin"], observe_noise_stds=True, - get_surrogate_and_datasets=lambda: (surrogate, []), optimal_value=0.0, + runner=runner, + is_noiseless=runner.is_noiseless, ) @@ -133,6 +140,13 @@ def get_moo_surrogate() -> MOOSurrogateBenchmarkProblem: data=experiment.lookup_data(), transforms=[], ) + + runner = SurrogateRunner( + name="test", + search_space=experiment.search_space, + outcome_names=["branin_a", "branin_b"], + get_surrogate_and_datasets=lambda: (surrogate, []), + ) return MOOSurrogateBenchmarkProblem( name="test", search_space=experiment.search_space, @@ -140,11 +154,11 @@ def get_moo_surrogate() -> MOOSurrogateBenchmarkProblem: MultiObjectiveOptimizationConfig, experiment.optimization_config ), num_trials=10, - outcome_names=["branin_a", "branin_b"], observe_noise_stds=True, - get_surrogate_and_datasets=lambda: (surrogate, []), optimal_value=1.0, reference_point=[], + runner=runner, + is_noiseless=runner.is_noiseless, )