diff --git a/ax/core/utils.py b/ax/core/utils.py index bd593ea1cf2..7a110a55333 100644 --- a/ax/core/utils.py +++ b/ax/core/utils.py @@ -4,19 +4,23 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import defaultdict from typing import Dict, Iterable, List, NamedTuple, Optional, Set, Tuple import numpy as np -from ax.core.base_trial import BaseTrial +from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.batch_trial import BatchTrial from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.objective import MultiObjective + +from ax.core.observation import ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.trial import Trial from ax.core.types import ComparisonOp +from ax.utils.common.typeutils import not_none from pyre_extensions import none_throws TArmTrial = Tuple[str, int] @@ -188,3 +192,154 @@ def get_model_times(experiment: Experiment) -> Tuple[float, float]: fit_time = sum((t for t in fit_times if t is not None)) gen_time = sum((t for t in gen_times if t is not None)) return fit_time, gen_time + + +# -------------------- Pending observations extraction utils. --------------------- + + +def get_pending_observation_features( + experiment: Experiment, include_failed_as_pending: bool = False +) -> Optional[Dict[str, List[ObservationFeatures]]]: + """Computes a list of pending observation features (corresponding to arms that + have been generated and deployed in the course of the experiment, but have not + been completed with data or to arms that have been abandoned or belong to + abandoned trials). + + NOTE: Pending observation features are passed to the model to + instruct it to not generate the same points again. + + Args: + experiment: Experiment, pending features on which we seek to compute. + include_failed_as_pending: Whether to include failed trials as pending + (for example, to avoid the model suggesting them again). + + Returns: + An optional mapping from metric names to a list of observation features, + pending for that metric (i.e. do not have evaluation data for that metric). + If there are no pending features for any of the metrics, return is None. + """ + pending_features = {} + # Note that this assumes that if a metric appears in fetched data, the trial is + # not pending for the metric. Where only the most recent data matters, this will + # work, but may need to add logic to check previously added data objects, too. + for trial_index, trial in experiment.trials.items(): + dat = trial.lookup_data() + for metric_name in experiment.metrics: + if metric_name not in pending_features: + pending_features[metric_name] = [] + include_since_failed = include_failed_as_pending and trial.status.is_failed + if isinstance(trial, BatchTrial): + if trial.status.is_abandoned or ( + (trial.status.is_deployed or include_since_failed) + and metric_name not in dat.df.metric_name.values + and trial.arms is not None + ): + for arm in trial.arms: + not_none(pending_features.get(metric_name)).append( + ObservationFeatures.from_arm( + arm=arm, + trial_index=np.int64(trial_index), + metadata=trial._get_candidate_metadata( + arm_name=arm.name + ), + ) + ) + abandoned_arms = trial.abandoned_arms + for abandoned_arm in abandoned_arms: + not_none(pending_features.get(metric_name)).append( + ObservationFeatures.from_arm( + arm=abandoned_arm, + trial_index=np.int64(trial_index), + metadata=trial._get_candidate_metadata( + arm_name=abandoned_arm.name + ), + ) + ) + + if isinstance(trial, Trial): + if trial.status.is_abandoned or ( + (trial.status.is_deployed or include_since_failed) + and metric_name not in dat.df.metric_name.values + and trial.arm is not None + ): + not_none(pending_features.get(metric_name)).append( + ObservationFeatures.from_arm( + arm=not_none(trial.arm), + trial_index=np.int64(trial_index), + metadata=trial._get_candidate_metadata( + arm_name=not_none(trial.arm).name + ), + ) + ) + return pending_features if any(x for x in pending_features.values()) else None + + +# TODO: allow user to pass search space which overrides that on the experiment. +def get_pending_observation_features_based_on_trial_status( + experiment: Experiment, + include_out_of_design_points: bool = False, +) -> Optional[Dict[str, List[ObservationFeatures]]]: + """A faster analogue of ``get_pending_observation_features`` that makes + assumptions about trials in experiment in order to speed up extraction + of pending points. + + Assumptions: + + * All arms in all trials in ``STAGED,`` ``RUNNING`` and ``ABANDONED`` statuses + are to be considered pending for all outcomes. + * All arms in all trials in other statuses are to be considered not pending for + all outcomes. + + This entails: + + * No actual data-fetching for trials to determine whether arms in them are pending + for specific outcomes. + * Even if data is present for some outcomes in ``RUNNING`` trials, their arms will + still be considered pending for those outcomes. + + NOTE: This function should not be used to extract pending features in field + experiments, where arms in running trials should not be considered pending if + there is data for those arms. + + Args: + experiment: Experiment, pending features on which we seek to compute. + + Returns: + An optional mapping from metric names to a list of observation features, + pending for that metric (i.e. do not have evaluation data for that metric). + If there are no pending features for any of the metrics, return is None. + """ + pending_features = defaultdict(list) + for status in [TrialStatus.STAGED, TrialStatus.RUNNING, TrialStatus.ABANDONED]: + for trial in experiment.trials_by_status[status]: + for metric_name in experiment.metrics: + for arm in trial.arms: + if ( + not include_out_of_design_points + and not experiment.search_space.check_membership(arm.parameters) + ): + continue + pending_features[metric_name].append( + ObservationFeatures.from_arm( + arm=arm, + trial_index=np.int64(trial.index), + metadata=trial._get_candidate_metadata(arm_name=arm.name), + ) + ) + return dict(pending_features) if any(x for x in pending_features.values()) else None + + +def extend_pending_observations( + experiment: Experiment, + pending_observations: Dict[str, List[ObservationFeatures]], + generator_run: GeneratorRun, +) -> None: + """Extend given pending observations dict (from metric name to observations + that are pending for that metric), with arms in a given generator run. + """ + for m in experiment.metrics: + if m not in pending_observations: + pending_observations[m] = [] + pending_observations[m].extend( + ObservationFeatures.from_arm(a) for a in generator_run.arms + ) diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index defb1ffc582..e2f129b14e9 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -16,12 +16,12 @@ from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures +from ax.core.utils import extend_pending_observations from ax.exceptions.core import DataRequiredError, NoDataError, UserInputError from ax.exceptions.generation_strategy import GenerationStrategyCompleted from ax.modelbridge.base import ModelBridge from ax.modelbridge.generation_node import GenerationStep -from ax.modelbridge.modelbridge_utils import extend_pending_observations from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase from ax.utils.common.base import Base from ax.utils.common.logger import _round_floats_for_logging, get_logger diff --git a/ax/modelbridge/modelbridge_utils.py b/ax/modelbridge/modelbridge_utils.py index 82288e4a358..c110a8c2c8c 100644 --- a/ax/modelbridge/modelbridge_utils.py +++ b/ax/modelbridge/modelbridge_utils.py @@ -8,7 +8,6 @@ import warnings -from collections import defaultdict from copy import deepcopy from functools import partial @@ -30,11 +29,8 @@ import numpy as np import torch -from ax.core.base_trial import TrialStatus -from ax.core.batch_trial import BatchTrial from ax.core.data import Data from ax.core.experiment import Experiment -from ax.core.generator_run import GeneratorRun from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.observation import Observation, ObservationData, ObservationFeatures from ax.core.optimization_config import ( @@ -56,8 +52,12 @@ SearchSpace, SearchSpaceDigest, ) -from ax.core.trial import Trial from ax.core.types import TBounds, TCandidateMetadata + +from ax.core.utils import ( # noqa F402: Temporary import for backward compatibility. + get_pending_observation_features, # noqa F401 + get_pending_observation_features_based_on_trial_status, # noqa F401 +) from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.utils import ( @@ -659,154 +659,6 @@ def _roundtrip_transform(x: np.ndarray) -> np.ndarray: return _roundtrip_transform -def get_pending_observation_features( - experiment: Experiment, include_failed_as_pending: bool = False -) -> Optional[Dict[str, List[ObservationFeatures]]]: - """Computes a list of pending observation features (corresponding to arms that - have been generated and deployed in the course of the experiment, but have not - been completed with data or to arms that have been abandoned or belong to - abandoned trials). - - NOTE: Pending observation features are passed to the model to - instruct it to not generate the same points again. - - Args: - experiment: Experiment, pending features on which we seek to compute. - include_failed_as_pending: Whether to include failed trials as pending - (for example, to avoid the model suggesting them again). - - Returns: - An optional mapping from metric names to a list of observation features, - pending for that metric (i.e. do not have evaluation data for that metric). - If there are no pending features for any of the metrics, return is None. - """ - pending_features = {} - # Note that this assumes that if a metric appears in fetched data, the trial is - # not pending for the metric. Where only the most recent data matters, this will - # work, but may need to add logic to check previously added data objects, too. - for trial_index, trial in experiment.trials.items(): - dat = trial.lookup_data() - for metric_name in experiment.metrics: - if metric_name not in pending_features: - pending_features[metric_name] = [] - include_since_failed = include_failed_as_pending and trial.status.is_failed - if isinstance(trial, BatchTrial): - if trial.status.is_abandoned or ( - (trial.status.is_deployed or include_since_failed) - and metric_name not in dat.df.metric_name.values - and trial.arms is not None - ): - for arm in trial.arms: - not_none(pending_features.get(metric_name)).append( - ObservationFeatures.from_arm( - arm=arm, - trial_index=np.int64(trial_index), - metadata=trial._get_candidate_metadata( - arm_name=arm.name - ), - ) - ) - abandoned_arms = trial.abandoned_arms - for abandoned_arm in abandoned_arms: - not_none(pending_features.get(metric_name)).append( - ObservationFeatures.from_arm( - arm=abandoned_arm, - trial_index=np.int64(trial_index), - metadata=trial._get_candidate_metadata( - arm_name=abandoned_arm.name - ), - ) - ) - - if isinstance(trial, Trial): - if trial.status.is_abandoned or ( - (trial.status.is_deployed or include_since_failed) - and metric_name not in dat.df.metric_name.values - and trial.arm is not None - ): - not_none(pending_features.get(metric_name)).append( - ObservationFeatures.from_arm( - arm=not_none(trial.arm), - trial_index=np.int64(trial_index), - metadata=trial._get_candidate_metadata( - arm_name=not_none(trial.arm).name - ), - ) - ) - return pending_features if any(x for x in pending_features.values()) else None - - -# TODO: allow user to pass search space which overrides that on the experiment. -def get_pending_observation_features_based_on_trial_status( - experiment: Experiment, - include_out_of_design_points: bool = False, -) -> Optional[Dict[str, List[ObservationFeatures]]]: - """A faster analogue of ``get_pending_observation_features`` that makes - assumptions about trials in experiment in order to speed up extraction - of pending points. - - Assumptions: - - * All arms in all trials in ``STAGED,`` ``RUNNING`` and ``ABANDONED`` statuses - are to be considered pending for all outcomes. - * All arms in all trials in other statuses are to be considered not pending for - all outcomes. - - This entails: - - * No actual data-fetching for trials to determine whether arms in them are pending - for specific outcomes. - * Even if data is present for some outcomes in ``RUNNING`` trials, their arms will - still be considered pending for those outcomes. - - NOTE: This function should not be used to extract pending features in field - experiments, where arms in running trials should not be considered pending if - there is data for those arms. - - Args: - experiment: Experiment, pending features on which we seek to compute. - - Returns: - An optional mapping from metric names to a list of observation features, - pending for that metric (i.e. do not have evaluation data for that metric). - If there are no pending features for any of the metrics, return is None. - """ - pending_features = defaultdict(list) - for status in [TrialStatus.STAGED, TrialStatus.RUNNING, TrialStatus.ABANDONED]: - for trial in experiment.trials_by_status[status]: - for metric_name in experiment.metrics: - for arm in trial.arms: - if ( - not include_out_of_design_points - and not experiment.search_space.check_membership(arm.parameters) - ): - continue - pending_features[metric_name].append( - ObservationFeatures.from_arm( - arm=arm, - trial_index=np.int64(trial.index), - metadata=trial._get_candidate_metadata(arm_name=arm.name), - ) - ) - return dict(pending_features) if any(x for x in pending_features.values()) else None - - -def extend_pending_observations( - experiment: Experiment, - pending_observations: Dict[str, List[ObservationFeatures]], - generator_run: GeneratorRun, -) -> None: - """Extend given pending observations dict (from metric name to observations - that are pending for that metric), with arms in a given generator run. - """ - for m in experiment.metrics: - if m not in pending_observations: - pending_observations[m] = [] - pending_observations[m].extend( - ObservationFeatures.from_arm(a) for a in generator_run.arms - ) - - def get_pareto_frontier_and_configs( modelbridge: modelbridge_module.torch.TorchModelBridge, observation_features: List[ObservationFeatures], diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index f0eb7dac71f..2f07c6d8329 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -17,6 +17,9 @@ from ax.core.observation import ObservationFeatures from ax.core.parameter import ChoiceParameter, FixedParameter, Parameter, ParameterType from ax.core.search_space import SearchSpace +from ax.core.utils import ( + get_pending_observation_features_based_on_trial_status as get_pending, +) from ax.exceptions.core import DataRequiredError, UserInputError from ax.exceptions.generation_strategy import ( GenerationStrategyCompleted, @@ -27,9 +30,6 @@ from ax.modelbridge.factory import get_sobol from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.modelbridge_utils import ( - get_pending_observation_features_based_on_trial_status as get_pending, -) from ax.modelbridge.random import RandomModelBridge from ax.modelbridge.registry import Cont_X_trans, MODEL_KEY_TO_MODEL_SETUP, Models from ax.modelbridge.torch import TorchModelBridge diff --git a/ax/modelbridge/tests/test_utils.py b/ax/modelbridge/tests/test_utils.py index 5c453bdb64d..d0a46f0c4d5 100644 --- a/ax/modelbridge/tests/test_utils.py +++ b/ax/modelbridge/tests/test_utils.py @@ -21,11 +21,13 @@ ScalarizedOutcomeConstraint, ) from ax.core.types import ComparisonOp +from ax.core.utils import ( + get_pending_observation_features, + get_pending_observation_features_based_on_trial_status as get_pending_status, +) from ax.modelbridge.modelbridge_utils import ( extract_objective_thresholds, extract_outcome_constraints, - get_pending_observation_features, - get_pending_observation_features_based_on_trial_status as get_pending_status, observation_data_to_array, pending_observations_as_array_list, ) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index fc5ee6d68a9..8bda0cfd095 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -48,6 +48,7 @@ TParameterization, TParamValue, ) +from ax.core.utils import get_pending_observation_features_based_on_trial_status from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.early_stopping.utils import estimate_early_stopping_savings from ax.exceptions.constants import CHOLESKY_ERROR_ANNOTATION @@ -62,9 +63,6 @@ from ax.global_stopping.strategies.improvement import constraint_satisfaction from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.modelbridge_utils import ( - get_pending_observation_features_based_on_trial_status, -) from ax.modelbridge.prediction_utils import predict_by_features from ax.plot.base import AxPlotConfig from ax.plot.contour import plot_contour @@ -1793,8 +1791,7 @@ def _find_last_trial_with_parameterization( @classmethod def _get_pending_observation_features( cls, - # pyre-fixme[2]: Parameter must be annotated. - experiment, + experiment: Experiment, ) -> Optional[Dict[str, List[ObservationFeatures]]]: """Extract pending points for the given experiment. diff --git a/ax/service/managed_loop.py b/ax/service/managed_loop.py index 0e47e145ce7..6dc593ccd7c 100644 --- a/ax/service/managed_loop.py +++ b/ax/service/managed_loop.py @@ -22,12 +22,12 @@ TModelPredictArm, TParameterization, ) +from ax.core.utils import get_pending_observation_features from ax.exceptions.constants import CHOLESKY_ERROR_ANNOTATION from ax.exceptions.core import SearchSpaceExhausted, UserInputError from ax.modelbridge.base import ModelBridge from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.modelbridge_utils import get_pending_observation_features from ax.modelbridge.registry import Models from ax.service.utils.best_point import ( get_best_parameters_from_model_predictions, diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 357f1fadcd1..6df0aa5dd25 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -41,6 +41,8 @@ from ax.core.outcome_constraint import ObjectiveThreshold from ax.core.runner import Runner from ax.core.types import TModelPredictArm, TParameterization +from ax.core.utils import get_pending_observation_features_based_on_trial_status + from ax.early_stopping.utils import estimate_early_stopping_savings from ax.exceptions.core import ( AxError, @@ -52,10 +54,6 @@ from ax.exceptions.generation_strategy import MaxParallelismReachedException from ax.modelbridge.base import ModelBridge from ax.modelbridge.generation_strategy import GenerationStrategy - -from ax.modelbridge.modelbridge_utils import ( - get_pending_observation_features_based_on_trial_status, -) from ax.plot.pareto_utils import infer_reference_point_from_experiment from ax.service.utils.best_point_mixin import BestPointMixin from ax.service.utils.scheduler_options import SchedulerOptions, TrialType diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index 8533294f986..7a589b00c20 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -19,14 +19,12 @@ from ax.core.metric import Metric from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig +from ax.core.utils import get_pending_observation_features_based_on_trial_status from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.exceptions.core import OptimizationComplete, UnsupportedError, UserInputError from ax.metrics.branin import BraninMetric from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.modelbridge_utils import ( - get_pending_observation_features_based_on_trial_status, -) from ax.modelbridge.registry import Models from ax.runners.synthetic import SyntheticRunner from ax.service.scheduler import (