diff --git a/ax/core/tests/test_utils.py b/ax/core/tests/test_utils.py index 0d64093d0e0..8266398a7f5 100644 --- a/ax/core/tests/test_utils.py +++ b/ax/core/tests/test_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from copy import deepcopy from unittest.mock import patch import numpy as np @@ -21,6 +22,7 @@ from ax.core.types import ComparisonOp from ax.core.utils import ( best_feasible_objective, + extract_pending_observations, get_missing_metrics, get_missing_metrics_by_name, get_model_times, @@ -541,3 +543,49 @@ def test_get_pending_observation_features_based_on_trial_status_hss(self) -> Non self.hss_arm.signature ], ) + + def test_extract_pending_observations(self) -> None: + exp_with_many_trials = get_experiment() + for _ in range(100): + exp_with_many_trials.new_trial().add_arm(self.arm) + + exp_with_many_trials_and_batch = deepcopy(exp_with_many_trials) + exp_with_many_trials_and_batch.new_batch_trial().add_arm(self.arm) + + m = extract_pending_observations.__module__ + with patch(f"{m}.get_pending_observation_features") as mock_pending, patch( + f"{m}.get_pending_observation_features_based_on_trial_status" + ) as mock_pending_ts: + # Check the typical case: few trials, we can use regular `get_pending...`. + extract_pending_observations(experiment=self.experiment) + mock_pending.assert_called_once_with( + experiment=self.experiment, include_out_of_design_points=False + ) + mock_pending.reset_mock() + + # Check out-of-design filter propagation. + extract_pending_observations( + experiment=self.experiment, include_out_of_design_points=True + ) + mock_pending.assert_called_once_with( + experiment=self.experiment, include_out_of_design_points=True + ) + mock_pending.reset_mock() + + # Check many-trials case and out-of-design filter propagation. + extract_pending_observations( + experiment=exp_with_many_trials, include_out_of_design_points=True + ) + mock_pending_ts.assert_called_once_with( + experiment=exp_with_many_trials, include_out_of_design_points=True + ) + + # Check "many-trials but batch trial present" case + # and out-of-design filter propagation. + extract_pending_observations( + experiment=exp_with_many_trials_and_batch, + include_out_of_design_points=True, + ) + mock_pending_ts.assert_called_once_with( + experiment=exp_with_many_trials, include_out_of_design_points=True + ) diff --git a/ax/core/utils.py b/ax/core/utils.py index c249685eff4..3d3d54ed5b1 100644 --- a/ax/core/utils.py +++ b/ax/core/utils.py @@ -26,6 +26,9 @@ TArmTrial = Tuple[str, int] +# Threshold for switching to pending points extraction based on trial status. +MANY_TRIALS_IN_EXPERIMENT = 100 + # --------------------------- Data intergrity utils. --------------------------- @@ -198,15 +201,52 @@ def get_model_times(experiment: Experiment) -> Tuple[float, float]: # -------------------- Pending observations extraction utils. --------------------- +def extract_pending_observations( + experiment: Experiment, + include_out_of_design_points: bool = False, +) -> Optional[Dict[str, List[ObservationFeatures]]]: + """Computes a list of pending observation features (corresponding to: + - arms that have been generated and run in the course of the experiment, + but have not been completed with data, + - arms that have been abandoned or belong to abandoned trials). + + This function dispatches to: + - ``get_pending_observation_features`` if experiment is using + ``BatchTrial``-s or has fewer than 100 trials, + - ``get_pending_observation_features_based_on_trial_status`` if + experiment is using ``Trial``-s and has more than 100 trials. + + ``get_pending_observation_features_based_on_trial_status`` is a faster + way to compute pending observations, but it is not guaranteed to be + accurate for ``BatchTrial`` settings and makes assumptions, e.g. + arms in ``COMPLETED`` trial never being pending. See docstring of + that function for more details. + + NOTE: Pending observation features are passed to the model to + instruct it to not generate the same points again. + """ + if len(experiment.trials) >= MANY_TRIALS_IN_EXPERIMENT and all( + isinstance(t, Trial) for t in experiment.trials.values() + ): + return get_pending_observation_features_based_on_trial_status( + experiment=experiment, + include_out_of_design_points=include_out_of_design_points, + ) + + return get_pending_observation_features( + experiment=experiment, include_out_of_design_points=include_out_of_design_points + ) + + def get_pending_observation_features( experiment: Experiment, *, include_out_of_design_points: bool = False, ) -> Optional[Dict[str, List[ObservationFeatures]]]: - """Computes a list of pending observation features (corresponding to arms that - have been generated and deployed in the course of the experiment, but have not - been completed with data or to arms that have been abandoned or belong to - abandoned trials). + """Computes a list of pending observation features (corresponding to: + - arms that have been generated and run in the course of the experiment, + but have not been completed with data, + - arms that have been abandoned or belong to abandoned trials). NOTE: Pending observation features are passed to the model to instruct it to not generate the same points again. @@ -273,7 +313,8 @@ def _is_in_design(arm: Arm) -> bool: return pending_features if any(x for x in pending_features.values()) else None -# TODO: allow user to pass search space which overrides that on the experiment. +# TODO: allow user to pass search space which overrides that on the experiment +# (to use for the `include_out_of_design_points` check) def get_pending_observation_features_based_on_trial_status( experiment: Experiment, include_out_of_design_points: bool = False,