Skip to content

Commit

Permalink
Move pending point utils to core Ax (facebook#2006)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#2006

No reason for these to live in model bridge + we should be able to import them in GSInterface

Reviewed By: Balandat

Differential Revision: D51437606

fbshipit-source-id: 417555d7ead0869868bb03cf3d862c4df842bc5c
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Nov 20, 2023
1 parent 375bf47 commit 00b7a02
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 173 deletions.
157 changes: 156 additions & 1 deletion ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from typing import Dict, Iterable, List, NamedTuple, Optional, Set, Tuple

import numpy as np
from ax.core.base_trial import BaseTrial
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.objective import MultiObjective

from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import OptimizationConfig

from ax.core.trial import Trial
from ax.core.types import ComparisonOp
from ax.utils.common.typeutils import not_none
from pyre_extensions import none_throws

TArmTrial = Tuple[str, int]
Expand Down Expand Up @@ -188,3 +192,154 @@ def get_model_times(experiment: Experiment) -> Tuple[float, float]:
fit_time = sum((t for t in fit_times if t is not None))
gen_time = sum((t for t in gen_times if t is not None))
return fit_time, gen_time


# -------------------- Pending observations extraction utils. ---------------------


def get_pending_observation_features(
experiment: Experiment, include_failed_as_pending: bool = False
) -> Optional[Dict[str, List[ObservationFeatures]]]:
"""Computes a list of pending observation features (corresponding to arms that
have been generated and deployed in the course of the experiment, but have not
been completed with data or to arms that have been abandoned or belong to
abandoned trials).
NOTE: Pending observation features are passed to the model to
instruct it to not generate the same points again.
Args:
experiment: Experiment, pending features on which we seek to compute.
include_failed_as_pending: Whether to include failed trials as pending
(for example, to avoid the model suggesting them again).
Returns:
An optional mapping from metric names to a list of observation features,
pending for that metric (i.e. do not have evaluation data for that metric).
If there are no pending features for any of the metrics, return is None.
"""
pending_features = {}
# Note that this assumes that if a metric appears in fetched data, the trial is
# not pending for the metric. Where only the most recent data matters, this will
# work, but may need to add logic to check previously added data objects, too.
for trial_index, trial in experiment.trials.items():
dat = trial.lookup_data()
for metric_name in experiment.metrics:
if metric_name not in pending_features:
pending_features[metric_name] = []
include_since_failed = include_failed_as_pending and trial.status.is_failed
if isinstance(trial, BatchTrial):
if trial.status.is_abandoned or (
(trial.status.is_deployed or include_since_failed)
and metric_name not in dat.df.metric_name.values
and trial.arms is not None
):
for arm in trial.arms:
not_none(pending_features.get(metric_name)).append(
ObservationFeatures.from_arm(
arm=arm,
trial_index=np.int64(trial_index),
metadata=trial._get_candidate_metadata(
arm_name=arm.name
),
)
)
abandoned_arms = trial.abandoned_arms
for abandoned_arm in abandoned_arms:
not_none(pending_features.get(metric_name)).append(
ObservationFeatures.from_arm(
arm=abandoned_arm,
trial_index=np.int64(trial_index),
metadata=trial._get_candidate_metadata(
arm_name=abandoned_arm.name
),
)
)

if isinstance(trial, Trial):
if trial.status.is_abandoned or (
(trial.status.is_deployed or include_since_failed)
and metric_name not in dat.df.metric_name.values
and trial.arm is not None
):
not_none(pending_features.get(metric_name)).append(
ObservationFeatures.from_arm(
arm=not_none(trial.arm),
trial_index=np.int64(trial_index),
metadata=trial._get_candidate_metadata(
arm_name=not_none(trial.arm).name
),
)
)
return pending_features if any(x for x in pending_features.values()) else None


# TODO: allow user to pass search space which overrides that on the experiment.
def get_pending_observation_features_based_on_trial_status(
experiment: Experiment,
include_out_of_design_points: bool = False,
) -> Optional[Dict[str, List[ObservationFeatures]]]:
"""A faster analogue of ``get_pending_observation_features`` that makes
assumptions about trials in experiment in order to speed up extraction
of pending points.
Assumptions:
* All arms in all trials in ``STAGED,`` ``RUNNING`` and ``ABANDONED`` statuses
are to be considered pending for all outcomes.
* All arms in all trials in other statuses are to be considered not pending for
all outcomes.
This entails:
* No actual data-fetching for trials to determine whether arms in them are pending
for specific outcomes.
* Even if data is present for some outcomes in ``RUNNING`` trials, their arms will
still be considered pending for those outcomes.
NOTE: This function should not be used to extract pending features in field
experiments, where arms in running trials should not be considered pending if
there is data for those arms.
Args:
experiment: Experiment, pending features on which we seek to compute.
Returns:
An optional mapping from metric names to a list of observation features,
pending for that metric (i.e. do not have evaluation data for that metric).
If there are no pending features for any of the metrics, return is None.
"""
pending_features = defaultdict(list)
for status in [TrialStatus.STAGED, TrialStatus.RUNNING, TrialStatus.ABANDONED]:
for trial in experiment.trials_by_status[status]:
for metric_name in experiment.metrics:
for arm in trial.arms:
if (
not include_out_of_design_points
and not experiment.search_space.check_membership(arm.parameters)
):
continue
pending_features[metric_name].append(
ObservationFeatures.from_arm(
arm=arm,
trial_index=np.int64(trial.index),
metadata=trial._get_candidate_metadata(arm_name=arm.name),
)
)
return dict(pending_features) if any(x for x in pending_features.values()) else None


def extend_pending_observations(
experiment: Experiment,
pending_observations: Dict[str, List[ObservationFeatures]],
generator_run: GeneratorRun,
) -> None:
"""Extend given pending observations dict (from metric name to observations
that are pending for that metric), with arms in a given generator run.
"""
for m in experiment.metrics:
if m not in pending_observations:
pending_observations[m] = []
pending_observations[m].extend(
ObservationFeatures.from_arm(a) for a in generator_run.arms
)
2 changes: 1 addition & 1 deletion ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.utils import extend_pending_observations
from ax.exceptions.core import DataRequiredError, NoDataError, UserInputError
from ax.exceptions.generation_strategy import GenerationStrategyCompleted

from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_node import GenerationStep
from ax.modelbridge.modelbridge_utils import extend_pending_observations
from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase
from ax.utils.common.base import Base
from ax.utils.common.logger import _round_floats_for_logging, get_logger
Expand Down
158 changes: 5 additions & 153 deletions ax/modelbridge/modelbridge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import warnings

from collections import defaultdict
from copy import deepcopy
from functools import partial

Expand All @@ -30,11 +29,8 @@

import numpy as np
import torch
from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.optimization_config import (
Expand All @@ -56,8 +52,12 @@
SearchSpace,
SearchSpaceDigest,
)
from ax.core.trial import Trial
from ax.core.types import TBounds, TCandidateMetadata

from ax.core.utils import ( # noqa F402: Temporary import for backward compatibility.
get_pending_observation_features, # noqa F401
get_pending_observation_features_based_on_trial_status, # noqa F401
)
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.utils import (
Expand Down Expand Up @@ -659,154 +659,6 @@ def _roundtrip_transform(x: np.ndarray) -> np.ndarray:
return _roundtrip_transform


def get_pending_observation_features(
experiment: Experiment, include_failed_as_pending: bool = False
) -> Optional[Dict[str, List[ObservationFeatures]]]:
"""Computes a list of pending observation features (corresponding to arms that
have been generated and deployed in the course of the experiment, but have not
been completed with data or to arms that have been abandoned or belong to
abandoned trials).
NOTE: Pending observation features are passed to the model to
instruct it to not generate the same points again.
Args:
experiment: Experiment, pending features on which we seek to compute.
include_failed_as_pending: Whether to include failed trials as pending
(for example, to avoid the model suggesting them again).
Returns:
An optional mapping from metric names to a list of observation features,
pending for that metric (i.e. do not have evaluation data for that metric).
If there are no pending features for any of the metrics, return is None.
"""
pending_features = {}
# Note that this assumes that if a metric appears in fetched data, the trial is
# not pending for the metric. Where only the most recent data matters, this will
# work, but may need to add logic to check previously added data objects, too.
for trial_index, trial in experiment.trials.items():
dat = trial.lookup_data()
for metric_name in experiment.metrics:
if metric_name not in pending_features:
pending_features[metric_name] = []
include_since_failed = include_failed_as_pending and trial.status.is_failed
if isinstance(trial, BatchTrial):
if trial.status.is_abandoned or (
(trial.status.is_deployed or include_since_failed)
and metric_name not in dat.df.metric_name.values
and trial.arms is not None
):
for arm in trial.arms:
not_none(pending_features.get(metric_name)).append(
ObservationFeatures.from_arm(
arm=arm,
trial_index=np.int64(trial_index),
metadata=trial._get_candidate_metadata(
arm_name=arm.name
),
)
)
abandoned_arms = trial.abandoned_arms
for abandoned_arm in abandoned_arms:
not_none(pending_features.get(metric_name)).append(
ObservationFeatures.from_arm(
arm=abandoned_arm,
trial_index=np.int64(trial_index),
metadata=trial._get_candidate_metadata(
arm_name=abandoned_arm.name
),
)
)

if isinstance(trial, Trial):
if trial.status.is_abandoned or (
(trial.status.is_deployed or include_since_failed)
and metric_name not in dat.df.metric_name.values
and trial.arm is not None
):
not_none(pending_features.get(metric_name)).append(
ObservationFeatures.from_arm(
arm=not_none(trial.arm),
trial_index=np.int64(trial_index),
metadata=trial._get_candidate_metadata(
arm_name=not_none(trial.arm).name
),
)
)
return pending_features if any(x for x in pending_features.values()) else None


# TODO: allow user to pass search space which overrides that on the experiment.
def get_pending_observation_features_based_on_trial_status(
experiment: Experiment,
include_out_of_design_points: bool = False,
) -> Optional[Dict[str, List[ObservationFeatures]]]:
"""A faster analogue of ``get_pending_observation_features`` that makes
assumptions about trials in experiment in order to speed up extraction
of pending points.
Assumptions:
* All arms in all trials in ``STAGED,`` ``RUNNING`` and ``ABANDONED`` statuses
are to be considered pending for all outcomes.
* All arms in all trials in other statuses are to be considered not pending for
all outcomes.
This entails:
* No actual data-fetching for trials to determine whether arms in them are pending
for specific outcomes.
* Even if data is present for some outcomes in ``RUNNING`` trials, their arms will
still be considered pending for those outcomes.
NOTE: This function should not be used to extract pending features in field
experiments, where arms in running trials should not be considered pending if
there is data for those arms.
Args:
experiment: Experiment, pending features on which we seek to compute.
Returns:
An optional mapping from metric names to a list of observation features,
pending for that metric (i.e. do not have evaluation data for that metric).
If there are no pending features for any of the metrics, return is None.
"""
pending_features = defaultdict(list)
for status in [TrialStatus.STAGED, TrialStatus.RUNNING, TrialStatus.ABANDONED]:
for trial in experiment.trials_by_status[status]:
for metric_name in experiment.metrics:
for arm in trial.arms:
if (
not include_out_of_design_points
and not experiment.search_space.check_membership(arm.parameters)
):
continue
pending_features[metric_name].append(
ObservationFeatures.from_arm(
arm=arm,
trial_index=np.int64(trial.index),
metadata=trial._get_candidate_metadata(arm_name=arm.name),
)
)
return dict(pending_features) if any(x for x in pending_features.values()) else None


def extend_pending_observations(
experiment: Experiment,
pending_observations: Dict[str, List[ObservationFeatures]],
generator_run: GeneratorRun,
) -> None:
"""Extend given pending observations dict (from metric name to observations
that are pending for that metric), with arms in a given generator run.
"""
for m in experiment.metrics:
if m not in pending_observations:
pending_observations[m] = []
pending_observations[m].extend(
ObservationFeatures.from_arm(a) for a in generator_run.arms
)


def get_pareto_frontier_and_configs(
modelbridge: modelbridge_module.torch.TorchModelBridge,
observation_features: List[ObservationFeatures],
Expand Down
Loading

0 comments on commit 00b7a02

Please sign in to comment.