diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index 029889722bd..d32c3ed8d3f 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -27,10 +27,7 @@ import numpy as np from ax.benchmark.benchmark_method import BenchmarkMethod -from ax.benchmark.benchmark_problem import ( - BenchmarkProblemProtocol, - BenchmarkProblemWithKnownOptimum, -) +from ax.benchmark.benchmark_problem import BenchmarkProblem from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult from ax.benchmark.metrics.base import BenchmarkMetricBase, GroundTruthMetricMixin from ax.core.experiment import Experiment @@ -53,16 +50,14 @@ def compute_score_trace( optimization_trace: np.ndarray, num_baseline_trials: int, - problem: BenchmarkProblemProtocol, + problem: BenchmarkProblem, ) -> np.ndarray: """Computes a score trace from the optimization trace.""" # Use the first GenerationStep's best found point as baseline. Sometimes (ex. in # a timeout) the first GenerationStep will not have not completed and we will not # have enough trials; in this case we do not score. - if (len(optimization_trace) <= num_baseline_trials) or not isinstance( - problem, BenchmarkProblemWithKnownOptimum - ): + if len(optimization_trace) <= num_baseline_trials: return np.full(len(optimization_trace), np.nan) optimum = problem.optimal_value baseline = optimization_trace[num_baseline_trials - 1] @@ -77,7 +72,7 @@ def compute_score_trace( def _create_benchmark_experiment( - problem: BenchmarkProblemProtocol, method_name: str + problem: BenchmarkProblem, method_name: str ) -> Experiment: """Creates an empty experiment for the given problem and method. @@ -117,7 +112,7 @@ def _create_benchmark_experiment( def benchmark_replication( - problem: BenchmarkProblemProtocol, + problem: BenchmarkProblem, method: BenchmarkMethod, seed: int, ) -> BenchmarkResult: @@ -192,7 +187,7 @@ def benchmark_replication( def benchmark_one_method_problem( - problem: BenchmarkProblemProtocol, + problem: BenchmarkProblem, method: BenchmarkMethod, seeds: Iterable[int], ) -> AggregatedBenchmarkResult: @@ -205,7 +200,7 @@ def benchmark_one_method_problem( def benchmark_multiple_problems_methods( - problems: Iterable[BenchmarkProblemProtocol], + problems: Iterable[BenchmarkProblem], methods: Iterable[BenchmarkMethod], seeds: Iterable[int], ) -> List[AggregatedBenchmarkResult]: @@ -222,7 +217,7 @@ def benchmark_multiple_problems_methods( def make_ground_truth_metrics( - problem: BenchmarkProblemProtocol, + problem: BenchmarkProblem, include_tracking_metrics: bool = True, ) -> Dict[str, Metric]: """Makes a ground truth version for each metric defined on the problem. diff --git a/ax/benchmark/benchmark_problem.py b/ax/benchmark/benchmark_problem.py index 604bb30c4c5..b1f36b3bf8f 100644 --- a/ax/benchmark/benchmark_problem.py +++ b/ax/benchmark/benchmark_problem.py @@ -11,17 +11,7 @@ # in the UI. from dataclasses import dataclass, field -from typing import ( - Any, - Dict, - List, - Optional, - Protocol, - runtime_checkable, - Type, - TypeVar, - Union, -) +from typing import Any, Dict, List, Optional, Type, TypeVar, Union from ax.benchmark.metrics.base import BenchmarkMetricBase @@ -72,33 +62,6 @@ def _get_name( return f"{base_name}{observed_noise}{dim_str}" -@runtime_checkable -class BenchmarkProblemProtocol(Protocol): - """ - Specifies the interface any benchmark problem must adhere to. - - Classes implementing this interface include BenchmarkProblem, - SurrogateBenchmarkProblem, and MOOSurrogateBenchmarkProblem. - """ - - name: str - search_space: SearchSpace - optimization_config: OptimizationConfig - num_trials: int - tracking_metrics: List[BenchmarkMetricBase] - is_noiseless: bool # If True, evaluations are deterministic - observe_noise_stds: Union[ - bool, Dict[str, bool] - ] # Whether we observe the observation noise level - has_ground_truth: bool # if True, evals (w/o synthetic noise) are determinstic - runner: Runner - - -@runtime_checkable -class BenchmarkProblemWithKnownOptimum(Protocol): - optimal_value: float - - @dataclass(kw_only=True, repr=True) class BenchmarkProblem(Base): """ diff --git a/ax/benchmark/tests/problems/test_surrogate_problems.py b/ax/benchmark/tests/problems/test_surrogate_problems.py index 81b772eaae8..c9c2a334096 100644 --- a/ax/benchmark/tests/problems/test_surrogate_problems.py +++ b/ax/benchmark/tests/problems/test_surrogate_problems.py @@ -8,7 +8,7 @@ import numpy as np from ax.benchmark.benchmark import compute_score_trace -from ax.benchmark.benchmark_problem import BenchmarkProblemProtocol +from ax.benchmark.benchmark_problem import BenchmarkProblem from ax.utils.common.testutils import TestCase from ax.utils.testing.benchmark_stubs import get_moo_surrogate, get_soo_surrogate @@ -19,12 +19,12 @@ def setUp(self) -> None: # print max output so errors in 'repr' can be fully shown self.maxDiff = None - def test_conforms_to_protocol(self) -> None: + def test_conforms_to_api(self) -> None: sbp = get_soo_surrogate() - self.assertIsInstance(sbp, BenchmarkProblemProtocol) + self.assertIsInstance(sbp, BenchmarkProblem) mbp = get_moo_surrogate() - self.assertIsInstance(mbp, BenchmarkProblemProtocol) + self.assertIsInstance(mbp, BenchmarkProblem) def test_repr(self) -> None: