Skip to content

Commit

Permalink
Migrate Jenatton to use BenchmarkRunner and BenchmarkMetric (facebook…
Browse files Browse the repository at this point in the history
…#2676)

Summary:
Pull Request resolved: facebook#2676

This PR:
- Has Jenatton use `ParamBasedTestProblem` so that it can use `ParamBasedProblemRunner`, and also have it use `BenchmarkMetric`; get rid of specialized Jenatton runners and metrics. This enables Jenatton to handle noisy problems, whether noise levels are observed or not, like other benchmark problems, and will make it easy to add constraints or benefit from other new functionality.
- Does *not* clean up the now-unnecessary Jennaton metric file; that happens in the next diff.

Differential Revision: D61502458

Reviewed By: Balandat
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 21, 2024
1 parent 4f5c338 commit 1c263b5
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 153 deletions.
69 changes: 1 addition & 68 deletions ax/benchmark/metrics/jenatton.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,78 +5,11 @@

# pyre-strict

from __future__ import annotations
from typing import Optional

from typing import Any, Optional

import numpy as np
import pandas as pd
from ax.benchmark.metrics.base import BenchmarkMetricBase, GroundTruthMetricMixin
from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.core.metric import MetricFetchE, MetricFetchResult
from ax.utils.common.result import Err, Ok
from ax.utils.common.typeutils import not_none


class JenattonMetric(BenchmarkMetricBase):
"""Jenatton metric for hierarchical search spaces."""

has_ground_truth: bool = True

def __init__(
self,
name: str = "jenatton",
noise_std: float = 0.0,
observe_noise_sd: bool = False,
) -> None:
super().__init__(name=name)
self.noise_std = noise_std
self.observe_noise_sd = observe_noise_sd
self.lower_is_better = True

def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
try:
mean = [
jenatton_test_function(**arm.parameters) # pyre-ignore [6]
for _, arm in trial.arms_by_name.items()
]
if self.noise_std != 0:
mean = [m + self.noise_std * np.random.randn() for m in mean]
df = pd.DataFrame(
{
"arm_name": [name for name, _ in trial.arms_by_name.items()],
"metric_name": self.name,
"mean": mean,
"sem": self.noise_std if self.observe_noise_sd else None,
"trial_index": trial.index,
}
)
return Ok(value=Data(df=df))

except Exception as e:
return Err(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)

def make_ground_truth_metric(self) -> GroundTruthJenattonMetric:
return GroundTruthJenattonMetric(original_metric=self)


class GroundTruthJenattonMetric(JenattonMetric, GroundTruthMetricMixin):
def __init__(self, original_metric: JenattonMetric) -> None:
"""
Args:
original_metric: The original JenattonMetric to which this metric
corresponds.
"""
super().__init__(
name=self.get_ground_truth_name(original_metric),
noise_std=0.0,
observe_noise_sd=False,
)


def jenatton_test_function(
x1: Optional[int] = None,
x2: Optional[int] = None,
Expand Down
56 changes: 47 additions & 9 deletions ax/benchmark/problems/synthetic/hss/jenatton.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,52 @@

# pyre-strict

from dataclasses import dataclass
from typing import Optional

import torch
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.metrics.jenatton import JenattonMetric
from ax.benchmark.metrics.benchmark import BenchmarkMetric
from ax.benchmark.metrics.jenatton import jenatton_test_function
from ax.benchmark.runners.botorch_test import (
ParamBasedTestProblem,
ParamBasedTestProblemRunner,
)
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import HierarchicalSearchSpace
from ax.runners.synthetic import SyntheticRunner
from ax.core.types import TParameterization


@dataclass(kw_only=True)
class Jenatton(ParamBasedTestProblem):
r"""Jenatton test function for hierarchical search spaces.
This function is taken from:
R. Jenatton, C. Archambeau, J. González, and M. Seeger. Bayesian
optimization with tree-structured dependencies. ICML 2017.
"""

noise_std: Optional[float] = None
negate: bool = False
num_objectives: int = 1
optimal_value: float = 0.1
_is_constrained: bool = False

def evaluate_true(self, params: TParameterization) -> torch.Tensor:
# pyre-fixme: Incompatible parameter type [6]: In call
# `jenatton_test_function`, for 1st positional argument, expected
# `Optional[float]` but got `Union[None, bool, float, int, str]`.
value = jenatton_test_function(**params)
return torch.tensor(value)


def get_jenatton_benchmark_problem(
num_trials: int = 50,
observe_noise_sd: bool = False,
noise_std: float = 0.0,
) -> BenchmarkProblem:
search_space = HierarchicalSearchSpace(
parameters=[
Expand Down Expand Up @@ -55,24 +89,28 @@ def get_jenatton_benchmark_problem(
),
]
)
name = "Jenatton" + ("_observed_noise" if observe_noise_sd else "")

optimization_config = OptimizationConfig(
objective=Objective(
metric=JenattonMetric(observe_noise_sd=observe_noise_sd),
metric=BenchmarkMetric(
name=name, observe_noise_sd=observe_noise_sd, lower_is_better=True
),
minimize=True,
)
)

name = "Jenatton" + ("_observed_noise" if observe_noise_sd else "")

return BenchmarkProblem(
name=name,
search_space=search_space,
optimization_config=optimization_config,
runner=SyntheticRunner(),
runner=ParamBasedTestProblemRunner(
test_problem_class=Jenatton,
test_problem_kwargs={"noise_std": noise_std},
outcome_names=[name],
),
num_trials=num_trials,
is_noiseless=True,
is_noiseless=noise_std == 0.0,
observe_noise_stds=observe_noise_sd,
has_ground_truth=True,
optimal_value=0.1,
optimal_value=Jenatton.optimal_value,
)
Loading

0 comments on commit 1c263b5

Please sign in to comment.