From 1a397088dd139a1c0b5515f3019f6beaf6996025 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Wed, 22 Nov 2023 12:17:17 -0800 Subject: [PATCH] update BestPointMixin to support BatchTrial in benchmarks (#2014) Summary: see title Differential Revision: D51534384 --- ax/service/tests/test_best_point.py | 20 ++++++++++ ax/service/tests/test_best_point_utils.py | 45 +++++++++++++++-------- ax/service/utils/best_point.py | 28 ++++++++------ ax/service/utils/best_point_mixin.py | 34 +++++++++++------ 4 files changed, 89 insertions(+), 38 deletions(-) diff --git a/ax/service/tests/test_best_point.py b/ax/service/tests/test_best_point.py index b706f207c1f..051fbf368c5 100644 --- a/ax/service/tests/test_best_point.py +++ b/ax/service/tests/test_best_point.py @@ -16,6 +16,7 @@ from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast, not_none from ax.utils.testing.core_stubs import ( + get_experiment_with_batch_trial, get_experiment_with_observations, get_experiment_with_trial, ) @@ -95,6 +96,25 @@ def test_get_trace(self) -> None: exp = get_experiment_with_trial() self.assertEqual(get_trace(exp), []) + # test batch trial + exp = get_experiment_with_batch_trial() + trial = exp.trials[0] + exp.optimization_config.outcome_constraints[0].relative = False + trial.mark_running(no_runner_required=True).mark_completed() + for i, arm in enumerate(trial.arms): + df_dict = [ + { + "trial_index": 0, + "metric_name": m, + "arm_name": arm.name, + "mean": float(i), + "sem": 0.0, + } + for m in exp.metrics.keys() + ] + exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict))) + self.assertEqual(get_trace(exp), [len(trial.arms) - 1]) + def test_get_hypervolume(self) -> None: # W/ empty data. exp = get_experiment_with_trial() diff --git a/ax/service/tests/test_best_point_utils.py b/ax/service/tests/test_best_point_utils.py index e0ff4064fec..7f42584bbf0 100644 --- a/ax/service/tests/test_best_point_utils.py +++ b/ax/service/tests/test_best_point_utils.py @@ -372,39 +372,47 @@ def test_extract_Y_from_data(self) -> None: ], dim=-1, ) - Y = extract_Y_from_data( + Y, trial_indices = extract_Y_from_data( experiment=experiment, metric_names=["foo", "bar"], ) + expected_trial_indices = torch.arange(20) self.assertTrue(torch.allclose(Y, expected_Y)) + self.assertTrue(torch.equal(trial_indices, expected_trial_indices)) # Check that it respects ordering of metric names. - Y = extract_Y_from_data( + Y, trial_indices = extract_Y_from_data( experiment=experiment, metric_names=["bar", "foo"], ) self.assertTrue(torch.allclose(Y, expected_Y[:, [1, 0]])) + self.assertTrue(torch.equal(trial_indices, expected_trial_indices)) # Extract partial metrics. - Y = extract_Y_from_data(experiment=experiment, metric_names=["bar"]) + Y, trial_indices = extract_Y_from_data( + experiment=experiment, metric_names=["bar"] + ) self.assertTrue(torch.allclose(Y, expected_Y[:, [1]])) + self.assertTrue(torch.equal(trial_indices, expected_trial_indices)) # Works with messed up ordering of data. clone_dicts = df_dicts.copy() random.shuffle(clone_dicts) experiment._data_by_trial = {} experiment.attach_data(Data(df=pd.DataFrame.from_records(clone_dicts))) - Y = extract_Y_from_data( + Y, trial_indices = extract_Y_from_data( experiment=experiment, metric_names=["foo", "bar"], ) self.assertTrue(torch.allclose(Y, expected_Y)) + self.assertTrue(torch.equal(trial_indices, expected_trial_indices)) # Check that it skips trials that are not completed. experiment.trials[0].mark_running(no_runner_required=True, unsafe=True) experiment.trials[1].mark_abandoned(unsafe=True) - Y = extract_Y_from_data( + Y, trial_indices = extract_Y_from_data( experiment=experiment, metric_names=["foo", "bar"], ) self.assertTrue(torch.allclose(Y, expected_Y[2:])) + self.assertTrue(torch.equal(trial_indices, expected_trial_indices[2:])) # Error with missing data. with self.assertRaisesRegex( @@ -420,11 +428,10 @@ def test_extract_Y_from_data(self) -> None: # Error with extra data. with self.assertRaisesRegex( - UserInputError, "Trial data has more than one row per metric. " + UserInputError, "Trial data has more than one row per arm, metric pair. " ): # Skipping first 5 data points since first two trials are not completed. base_df = pd.DataFrame.from_records(df_dicts[5:]) - extract_Y_from_data( experiment=experiment, metric_names=["foo", "bar"], @@ -433,15 +440,21 @@ def test_extract_Y_from_data(self) -> None: # Check that it errors with BatchTrial. experiment = get_branin_experiment() - BatchTrial(experiment=experiment, index=0).mark_running( - no_runner_required=True - ).mark_completed() - with self.assertRaisesRegex(UnsupportedError, "BatchTrials are not supported."): - extract_Y_from_data( - experiment=experiment, - metric_names=["foo", "bar"], - data=Data(df=pd.DataFrame.from_records(df_0)), - ) + batch_trial = BatchTrial(experiment=experiment, index=0) + batch_trial.add_arm(Arm(name="0_0", parameters={"x1": 0.0, "x2": 0.0})) + batch_trial.add_arm(Arm(name="0_1", parameters={"x1": 1.0, "x2": 0.0})) + batch_trial.mark_running(no_runner_required=True).mark_completed() + pd_df_0 = pd.DataFrame.from_records(df_0) + pd_df_1 = pd.DataFrame.from_records(df_dicts[2:4]) + pd_df_1["arm_name"] = "0_1" + pd_df_1["trial_index"] = 0 + Y, trial_indices = extract_Y_from_data( + experiment=experiment, + metric_names=["foo", "bar"], + data=Data(df=pd.concat([pd_df_0, pd_df_1])), + ) + self.assertTrue(torch.allclose(Y, expected_Y[:2])) + self.assertTrue(torch.equal(trial_indices, torch.zeros(2, dtype=torch.long))) def test_is_row_feasible(self) -> None: exp = get_experiment_with_observations( diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 2af8ebe5201..67131dc5d00 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -780,7 +780,7 @@ def extract_Y_from_data( experiment: Experiment, metric_names: List[str], data: Optional[Data] = None, -) -> Tensor: +) -> Tuple[Tensor, Tensor]: r"""Converts the experiment observation data into a tensor. NOTE: This requires block design for observations. It will @@ -796,11 +796,14 @@ def extract_Y_from_data( each `trial_index` in the `data`. Returns: - A tensor of observed metrics. + A two-element Tuple containing a tensor of observed metrics and a + tensor of trial_indices. """ df = data.df if data is not None else experiment.lookup_data().df if len(df) == 0: - return torch.empty(0, len(metric_names), dtype=torch.double) + return torch.empty(0, len(metric_names), dtype=torch.double), torch.empty( + 0, dtype=torch.long + ) trials_to_use = [] data_to_use = df[df["metric_name"].isin(metric_names)] @@ -810,12 +813,10 @@ def extract_Y_from_data( if trial.status not in [TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED]: # Skip trials that are not completed or early stopped. continue - if isinstance(trial, BatchTrial): - raise UnsupportedError("BatchTrials are not supported.") trials_to_use.append(trial_idx) - if len(trial_data) > len(set(trial_data["metric_name"])): + if len(trial_data) > len(set(trial_data["metric_name"])) * len(trial.arms): raise UserInputError( - "Trial data has more than one row per metric. " + "Trial data has more than one row per arm, metric pair. " f"Got\n\n{trial_data}\n\nfor trial {trial_idx}." ) # We have already ensured that `trial_data` has no metrics not in @@ -830,13 +831,18 @@ def extract_Y_from_data( keeps = df["trial_index"].isin(trials_to_use) if not keeps.any(): - return torch.empty(0, len(metric_names), dtype=torch.double) + return torch.empty(0, len(metric_names), dtype=torch.double), torch.empty( + 0, dtype=torch.long + ) data_as_wide = df[keeps].pivot( - columns="metric_name", index="trial_index", values="mean" + columns="metric_name", index=["trial_index", "arm_name"], values="mean" )[metric_names] - - return torch.tensor(data_as_wide.to_numpy()).to(torch.double) + means = torch.tensor(data_as_wide.to_numpy()).to(torch.double) + trial_indices = torch.tensor( + data_as_wide.reset_index()["trial_index"].to_numpy(), dtype=torch.long + ) + return means, trial_indices def _objective_threshold_from_nadir( diff --git a/ax/service/utils/best_point_mixin.py b/ax/service/utils/best_point_mixin.py index 4a85def54a5..dddc30bf7cd 100644 --- a/ax/service/utils/best_point_mixin.py +++ b/ax/service/utils/best_point_mixin.py @@ -437,7 +437,9 @@ def _get_trace( metric_names.update({cons.metric.name}) metric_names = list(metric_names) # Convert data into a tensor. - Y = extract_Y_from_data(experiment=experiment, metric_names=metric_names) + Y, trial_indices = extract_Y_from_data( + experiment=experiment, metric_names=metric_names + ) if Y.numel() == 0: return [] @@ -508,26 +510,36 @@ def _get_trace( feas = torch.all(torch.stack([c(Y) <= 0 for c in cons_tfs], dim=-1), dim=-1) # Set the infeasible points to reference point or the worst observed value. Y_obj[~feas] = infeas_value + num_trials = int(trial_indices.max().item()) if optimization_config.is_moo_problem: # Compute the hypervolume trace. partitioning = DominatedPartitioning( ref_point=weighted_objective_thresholds.double() ) - # compute hv at each iteration + # compute hv for each iteration (trial_index) hvs = [] - for Yi in Y_obj.split(1): + cumulative_Y = torch.empty( + 0, Y_obj.shape[1], dtype=Y_obj.dtype, device=Y_obj.device + ) + for trial_index in range(num_trials + 1): + new_Y = Y_obj[trial_indices == trial_index] + cumulative_Y = torch.cat([cumulative_Y, new_Y], dim=0) # update with new point - partitioning.update(Y=Yi) + partitioning.update(Y=cumulative_Y) hv = partitioning.compute_hypervolume().item() hvs.append(hv) return hvs - else: - # Find the best observed value. - raw_maximum = np.maximum.accumulate(Y_obj.cpu().numpy()) - if optimization_config.objective.minimize: - # Negate the result if it is a minimization problem. - raw_maximum = -raw_maximum - return raw_maximum.tolist() + running_max = float("-inf") + raw_maximum = np.zeros(num_trials + 1) + # Find the best observed value for each iterations + for trial_index in range(num_trials + 1): + new_Y = Y_obj[trial_indices == trial_index] + running_max = max(running_max, new_Y.max().item()) + raw_maximum[trial_index] = running_max + if optimization_config.objective.minimize: + # Negate the result if it is a minimization problem. + raw_maximum = -raw_maximum + return raw_maximum.tolist() @staticmethod def _get_trace_by_progression(