diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index 029889722bd..e3a69a9a36c 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -27,10 +27,7 @@ import numpy as np from ax.benchmark.benchmark_method import BenchmarkMethod -from ax.benchmark.benchmark_problem import ( - BenchmarkProblemProtocol, - BenchmarkProblemWithKnownOptimum, -) +from ax.benchmark.benchmark_problem import BenchmarkProblemProtocol from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult from ax.benchmark.metrics.base import BenchmarkMetricBase, GroundTruthMetricMixin from ax.core.experiment import Experiment @@ -60,9 +57,7 @@ def compute_score_trace( # Use the first GenerationStep's best found point as baseline. Sometimes (ex. in # a timeout) the first GenerationStep will not have not completed and we will not # have enough trials; in this case we do not score. - if (len(optimization_trace) <= num_baseline_trials) or not isinstance( - problem, BenchmarkProblemWithKnownOptimum - ): + if len(optimization_trace) <= num_baseline_trials: return np.full(len(optimization_trace), np.nan) optimum = problem.optimal_value baseline = optimization_trace[num_baseline_trials - 1] diff --git a/ax/benchmark/benchmark_problem.py b/ax/benchmark/benchmark_problem.py index c97b26699b6..7490fc3d9a9 100644 --- a/ax/benchmark/benchmark_problem.py +++ b/ax/benchmark/benchmark_problem.py @@ -10,7 +10,6 @@ # `BenchmarkProblem` as return type annotation, used for serialization and rendering # in the UI. -import abc from dataclasses import dataclass, field from typing import ( Any, @@ -92,14 +91,7 @@ class BenchmarkProblemProtocol(Protocol): bool, Dict[str, bool] ] # Whether we observe the observation noise level has_ground_truth: bool # if True, evals (w/o synthetic noise) are determinstic - - @abc.abstractproperty - def runner(self) -> Runner: - pass # pragma: no cover - - -@runtime_checkable -class BenchmarkProblemWithKnownOptimum(Protocol): + runner: Runner optimal_value: float @@ -109,7 +101,8 @@ class BenchmarkProblem(Base): Problem against which diffrent methods can be benchmarked. Defines how data is generated, the objective (via the OptimizationConfig), - and the SearchSpace. + and the SearchSpace. Does not define the runner, which must be handled by + subclasses. Args: name: Can be generated programmatically with `_get_name`.