From 02df1211989a668e792658191fbb914db5ded32d Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Thu, 15 Aug 2024 18:05:43 -0700 Subject: [PATCH] Add `auxiliary_experiments_by_purpose` to Experiment (#2634) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2634 Add `auxiliary_experiments_by_purpose ` to `Experiment` object. This change allows for `Experiment` to hold a dictionary of list of auxiliary experiments for different purposes. This is helpful for methods such as transfer learning or preference learning where we are using other experiments as auxiliary information (e.g., in modelbridge) to facilitate the optimization of the experiment of interest. Reviewed By: saitcakmak Differential Revision: D60542566 fbshipit-source-id: 4df696e584138691dc0797eb29b359dfcc2e0f59 --- ax/core/auxiliary.py | 6 ++++ ax/core/experiment.py | 15 +++++++-- ax/core/tests/test_auxiliary.py | 8 +++-- ax/core/tests/test_experiment.py | 56 +++++++++++++++++++++++++++++++ ax/service/utils/instantiation.py | 12 +++++-- 5 files changed, 90 insertions(+), 7 deletions(-) diff --git a/ax/core/auxiliary.py b/ax/core/auxiliary.py index 507af38025c..748ad187711 100644 --- a/ax/core/auxiliary.py +++ b/ax/core/auxiliary.py @@ -7,6 +7,7 @@ from __future__ import annotations +from enum import Enum, unique from typing import Optional, TYPE_CHECKING from ax.core.data import Data @@ -39,3 +40,8 @@ def _unique_id(self) -> str: # here is only needed w.r.t. parent object ("main experiment", for which # this will be an auxiliary source for). return self.experiment.name + + +@unique +class AuxiliaryExperimentPurpose(Enum): + pass diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 9063bf55928..3cd8a16d5b3 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -15,11 +15,13 @@ from collections.abc import Hashable, Iterable, Mapping from datetime import datetime from functools import partial, reduce + from typing import Any, Optional import ax.core.observation as observation import pandas as pd from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.base_trial import BaseTrial, DEFAULT_STATUSES_TO_WARM_START, TrialStatus from ax.core.batch_trial import BatchTrial, LifecycleStage from ax.core.data import Data @@ -79,6 +81,9 @@ def __init__( experiment_type: Optional[str] = None, properties: Optional[dict[str, Any]] = None, default_data_type: Optional[DataType] = None, + auxiliary_experiments_by_purpose: Optional[ + dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] + ] = None, ) -> None: """Inits Experiment. @@ -94,6 +99,8 @@ def __init__( experiment_type: The class of experiments this one belongs to. properties: Dictionary of this experiment's properties. default_data_type: Enum representing the data type this experiment uses. + auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments + for different purposes (e.g., transfer learning). """ # appease pyre self._search_space: SearchSpace @@ -127,6 +134,10 @@ def __init__( self._arms_by_signature: dict[str, Arm] = {} self._arms_by_name: dict[str, Arm] = {} + self.auxiliary_experiments_by_purpose: dict[ + AuxiliaryExperimentPurpose, list[AuxiliaryExperiment] + ] = (auxiliary_experiments_by_purpose or {}) + self.add_tracking_metrics(tracking_metrics or []) # call setters defined below @@ -1020,14 +1031,14 @@ def trials_by_status(self) -> dict[TrialStatus, list[BaseTrial]]: @property def trials_expecting_data(self) -> list[BaseTrial]: - """List[BaseTrial]: the list of all trials for which data has arrived + """list[BaseTrial]: the list of all trials for which data has arrived or is expected to arrive. """ return [trial for trial in self.trials.values() if trial.status.expecting_data] @property def completed_trials(self) -> list[BaseTrial]: - """List[BaseTrial]: the list of all trials for which data has arrived + """list[BaseTrial]: the list of all trials for which data has arrived or is expected to arrive. """ return self.trials_by_status[TrialStatus.COMPLETED] diff --git a/ax/core/tests/test_auxiliary.py b/ax/core/tests/test_auxiliary.py index 8a816f6a1a7..5df413ad71f 100644 --- a/ax/core/tests/test_auxiliary.py +++ b/ax/core/tests/test_auxiliary.py @@ -19,6 +19,8 @@ def test_AuxiliaryExperiment(self) -> None: self.assertEqual(aux_exp.experiment, exp) self.assertEqual(aux_exp.data, data) - aux_exp = AuxiliaryExperiment(experiment=exp, data=exp.lookup_data()) - self.assertEqual(aux_exp.experiment, exp) - self.assertEqual(aux_exp.data, data) + another_aux_exp = AuxiliaryExperiment( + experiment=exp, data=exp.lookup_data() + ) + self.assertEqual(another_aux_exp.experiment, exp) + self.assertEqual(another_aux_exp.data, data) diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 0f2dfcda85f..4f16a0fefd8 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -8,11 +8,13 @@ import logging from collections import OrderedDict +from enum import unique from unittest.mock import MagicMock, patch import pandas as pd from ax.core import BatchTrial, Experiment, Trial from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.base_trial import TrialStatus from ax.core.data import Data from ax.core.map_data import MapData @@ -52,6 +54,7 @@ get_branin_search_space, get_data, get_experiment, + get_experiment_with_data, get_experiment_with_map_data_type, get_optimization_config, get_scalarized_outcome_constraint, @@ -1470,3 +1473,56 @@ def test_it_does_not_take_both_single_and_multiple_gr_ars(self) -> None: generator_run=gr1, generator_runs=[gr2], ) + + def test_experiment_with_aux_experiments(self) -> None: + @unique + class TestAuxiliaryExperimentPurpose(AuxiliaryExperimentPurpose): + MyAuxExpPurpose = "my_auxiliary_experiment_purpose" + MyOtherAuxExpPurpose = "my_other_auxiliary_experiment_purpose" + + for get_exp_func in [get_experiment, get_experiment_with_data]: + exp = get_exp_func() + data = exp.lookup_data() + + aux_exp = AuxiliaryExperiment(experiment=exp) + another_aux_exp = AuxiliaryExperiment(experiment=exp, data=data) + + # init experiment with auxiliary experiments + exp_w_aux_exp = Experiment( + name="test", + search_space=get_search_space(), + auxiliary_experiments_by_purpose={ + TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp], + }, + ) + + # in-place modification of auxiliary experiments + exp_w_aux_exp.auxiliary_experiments_by_purpose[ + TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose + ] = [aux_exp] + self.assertEqual( + exp_w_aux_exp.auxiliary_experiments_by_purpose, + { + TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp], + TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [aux_exp], + }, + ) + + # test setter + exp_w_aux_exp.auxiliary_experiments_by_purpose = { + TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp], + TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [ + aux_exp, + another_aux_exp, + ], + } + self.assertEqual( + exp_w_aux_exp.auxiliary_experiments_by_purpose, + { + TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp], + TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [ + aux_exp, + another_aux_exp, + ], + }, + ) diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index e66085632c4..640b5edb953 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -14,6 +14,7 @@ from typing import Any, Optional, Union from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.experiment import DataType, Experiment from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective @@ -784,6 +785,9 @@ def make_experiment( objective_thresholds: Optional[list[str]] = None, support_intermediate_data: bool = False, immutable_search_space_and_opt_config: bool = True, + auxiliary_experiments_by_purpose: Optional[ + dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] + ] = None, is_test: bool = False, ) -> Experiment: """Instantiation wrapper that allows for Ax `Experiment` creation @@ -823,6 +827,8 @@ def make_experiment( a product in which it is used), if any. tracking_metric_names: Names of additional tracking metrics not used for optimization. + metric_definitions: A mapping of metric names to extra kwargs to pass + to that metric objectives: Mapping from an objective name to "minimize" or "maximize" representing the direction for that objective. objective_thresholds: A list of objective threshold constraints for multi- @@ -835,10 +841,11 @@ def make_experiment( Defaults to True. If set to True, we won't store or load copies of the search space and optimization config on each generator run, which will improve storage performance. + auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments for + different use cases (e.g., transfer learning). is_test: Whether this experiment will be a test experiment (useful for marking test experiments in storage etc). Defaults to False. - metric_definitions: A mapping of metric names to extra kwargs to pass - to that metric + """ status_quo_arm = None if status_quo is None else Arm(parameters=status_quo) @@ -889,6 +896,7 @@ def make_experiment( tracking_metrics=tracking_metrics, default_data_type=default_data_type, properties=properties, + auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose, is_test=is_test, )