diff --git a/ax/core/auxiliary.py b/ax/core/auxiliary.py index 507af38025c..748ad187711 100644 --- a/ax/core/auxiliary.py +++ b/ax/core/auxiliary.py @@ -7,6 +7,7 @@ from __future__ import annotations +from enum import Enum, unique from typing import Optional, TYPE_CHECKING from ax.core.data import Data @@ -39,3 +40,8 @@ def _unique_id(self) -> str: # here is only needed w.r.t. parent object ("main experiment", for which # this will be an auxiliary source for). return self.experiment.name + + +@unique +class AuxiliaryExperimentPurpose(Enum): + pass diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 9063bf55928..3cd8a16d5b3 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -15,11 +15,13 @@ from collections.abc import Hashable, Iterable, Mapping from datetime import datetime from functools import partial, reduce + from typing import Any, Optional import ax.core.observation as observation import pandas as pd from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.base_trial import BaseTrial, DEFAULT_STATUSES_TO_WARM_START, TrialStatus from ax.core.batch_trial import BatchTrial, LifecycleStage from ax.core.data import Data @@ -79,6 +81,9 @@ def __init__( experiment_type: Optional[str] = None, properties: Optional[dict[str, Any]] = None, default_data_type: Optional[DataType] = None, + auxiliary_experiments_by_purpose: Optional[ + dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] + ] = None, ) -> None: """Inits Experiment. @@ -94,6 +99,8 @@ def __init__( experiment_type: The class of experiments this one belongs to. properties: Dictionary of this experiment's properties. default_data_type: Enum representing the data type this experiment uses. + auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments + for different purposes (e.g., transfer learning). """ # appease pyre self._search_space: SearchSpace @@ -127,6 +134,10 @@ def __init__( self._arms_by_signature: dict[str, Arm] = {} self._arms_by_name: dict[str, Arm] = {} + self.auxiliary_experiments_by_purpose: dict[ + AuxiliaryExperimentPurpose, list[AuxiliaryExperiment] + ] = (auxiliary_experiments_by_purpose or {}) + self.add_tracking_metrics(tracking_metrics or []) # call setters defined below @@ -1020,14 +1031,14 @@ def trials_by_status(self) -> dict[TrialStatus, list[BaseTrial]]: @property def trials_expecting_data(self) -> list[BaseTrial]: - """List[BaseTrial]: the list of all trials for which data has arrived + """list[BaseTrial]: the list of all trials for which data has arrived or is expected to arrive. """ return [trial for trial in self.trials.values() if trial.status.expecting_data] @property def completed_trials(self) -> list[BaseTrial]: - """List[BaseTrial]: the list of all trials for which data has arrived + """list[BaseTrial]: the list of all trials for which data has arrived or is expected to arrive. """ return self.trials_by_status[TrialStatus.COMPLETED] diff --git a/ax/core/tests/test_auxiliary.py b/ax/core/tests/test_auxiliary.py index 8a816f6a1a7..5df413ad71f 100644 --- a/ax/core/tests/test_auxiliary.py +++ b/ax/core/tests/test_auxiliary.py @@ -19,6 +19,8 @@ def test_AuxiliaryExperiment(self) -> None: self.assertEqual(aux_exp.experiment, exp) self.assertEqual(aux_exp.data, data) - aux_exp = AuxiliaryExperiment(experiment=exp, data=exp.lookup_data()) - self.assertEqual(aux_exp.experiment, exp) - self.assertEqual(aux_exp.data, data) + another_aux_exp = AuxiliaryExperiment( + experiment=exp, data=exp.lookup_data() + ) + self.assertEqual(another_aux_exp.experiment, exp) + self.assertEqual(another_aux_exp.data, data) diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 0f2dfcda85f..4f16a0fefd8 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -8,11 +8,13 @@ import logging from collections import OrderedDict +from enum import unique from unittest.mock import MagicMock, patch import pandas as pd from ax.core import BatchTrial, Experiment, Trial from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.base_trial import TrialStatus from ax.core.data import Data from ax.core.map_data import MapData @@ -52,6 +54,7 @@ get_branin_search_space, get_data, get_experiment, + get_experiment_with_data, get_experiment_with_map_data_type, get_optimization_config, get_scalarized_outcome_constraint, @@ -1470,3 +1473,56 @@ def test_it_does_not_take_both_single_and_multiple_gr_ars(self) -> None: generator_run=gr1, generator_runs=[gr2], ) + + def test_experiment_with_aux_experiments(self) -> None: + @unique + class TestAuxiliaryExperimentPurpose(AuxiliaryExperimentPurpose): + MyAuxExpPurpose = "my_auxiliary_experiment_purpose" + MyOtherAuxExpPurpose = "my_other_auxiliary_experiment_purpose" + + for get_exp_func in [get_experiment, get_experiment_with_data]: + exp = get_exp_func() + data = exp.lookup_data() + + aux_exp = AuxiliaryExperiment(experiment=exp) + another_aux_exp = AuxiliaryExperiment(experiment=exp, data=data) + + # init experiment with auxiliary experiments + exp_w_aux_exp = Experiment( + name="test", + search_space=get_search_space(), + auxiliary_experiments_by_purpose={ + TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp], + }, + ) + + # in-place modification of auxiliary experiments + exp_w_aux_exp.auxiliary_experiments_by_purpose[ + TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose + ] = [aux_exp] + self.assertEqual( + exp_w_aux_exp.auxiliary_experiments_by_purpose, + { + TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp], + TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [aux_exp], + }, + ) + + # test setter + exp_w_aux_exp.auxiliary_experiments_by_purpose = { + TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp], + TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [ + aux_exp, + another_aux_exp, + ], + } + self.assertEqual( + exp_w_aux_exp.auxiliary_experiments_by_purpose, + { + TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp], + TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [ + aux_exp, + another_aux_exp, + ], + }, + ) diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index e66085632c4..640b5edb953 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -14,6 +14,7 @@ from typing import Any, Optional, Union from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.experiment import DataType, Experiment from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective @@ -784,6 +785,9 @@ def make_experiment( objective_thresholds: Optional[list[str]] = None, support_intermediate_data: bool = False, immutable_search_space_and_opt_config: bool = True, + auxiliary_experiments_by_purpose: Optional[ + dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] + ] = None, is_test: bool = False, ) -> Experiment: """Instantiation wrapper that allows for Ax `Experiment` creation @@ -823,6 +827,8 @@ def make_experiment( a product in which it is used), if any. tracking_metric_names: Names of additional tracking metrics not used for optimization. + metric_definitions: A mapping of metric names to extra kwargs to pass + to that metric objectives: Mapping from an objective name to "minimize" or "maximize" representing the direction for that objective. objective_thresholds: A list of objective threshold constraints for multi- @@ -835,10 +841,11 @@ def make_experiment( Defaults to True. If set to True, we won't store or load copies of the search space and optimization config on each generator run, which will improve storage performance. + auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments for + different use cases (e.g., transfer learning). is_test: Whether this experiment will be a test experiment (useful for marking test experiments in storage etc). Defaults to False. - metric_definitions: A mapping of metric names to extra kwargs to pass - to that metric + """ status_quo_arm = None if status_quo is None else Arm(parameters=status_quo) @@ -889,6 +896,7 @@ def make_experiment( tracking_metrics=tracking_metrics, default_data_type=default_data_type, properties=properties, + auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose, is_test=is_test, )