Skip to content

Commit

Permalink
Add extract_pending_observations function that auto-deploys to the …
Browse files Browse the repository at this point in the history
…correct pending-points function for the use case (#2039)

Summary:

For automated extraction of pending points, it would be ideal to not have to separately call one of two functions: get_pending... and get_pending_..._based_on_trial_status. I think we only want to go the latter route when: all trials are 1-arm and there are many trials.
Planning to use this for pending points extraction from within the ExternalGenerationStrategy (in time we might use this in other places, too, e.g. the Scheduler).

Differential Revision: D51684966
  • Loading branch information
lena-kashtelyan authored and facebook-github-bot committed Dec 5, 2023
1 parent 79133c7 commit e777720
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 5 deletions.
48 changes: 48 additions & 0 deletions ax/core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
from unittest.mock import patch

import numpy as np
Expand All @@ -21,6 +22,7 @@
from ax.core.types import ComparisonOp
from ax.core.utils import (
best_feasible_objective,
extract_pending_observations,
get_missing_metrics,
get_missing_metrics_by_name,
get_model_times,
Expand Down Expand Up @@ -541,3 +543,49 @@ def test_get_pending_observation_features_based_on_trial_status_hss(self) -> Non
self.hss_arm.signature
],
)

def test_extract_pending_observations(self) -> None:
exp_with_many_trials = get_experiment()
for _ in range(100):
exp_with_many_trials.new_trial().add_arm(self.arm)

exp_with_many_trials_and_batch = deepcopy(exp_with_many_trials)
exp_with_many_trials_and_batch.new_batch_trial().add_arm(self.arm)

m = extract_pending_observations.__module__
with patch(f"{m}.get_pending_observation_features") as mock_pending, patch(
f"{m}.get_pending_observation_features_based_on_trial_status"
) as mock_pending_ts:
# Check the typical case: few trials, we can use regular `get_pending...`.
extract_pending_observations(experiment=self.experiment)
mock_pending.assert_called_once_with(
experiment=self.experiment, include_out_of_design_points=False
)
mock_pending.reset_mock()

# Check out-of-design filter propagation.
extract_pending_observations(
experiment=self.experiment, include_out_of_design_points=True
)
mock_pending.assert_called_once_with(
experiment=self.experiment, include_out_of_design_points=True
)
mock_pending.reset_mock()

# Check many-trials case and out-of-design filter propagation.
extract_pending_observations(
experiment=exp_with_many_trials, include_out_of_design_points=True
)
mock_pending_ts.assert_called_once_with(
experiment=exp_with_many_trials, include_out_of_design_points=True
)

# Check "many-trials but batch trial present" case
# and out-of-design filter propagation.
extract_pending_observations(
experiment=exp_with_many_trials_and_batch,
include_out_of_design_points=True,
)
mock_pending_ts.assert_called_once_with(
experiment=exp_with_many_trials, include_out_of_design_points=True
)
51 changes: 46 additions & 5 deletions ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

TArmTrial = Tuple[str, int]

# Threshold for switching to pending points extraction based on trial status.
MANY_TRIALS_IN_EXPERIMENT = 100


# --------------------------- Data intergrity utils. ---------------------------

Expand Down Expand Up @@ -198,15 +201,52 @@ def get_model_times(experiment: Experiment) -> Tuple[float, float]:
# -------------------- Pending observations extraction utils. ---------------------


def extract_pending_observations(
experiment: Experiment,
include_out_of_design_points: bool = False,
) -> Optional[Dict[str, List[ObservationFeatures]]]:
"""Computes a list of pending observation features (corresponding to:
- arms that have been generated and run in the course of the experiment,
but have not been completed with data,
- arms that have been abandoned or belong to abandoned trials).
This function dispatches to:
- ``get_pending_observation_features`` if experiment is using
``BatchTrial``-s or has fewer than 100 trials,
- ``get_pending_observation_features_based_on_trial_status`` if
experiment is using ``Trial``-s and has more than 100 trials.
``get_pending_observation_features_based_on_trial_status`` is a faster
way to compute pending observations, but it is not guaranteed to be
accurate for ``BatchTrial`` settings and makes assumptions, e.g.
arms in ``COMPLETED`` trial never being pending. See docstring of
that function for more details.
NOTE: Pending observation features are passed to the model to
instruct it to not generate the same points again.
"""
if len(experiment.trials) >= MANY_TRIALS_IN_EXPERIMENT and all(
isinstance(t, Trial) for t in experiment.trials.values()
):
return get_pending_observation_features_based_on_trial_status(
experiment=experiment,
include_out_of_design_points=include_out_of_design_points,
)

return get_pending_observation_features(
experiment=experiment, include_out_of_design_points=include_out_of_design_points
)


def get_pending_observation_features(
experiment: Experiment,
*,
include_out_of_design_points: bool = False,
) -> Optional[Dict[str, List[ObservationFeatures]]]:
"""Computes a list of pending observation features (corresponding to arms that
have been generated and deployed in the course of the experiment, but have not
been completed with data or to arms that have been abandoned or belong to
abandoned trials).
"""Computes a list of pending observation features (corresponding to:
- arms that have been generated and run in the course of the experiment,
but have not been completed with data,
- arms that have been abandoned or belong to abandoned trials).
NOTE: Pending observation features are passed to the model to
instruct it to not generate the same points again.
Expand Down Expand Up @@ -273,7 +313,8 @@ def _is_in_design(arm: Arm) -> bool:
return pending_features if any(x for x in pending_features.values()) else None


# TODO: allow user to pass search space which overrides that on the experiment.
# TODO: allow user to pass search space which overrides that on the experiment
# (to use for the `include_out_of_design_points` check)
def get_pending_observation_features_based_on_trial_status(
experiment: Experiment,
include_out_of_design_points: bool = False,
Expand Down

0 comments on commit e777720

Please sign in to comment.