Skip to content

Commit

Permalink
update BestPointMixin to support BatchTrial in benchmarks (facebook#2014
Browse files Browse the repository at this point in the history
)

Summary:

see title

Differential Revision: D51534384
  • Loading branch information
sdaulton authored and facebook-github-bot committed Nov 22, 2023
1 parent 50790cc commit 1a39708
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 38 deletions.
20 changes: 20 additions & 0 deletions ax/service/tests/test_best_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.testing.core_stubs import (
get_experiment_with_batch_trial,
get_experiment_with_observations,
get_experiment_with_trial,
)
Expand Down Expand Up @@ -95,6 +96,25 @@ def test_get_trace(self) -> None:
exp = get_experiment_with_trial()
self.assertEqual(get_trace(exp), [])

# test batch trial
exp = get_experiment_with_batch_trial()
trial = exp.trials[0]
exp.optimization_config.outcome_constraints[0].relative = False
trial.mark_running(no_runner_required=True).mark_completed()
for i, arm in enumerate(trial.arms):
df_dict = [
{
"trial_index": 0,
"metric_name": m,
"arm_name": arm.name,
"mean": float(i),
"sem": 0.0,
}
for m in exp.metrics.keys()
]
exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict)))
self.assertEqual(get_trace(exp), [len(trial.arms) - 1])

def test_get_hypervolume(self) -> None:
# W/ empty data.
exp = get_experiment_with_trial()
Expand Down
45 changes: 29 additions & 16 deletions ax/service/tests/test_best_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,39 +372,47 @@ def test_extract_Y_from_data(self) -> None:
],
dim=-1,
)
Y = extract_Y_from_data(
Y, trial_indices = extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
)
expected_trial_indices = torch.arange(20)
self.assertTrue(torch.allclose(Y, expected_Y))
self.assertTrue(torch.equal(trial_indices, expected_trial_indices))
# Check that it respects ordering of metric names.
Y = extract_Y_from_data(
Y, trial_indices = extract_Y_from_data(
experiment=experiment,
metric_names=["bar", "foo"],
)
self.assertTrue(torch.allclose(Y, expected_Y[:, [1, 0]]))
self.assertTrue(torch.equal(trial_indices, expected_trial_indices))
# Extract partial metrics.
Y = extract_Y_from_data(experiment=experiment, metric_names=["bar"])
Y, trial_indices = extract_Y_from_data(
experiment=experiment, metric_names=["bar"]
)
self.assertTrue(torch.allclose(Y, expected_Y[:, [1]]))
self.assertTrue(torch.equal(trial_indices, expected_trial_indices))
# Works with messed up ordering of data.
clone_dicts = df_dicts.copy()
random.shuffle(clone_dicts)
experiment._data_by_trial = {}
experiment.attach_data(Data(df=pd.DataFrame.from_records(clone_dicts)))
Y = extract_Y_from_data(
Y, trial_indices = extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
)
self.assertTrue(torch.allclose(Y, expected_Y))
self.assertTrue(torch.equal(trial_indices, expected_trial_indices))

# Check that it skips trials that are not completed.
experiment.trials[0].mark_running(no_runner_required=True, unsafe=True)
experiment.trials[1].mark_abandoned(unsafe=True)
Y = extract_Y_from_data(
Y, trial_indices = extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
)
self.assertTrue(torch.allclose(Y, expected_Y[2:]))
self.assertTrue(torch.equal(trial_indices, expected_trial_indices[2:]))

# Error with missing data.
with self.assertRaisesRegex(
Expand All @@ -420,11 +428,10 @@ def test_extract_Y_from_data(self) -> None:

# Error with extra data.
with self.assertRaisesRegex(
UserInputError, "Trial data has more than one row per metric. "
UserInputError, "Trial data has more than one row per arm, metric pair. "
):
# Skipping first 5 data points since first two trials are not completed.
base_df = pd.DataFrame.from_records(df_dicts[5:])

extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
Expand All @@ -433,15 +440,21 @@ def test_extract_Y_from_data(self) -> None:

# Check that it errors with BatchTrial.
experiment = get_branin_experiment()
BatchTrial(experiment=experiment, index=0).mark_running(
no_runner_required=True
).mark_completed()
with self.assertRaisesRegex(UnsupportedError, "BatchTrials are not supported."):
extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
data=Data(df=pd.DataFrame.from_records(df_0)),
)
batch_trial = BatchTrial(experiment=experiment, index=0)
batch_trial.add_arm(Arm(name="0_0", parameters={"x1": 0.0, "x2": 0.0}))
batch_trial.add_arm(Arm(name="0_1", parameters={"x1": 1.0, "x2": 0.0}))
batch_trial.mark_running(no_runner_required=True).mark_completed()
pd_df_0 = pd.DataFrame.from_records(df_0)
pd_df_1 = pd.DataFrame.from_records(df_dicts[2:4])
pd_df_1["arm_name"] = "0_1"
pd_df_1["trial_index"] = 0
Y, trial_indices = extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
data=Data(df=pd.concat([pd_df_0, pd_df_1])),
)
self.assertTrue(torch.allclose(Y, expected_Y[:2]))
self.assertTrue(torch.equal(trial_indices, torch.zeros(2, dtype=torch.long)))

def test_is_row_feasible(self) -> None:
exp = get_experiment_with_observations(
Expand Down
28 changes: 17 additions & 11 deletions ax/service/utils/best_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def extract_Y_from_data(
experiment: Experiment,
metric_names: List[str],
data: Optional[Data] = None,
) -> Tensor:
) -> Tuple[Tensor, Tensor]:
r"""Converts the experiment observation data into a tensor.
NOTE: This requires block design for observations. It will
Expand All @@ -796,11 +796,14 @@ def extract_Y_from_data(
each `trial_index` in the `data`.
Returns:
A tensor of observed metrics.
A two-element Tuple containing a tensor of observed metrics and a
tensor of trial_indices.
"""
df = data.df if data is not None else experiment.lookup_data().df
if len(df) == 0:
return torch.empty(0, len(metric_names), dtype=torch.double)
return torch.empty(0, len(metric_names), dtype=torch.double), torch.empty(
0, dtype=torch.long
)

trials_to_use = []
data_to_use = df[df["metric_name"].isin(metric_names)]
Expand All @@ -810,12 +813,10 @@ def extract_Y_from_data(
if trial.status not in [TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED]:
# Skip trials that are not completed or early stopped.
continue
if isinstance(trial, BatchTrial):
raise UnsupportedError("BatchTrials are not supported.")
trials_to_use.append(trial_idx)
if len(trial_data) > len(set(trial_data["metric_name"])):
if len(trial_data) > len(set(trial_data["metric_name"])) * len(trial.arms):
raise UserInputError(
"Trial data has more than one row per metric. "
"Trial data has more than one row per arm, metric pair. "
f"Got\n\n{trial_data}\n\nfor trial {trial_idx}."
)
# We have already ensured that `trial_data` has no metrics not in
Expand All @@ -830,13 +831,18 @@ def extract_Y_from_data(
keeps = df["trial_index"].isin(trials_to_use)

if not keeps.any():
return torch.empty(0, len(metric_names), dtype=torch.double)
return torch.empty(0, len(metric_names), dtype=torch.double), torch.empty(
0, dtype=torch.long
)

data_as_wide = df[keeps].pivot(
columns="metric_name", index="trial_index", values="mean"
columns="metric_name", index=["trial_index", "arm_name"], values="mean"
)[metric_names]

return torch.tensor(data_as_wide.to_numpy()).to(torch.double)
means = torch.tensor(data_as_wide.to_numpy()).to(torch.double)
trial_indices = torch.tensor(
data_as_wide.reset_index()["trial_index"].to_numpy(), dtype=torch.long
)
return means, trial_indices


def _objective_threshold_from_nadir(
Expand Down
34 changes: 23 additions & 11 deletions ax/service/utils/best_point_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,9 @@ def _get_trace(
metric_names.update({cons.metric.name})
metric_names = list(metric_names)
# Convert data into a tensor.
Y = extract_Y_from_data(experiment=experiment, metric_names=metric_names)
Y, trial_indices = extract_Y_from_data(
experiment=experiment, metric_names=metric_names
)
if Y.numel() == 0:
return []

Expand Down Expand Up @@ -508,26 +510,36 @@ def _get_trace(
feas = torch.all(torch.stack([c(Y) <= 0 for c in cons_tfs], dim=-1), dim=-1)
# Set the infeasible points to reference point or the worst observed value.
Y_obj[~feas] = infeas_value
num_trials = int(trial_indices.max().item())
if optimization_config.is_moo_problem:
# Compute the hypervolume trace.
partitioning = DominatedPartitioning(
ref_point=weighted_objective_thresholds.double()
)
# compute hv at each iteration
# compute hv for each iteration (trial_index)
hvs = []
for Yi in Y_obj.split(1):
cumulative_Y = torch.empty(
0, Y_obj.shape[1], dtype=Y_obj.dtype, device=Y_obj.device
)
for trial_index in range(num_trials + 1):
new_Y = Y_obj[trial_indices == trial_index]
cumulative_Y = torch.cat([cumulative_Y, new_Y], dim=0)
# update with new point
partitioning.update(Y=Yi)
partitioning.update(Y=cumulative_Y)
hv = partitioning.compute_hypervolume().item()
hvs.append(hv)
return hvs
else:
# Find the best observed value.
raw_maximum = np.maximum.accumulate(Y_obj.cpu().numpy())
if optimization_config.objective.minimize:
# Negate the result if it is a minimization problem.
raw_maximum = -raw_maximum
return raw_maximum.tolist()
running_max = float("-inf")
raw_maximum = np.zeros(num_trials + 1)
# Find the best observed value for each iterations
for trial_index in range(num_trials + 1):
new_Y = Y_obj[trial_indices == trial_index]
running_max = max(running_max, new_Y.max().item())
raw_maximum[trial_index] = running_max
if optimization_config.objective.minimize:
# Negate the result if it is a minimization problem.
raw_maximum = -raw_maximum
return raw_maximum.tolist()

@staticmethod
def _get_trace_by_progression(
Expand Down

0 comments on commit 1a39708

Please sign in to comment.