diff --git a/ax/benchmark/benchmark_problem.py b/ax/benchmark/benchmark_problem.py index d7cf78cc46a..efd6b1acf11 100644 --- a/ax/benchmark/benchmark_problem.py +++ b/ax/benchmark/benchmark_problem.py @@ -206,13 +206,12 @@ class MultiObjectiveBenchmarkProblem(BenchmarkProblem): A `BenchmarkProblem` that supports multiple objectives. For multi-objective problems, `optimal_value` indicates the maximum - hypervolume attainable with the given `reference_point`. + hypervolume attainable with the objective thresholds provided on the + `optimization_config`. - For argument descriptions, see `BenchmarkProblem`; it additionally takes a `runner` - and a `reference_point`. + For argument descriptions, see `BenchmarkProblem`. """ - reference_point: List[float] optimization_config: MultiObjectiveOptimizationConfig @@ -289,5 +288,4 @@ def create_multi_objective_problem_from_botorch( observe_noise_stds=observe_noise_sd, has_ground_truth=problem.has_ground_truth, optimal_value=test_problem.max_hv, - reference_point=test_problem._ref_point, ) diff --git a/ax/benchmark/problems/surrogate.py b/ax/benchmark/problems/surrogate.py index a165216b252..4d9864540c1 100644 --- a/ax/benchmark/problems/surrogate.py +++ b/ax/benchmark/problems/surrogate.py @@ -4,12 +4,19 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +""" +Benchmark problems based on surrogates. + +These problems might appear to function identically to their non-surrogate +counterparts, `BenchmarkProblem` and `MultiObjectiveBenchmarkProblem`, aside +from the restriction that their runners are of type `SurrogateRunner`. However, +they are treated specially within JSON storage because surrogates cannot be +easily serialized. +""" from dataclasses import dataclass, field -from typing import List from ax.benchmark.benchmark_problem import BenchmarkProblem - from ax.benchmark.runners.surrogate import SurrogateRunner from ax.core.optimization_config import MultiObjectiveOptimizationConfig @@ -21,6 +28,8 @@ class SurrogateBenchmarkProblemBase(BenchmarkProblem): Its `runner` is a `SurrogateRunner`, which allows for the surrogate to be constructed lazily and datasets to be downloaded lazily. + + For argument descriptions, see `BenchmarkProblem`. """ runner: SurrogateRunner = field(repr=False) @@ -34,9 +43,10 @@ class SOOSurrogateBenchmarkProblem(SurrogateBenchmarkProblemBase): class MOOSurrogateBenchmarkProblem(SurrogateBenchmarkProblemBase): """ Has the same attributes/properties as a `MultiObjectiveBenchmarkProblem`, - but its runner is not constructed until needed, to allow for deferring - constructing the surrogate and downloading data. + but its `runner` is a `SurrogateRunner`, which allows for the surrogate to be + constructed lazily and datasets to be downloaded lazily. + + For argument descriptions, see `BenchmarkProblem`. """ optimization_config: MultiObjectiveOptimizationConfig - reference_point: List[float] diff --git a/ax/benchmark/tests/test_benchmark_problem.py b/ax/benchmark/tests/test_benchmark_problem.py index cdd5a931a80..f9b210d2077 100644 --- a/ax/benchmark/tests/test_benchmark_problem.py +++ b/ax/benchmark/tests/test_benchmark_problem.py @@ -203,7 +203,11 @@ def test_moo_from_botorch(self) -> None: # Test hypervolume self.assertEqual(branin_currin_problem.optimal_value, test_problem._max_hv) - self.assertEqual(branin_currin_problem.reference_point, test_problem._ref_point) + opt_config = branin_currin_problem.optimization_config + reference_point = [ + threshold.bound for threshold in opt_config.objective_thresholds + ] + self.assertEqual(reference_point, test_problem._ref_point) def test_moo_from_botorch_constrained(self) -> None: with self.assertRaisesRegex( diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 7614df239b5..905b917c77a 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -148,7 +148,6 @@ def get_moo_surrogate() -> MOOSurrogateBenchmarkProblem: num_trials=10, observe_noise_stds=True, optimal_value=1.0, - reference_point=[], runner=runner, is_noiseless=runner.is_noiseless, )