Skip to content

Commit

Permalink
Move lazy construction of a surrogate from problem to runner (faceboo…
Browse files Browse the repository at this point in the history
…k#2603)

Summary:
Pull Request resolved: facebook#2603

Context: Surrogate benchmark problems allow for downloading datasets and constructing a surrogate lazily. Since the surrogates and datasets are only needed for the `Runner`, it makes sense to confine this logic to `SurrogateRunner`. This gives surrogate benchmark problems an interface that is much clsoer to that of non-surrogate benchmark problems. In the future, we should be able to get down to just one `BenchmarkProblem` class.

This PR:
* Moves lazy construction of surrogates from the `Problem` to the `Runner`.
* Moves corresponding unit tests from the problem's file to the runner's.
* Removes the attribute `noise_stds` from the problem, since it duplicates the same attribute on the runner and doesn't conform to the interface of other benchmark problems.
* Requires `is_noiseless` to be provided at problem initialization, to make surrogate problems have the same interface as other problems, and adds an attribute `SurrogateRunner.is_noiseless` so that this is not difficult to provide.

Differential Revision: D60266288
  • Loading branch information
esantorella authored and facebook-github-bot committed Jul 26, 2024
1 parent 1ab07a7 commit edec0ee
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 122 deletions.
108 changes: 18 additions & 90 deletions ax/benchmark/problems/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pyre-strict

from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union

from ax.benchmark.metrics.base import BenchmarkMetricBase

Expand All @@ -14,21 +14,16 @@
MultiObjectiveOptimizationConfig,
OptimizationConfig,
)
from ax.core.runner import Runner
from ax.core.search_space import SearchSpace
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.typeutils import checked_cast, not_none
from botorch.utils.datasets import SupervisedDataset


class SurrogateBenchmarkProblemBase(Base):
"""
Base class for SOOSurrogateBenchmarkProblem and MOOSurrogateBenchmarkProblem.
Its `runner` is created lazily, when `runner` is accessed or `set_runner` is
called, to defer construction of the surrogate and downloading of datasets.
Its `runner` is a `SurrogateRunner`, which allows for the surrogate to be
constructed lazily and datasets to be downloaded lazily.
"""

def __init__(
Expand All @@ -38,14 +33,10 @@ def __init__(
search_space: SearchSpace,
optimization_config: OptimizationConfig,
num_trials: int,
outcome_names: List[str],
runner: SurrogateRunner,
is_noiseless: bool,
observe_noise_stds: Union[bool, Dict[str, bool]] = False,
noise_stds: Union[float, Dict[str, float]] = 0.0,
get_surrogate_and_datasets: Optional[
Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]]
] = None,
tracking_metrics: Optional[List[BenchmarkMetricBase]] = None,
_runner: Optional[Runner] = None,
) -> None:
"""Construct a `SurrogateBenchmarkProblemBase` instance.
Expand All @@ -54,80 +45,31 @@ def __init__(
search_space: The search space to optimize over.
optimization_config: THe optimization config for the problem.
num_trials: The number of trials to run.
outcome_names: The names of the metrics the benchmark problem
produces outcome observations for.
runner: A `SurrogateRunner`, allowing for lazy construction of the
surrogate and datasets.
observe_noise_stds: Whether or not to observe the observation noise
level for each metric. If True/False, observe the the noise standard
deviation for all/no metrics. If a dictionary, specify this for
individual metrics (metrics not appearing in the dictionary will
be assumed to not provide observation noise levels).
noise_stds: The standard deviation(s) of the observation noise(s).
If a single value is provided, it is used for all metrics. Providing
a dictionary allows specifying different noise levels for different
metrics (metrics not appearing in the dictionary will be assumed to
be noiseless - but not necessarily be known to the problem to be
noiseless).
get_surrogate_and_datasets: A factory function that retunrs the Surrogate
and a list of datasets to be used by the surrogate.
tracking_metrics: Additional tracking metrics to compute during the
optimization (not used to inform the optimization).
"""

if get_surrogate_and_datasets is None and _runner is None:
raise ValueError(
"Either `get_surrogate_and_datasets` or `_runner` required."
)
self.name = name
self.search_space = search_space
self.optimization_config = optimization_config
self.num_trials = num_trials
self.outcome_names = outcome_names
self.observe_noise_stds = observe_noise_stds
self.noise_stds = noise_stds
self.get_surrogate_and_datasets = get_surrogate_and_datasets
self.tracking_metrics: List[BenchmarkMetricBase] = tracking_metrics or []
self._runner = _runner

@property
def is_noiseless(self) -> bool:
if self.noise_stds is None:
return True
if isinstance(self.noise_stds, float):
return self.noise_stds == 0.0
return all(std == 0.0 for std in checked_cast(dict, self.noise_stds).values())
self.runner = runner
self.is_noiseless = is_noiseless

@property
def has_ground_truth(self) -> bool:
# All surrogate-based problems have a ground truth
return True

@equality_typechecker
def __eq__(self, other: Base) -> bool:
if type(other) is not type(self):
return False

# Checking the whole datasets' equality here would be too expensive to be
# worth it; just check names instead
return self.name == other.name

def set_runner(self) -> None:
surrogate, datasets = not_none(self.get_surrogate_and_datasets)()

self._runner = SurrogateRunner(
name=self.name,
surrogate=surrogate,
datasets=datasets,
search_space=self.search_space,
outcome_names=self.outcome_names,
noise_stds=self.noise_stds,
)

@property
def runner(self) -> Runner:
if self._runner is None:
self.set_runner()
return not_none(self._runner)

def __repr__(self) -> str:
"""
Return a string representation that includes only the attributes that
Expand All @@ -140,7 +82,7 @@ def __repr__(self) -> str:
f"num_trials={self.num_trials}, "
f"is_noiseless={self.is_noiseless}, "
f"observe_noise_stds={self.observe_noise_stds}, "
f"noise_stds={self.noise_stds}, "
f"noise_stds={self.runner.noise_stds}, "
f"tracking_metrics={self.tracking_metrics})"
)

Expand All @@ -161,26 +103,18 @@ def __init__(
search_space: SearchSpace,
optimization_config: OptimizationConfig,
num_trials: int,
outcome_names: List[str],
runner: SurrogateRunner,
is_noiseless: bool,
observe_noise_stds: Union[bool, Dict[str, bool]] = False,
noise_stds: Union[float, Dict[str, float]] = 0.0,
get_surrogate_and_datasets: Optional[
Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]]
] = None,
tracking_metrics: Optional[List[BenchmarkMetricBase]] = None,
_runner: Optional[Runner] = None,
) -> None:
super().__init__(
name=name,
search_space=search_space,
optimization_config=optimization_config,
num_trials=num_trials,
outcome_names=outcome_names,
observe_noise_stds=observe_noise_stds,
noise_stds=noise_stds,
get_surrogate_and_datasets=get_surrogate_and_datasets,
tracking_metrics=tracking_metrics,
_runner=_runner,
runner=runner,
is_noiseless=is_noiseless,
)
self.optimal_value = optimal_value

Expand All @@ -204,26 +138,20 @@ def __init__(
search_space: SearchSpace,
optimization_config: MultiObjectiveOptimizationConfig,
num_trials: int,
outcome_names: List[str],
runner: SurrogateRunner,
is_noiseless: bool,
observe_noise_stds: Union[bool, Dict[str, bool]] = False,
noise_stds: Union[float, Dict[str, float]] = 0.0,
get_surrogate_and_datasets: Optional[
Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]]
] = None,
tracking_metrics: Optional[List[BenchmarkMetricBase]] = None,
_runner: Optional[Runner] = None,
) -> None:
super().__init__(
name=name,
search_space=search_space,
optimization_config=optimization_config,
num_trials=num_trials,
outcome_names=outcome_names,
observe_noise_stds=observe_noise_stds,
noise_stds=noise_stds,
get_surrogate_and_datasets=get_surrogate_and_datasets,
tracking_metrics=tracking_metrics,
_runner=_runner,
runner=runner,
is_noiseless=is_noiseless,
)
self.reference_point = reference_point
self.optimal_value = optimal_value
63 changes: 58 additions & 5 deletions ax/benchmark/runners/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pyre-strict

import warnings
from typing import Any, Dict, Iterable, List, Optional, Set, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union

import torch
from ax.benchmark.runners.base import BenchmarkRunner
Expand All @@ -15,20 +15,27 @@
from ax.core.observation import ObservationFeatures
from ax.core.search_space import SearchSpace
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
from botorch.utils.datasets import SupervisedDataset
from pyre_extensions import assert_is_instance, none_throws
from torch import Tensor


class SurrogateRunner(BenchmarkRunner):
def __init__(
self,
*,
name: str,
surrogate: TorchModelBridge,
datasets: List[SupervisedDataset],
search_space: SearchSpace,
outcome_names: List[str],
surrogate: Optional[TorchModelBridge] = None,
datasets: Optional[List[SupervisedDataset]] = None,
noise_stds: Union[float, Dict[str, float]] = 0.0,
get_surrogate_and_datasets: Optional[
Callable[[], Tuple[TorchModelBridge, List[SupervisedDataset]]]
] = None,
) -> None:
"""Runner for surrogate benchmark problems.
Expand All @@ -45,15 +52,42 @@ def __init__(
is added to all outputs. Alternatively, a dictionary mapping outcome
names to noise standard deviations can be provided to specify different
noise levels for different outputs.
get_surrogate_and_datasets: Function that returns the surrogate and
datasets, to allow for lazy construction. If
`get_surrogate_and_datasets` is not provided, `surrogate` and
`datasets` must be provided, and vice versa.
"""
if get_surrogate_and_datasets is None and (
surrogate is None or datasets is None
):
raise ValueError(
"If get_surrogate_and_datasets is provided, surrogate and "
"datasets must not be provided, and vice versa."
)
self.get_surrogate_and_datasets = get_surrogate_and_datasets
self.name = name
self.surrogate = surrogate
self._surrogate = surrogate
self._outcome_names = outcome_names
self.datasets = datasets
self._datasets = datasets
self.search_space = search_space
self.noise_stds = noise_stds
self.statuses: Dict[int, TrialStatus] = {}

def set_surrogate_and_datasets(self) -> None:
self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)()

@property
def surrogate(self) -> TorchModelBridge:
if self.get_surrogate_and_datasets is not None:
self.set_surrogate_and_datasets()
return none_throws(self._surrogate)

@property
def datasets(self) -> List[SupervisedDataset]:
if self.get_surrogate_and_datasets is not None:
self.set_surrogate_and_datasets()
return none_throws(self._datasets)

@property
def outcome_names(self) -> List[str]:
return self._outcome_names
Expand Down Expand Up @@ -131,3 +165,22 @@ def deserialize_init_args(
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
) -> Dict[str, Any]:
return {}

@property
def is_noiseless(self) -> bool:
if self.noise_stds is None:
return True
if isinstance(self.noise_stds, float):
return self.noise_stds == 0.0
return all(
std == 0.0 for std in assert_is_instance(self.noise_stds, dict).values()
)

@equality_typechecker
def __eq__(self, other: Base) -> bool:
if type(other) is not type(self):
return False

# Checking the whole datasets' equality here would be too expensive to be
# worth it; just check names instead
return self.name == other.name
26 changes: 5 additions & 21 deletions ax/benchmark/tests/problems/test_surrogate_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,25 @@
import numpy as np
from ax.benchmark.benchmark import compute_score_trace
from ax.benchmark.benchmark_problem import BenchmarkProblemProtocol
from ax.core.runner import Runner
from ax.utils.common.testutils import TestCase
from ax.utils.testing.benchmark_stubs import get_moo_surrogate, get_soo_surrogate


class TestSurrogateProblems(TestCase):
def setUp(self) -> None:
super().setUp()
self.maxDiff = None

def test_conforms_to_protocol(self) -> None:
sbp = get_soo_surrogate()
self.assertIsInstance(sbp, BenchmarkProblemProtocol)

mbp = get_moo_surrogate()
self.assertIsInstance(mbp, BenchmarkProblemProtocol)

def test_lazy_instantiation(self) -> None:
def test_repr(self) -> None:

# test instantiation from init
sbp = get_soo_surrogate()
# test __repr__ method

expected_repr = (
"SOOSurrogateBenchmarkProblem(name=test, "
Expand All @@ -38,23 +39,6 @@ def test_lazy_instantiation(self) -> None:
)
self.assertEqual(repr(sbp), expected_repr)

self.assertIsNone(sbp._runner)
# sets runner
self.assertIsInstance(sbp.runner, Runner)

self.assertIsNotNone(sbp._runner)
self.assertIsNotNone(sbp.runner)

# repeat for MOO
sbp = get_moo_surrogate()

self.assertIsNone(sbp._runner)
# sets runner
self.assertIsInstance(sbp.runner, Runner)

self.assertIsNotNone(sbp._runner)
self.assertIsNotNone(sbp.runner)

def test_compute_score_trace(self) -> None:
soo_problem = get_soo_surrogate()
score_trace = compute_score_trace(
Expand Down
15 changes: 14 additions & 1 deletion ax/benchmark/tests/runners/test_surrogate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from unittest.mock import MagicMock

import torch
from ax.benchmark.problems.surrogate import SurrogateRunner
from ax.benchmark.runners.surrogate import SurrogateRunner
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.testutils import TestCase
from ax.utils.testing.benchmark_stubs import get_soo_surrogate


class TestSurrogateRunner(TestCase):
Expand Down Expand Up @@ -43,3 +45,14 @@ def test_surrogate_runner(self) -> None:
self.assertIs(runner.surrogate, surrogate)
self.assertEqual(runner.outcome_names, ["dummy_metric"])
self.assertEqual(runner.noise_stds, noise_std)

def test_lazy_instantiation(self) -> None:
problem = get_soo_surrogate()

self.assertIsNone(problem.runner._surrogate)
self.assertIsNone(problem.runner._datasets)

# sets datasets and surrogate
self.assertIsInstance(problem.runner.surrogate, TorchModelBridge)
self.assertIsNotNone(problem.runner._surrogate)
self.assertIsNotNone(problem.runner._datasets)
Loading

0 comments on commit edec0ee

Please sign in to comment.