From 057644b7e5d9b5e324d995e3241aa9deeaaadc42 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Fri, 1 Dec 2023 14:08:57 -0800 Subject: [PATCH] Add `extract_pending_observations` function that auto-deploys to the correct pending-points function for the use case Summary: For automated extraction of pending points, it would be ideal to not have to separately call one of two functions: get_pending... and get_pending_..._based_on_trial_status. I think we only want to go the latter route when: all trials are 1-arm and there are many trials. Planning to use this for pending points extraction from within the ExternalGenerationStrategy (in time we might use this in other places, too, e.g. the Scheduler). Differential Revision: D51684966 --- ax/core/tests/test_utils.py | 48 ++++++++++++++++++++++++++++++++++ ax/core/utils.py | 51 +++++++++++++++++++++++++++++++++---- 2 files changed, 94 insertions(+), 5 deletions(-) diff --git a/ax/core/tests/test_utils.py b/ax/core/tests/test_utils.py index 4904df54906..72269542fa7 100644 --- a/ax/core/tests/test_utils.py +++ b/ax/core/tests/test_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from copy import deepcopy from unittest.mock import patch import numpy as np @@ -21,6 +22,7 @@ from ax.core.types import ComparisonOp from ax.core.utils import ( best_feasible_objective, + extract_pending_observations, get_missing_metrics, get_missing_metrics_by_name, get_model_times, @@ -541,3 +543,49 @@ def test_get_pending_observation_features_based_on_trial_status_hss(self) -> Non self.hss_arm.signature ], ) + + def test_extract_pending_observations(self) -> None: + exp_with_many_trials = get_experiment() + for _ in range(100): + exp_with_many_trials.new_trial().add_arm(self.arm) + + exp_with_many_trials_and_batch = deepcopy(exp_with_many_trials) + exp_with_many_trials_and_batch.new_batch_trial().add_arm(self.arm) + + m = extract_pending_observations.__module__ + with patch(f"{m}.get_pending_observation_features") as mock_pending, patch( + f"{m}.get_pending_observation_features_based_on_trial_status" + ) as mock_pending_ts: + # Check the typical case: few trials, we can use regular `get_pending...`. + extract_pending_observations(experiment=self.experiment) + mock_pending.assert_called_once_with( + experiment=self.experiment, include_out_of_design_points=False + ) + mock_pending.reset_mock() + + # Check out-of-design filter propagation. + extract_pending_observations( + experiment=self.experiment, include_out_of_design_points=True + ) + mock_pending.assert_called_once_with( + experiment=self.experiment, include_out_of_design_points=True + ) + mock_pending.reset_mock() + + # Check many-trials case and out-of-design filter propagation. + extract_pending_observations( + experiment=exp_with_many_trials, include_out_of_design_points=True + ) + mock_pending_ts.assert_called_once_with( + experiment=exp_with_many_trials, include_out_of_design_points=True + ) + + # Check "many-trials but batch trial present" case + # and out-of-design filter propagation. + extract_pending_observations( + experiment=exp_with_many_trials_and_batch, + include_out_of_design_points=True, + ) + mock_pending_ts.assert_called_once_with( + experiment=exp_with_many_trials, include_out_of_design_points=True + ) diff --git a/ax/core/utils.py b/ax/core/utils.py index 7f3bfe6e128..c0c829288bc 100644 --- a/ax/core/utils.py +++ b/ax/core/utils.py @@ -26,6 +26,9 @@ TArmTrial = Tuple[str, int] +# Threshold for switching to pending points extraction based on trial status. +MANY_TRIALS_IN_EXPERIMENT = 100 + # --------------------------- Data intergrity utils. --------------------------- @@ -198,15 +201,52 @@ def get_model_times(experiment: Experiment) -> Tuple[float, float]: # -------------------- Pending observations extraction utils. --------------------- +def extract_pending_observations( + experiment: Experiment, + include_out_of_design_points: bool = False, +) -> Optional[Dict[str, List[ObservationFeatures]]]: + """Computes a list of pending observation features (corresponding to: + - arms that have been generated and run in the course of the experiment, + but have not been completed with data, + - arms that have been abandoned or belong to abandoned trials). + + This function dispatches to: + - ``get_pending_observation_features`` if experiment is using + ``BatchTrial``-s or has fewer than 100 trials, + - ``get_pending_observation_features_based_on_trial_status`` if + experiment is using ``Trial``-s and has more than 100 trials. + + ``get_pending_observation_features_based_on_trial_status`` is a faster + way to compute pending observations, but it is not guaranteed to be + accurate for ``BatchTrial`` settings and makes assumptions, e.g. + arms in ``COMPLETED`` trial never being pending. See docstring of + that function for more details. + + NOTE: Pending observation features are passed to the model to + instruct it to not generate the same points again. + """ + if len(experiment.trials) >= MANY_TRIALS_IN_EXPERIMENT and all( + isinstance(t, Trial) for t in experiment.trials.values() + ): + return get_pending_observation_features_based_on_trial_status( + experiment=experiment, + include_out_of_design_points=include_out_of_design_points, + ) + + return get_pending_observation_features( + experiment=experiment, include_out_of_design_points=include_out_of_design_points + ) + + def get_pending_observation_features( experiment: Experiment, *, include_out_of_design_points: bool = False, ) -> Optional[Dict[str, List[ObservationFeatures]]]: - """Computes a list of pending observation features (corresponding to arms that - have been generated and deployed in the course of the experiment, but have not - been completed with data or to arms that have been abandoned or belong to - abandoned trials). + """Computes a list of pending observation features (corresponding to: + - arms that have been generated and run in the course of the experiment, + but have not been completed with data, + - arms that have been abandoned or belong to abandoned trials). NOTE: Pending observation features are passed to the model to instruct it to not generate the same points again. @@ -273,7 +313,8 @@ def _is_in_design(arm: Arm) -> bool: return pending_features if any(x for x in pending_features.values()) else None -# TODO: allow user to pass search space which overrides that on the experiment. +# TODO: allow user to pass search space which overrides that on the experiment +# (to use for the `include_out_of_design_points` check) def get_pending_observation_features_based_on_trial_status( experiment: Experiment, include_out_of_design_points: bool = False,