Skip to content

Commit

Permalink
Add auxiliary_experiments_by_purpose to Experiment (#2634)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2634

Add `auxiliary_experiments_by_purpose ` to `Experiment` object.

This change allows for `Experiment` to hold a dictionary of list of auxiliary experiments for different purposes. This is helpful for methods such as transfer learning or preference learning where we are using other experiments as auxiliary information (e.g., in modelbridge) to facilitate the optimization of the experiment of interest.

Reviewed By: saitcakmak

Differential Revision: D60542566

fbshipit-source-id: 4df696e584138691dc0797eb29b359dfcc2e0f59
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Aug 16, 2024
1 parent 3f82b1d commit 02df121
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 7 deletions.
6 changes: 6 additions & 0 deletions ax/core/auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

from enum import Enum, unique
from typing import Optional, TYPE_CHECKING

from ax.core.data import Data
Expand Down Expand Up @@ -39,3 +40,8 @@ def _unique_id(self) -> str:
# here is only needed w.r.t. parent object ("main experiment", for which
# this will be an auxiliary source for).
return self.experiment.name


@unique
class AuxiliaryExperimentPurpose(Enum):
pass
15 changes: 13 additions & 2 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
from collections.abc import Hashable, Iterable, Mapping
from datetime import datetime
from functools import partial, reduce

from typing import Any, Optional

import ax.core.observation as observation
import pandas as pd
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.base_trial import BaseTrial, DEFAULT_STATUSES_TO_WARM_START, TrialStatus
from ax.core.batch_trial import BatchTrial, LifecycleStage
from ax.core.data import Data
Expand Down Expand Up @@ -79,6 +81,9 @@ def __init__(
experiment_type: Optional[str] = None,
properties: Optional[dict[str, Any]] = None,
default_data_type: Optional[DataType] = None,
auxiliary_experiments_by_purpose: Optional[
dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]
] = None,
) -> None:
"""Inits Experiment.
Expand All @@ -94,6 +99,8 @@ def __init__(
experiment_type: The class of experiments this one belongs to.
properties: Dictionary of this experiment's properties.
default_data_type: Enum representing the data type this experiment uses.
auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments
for different purposes (e.g., transfer learning).
"""
# appease pyre
self._search_space: SearchSpace
Expand Down Expand Up @@ -127,6 +134,10 @@ def __init__(
self._arms_by_signature: dict[str, Arm] = {}
self._arms_by_name: dict[str, Arm] = {}

self.auxiliary_experiments_by_purpose: dict[
AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]
] = (auxiliary_experiments_by_purpose or {})

self.add_tracking_metrics(tracking_metrics or [])

# call setters defined below
Expand Down Expand Up @@ -1020,14 +1031,14 @@ def trials_by_status(self) -> dict[TrialStatus, list[BaseTrial]]:

@property
def trials_expecting_data(self) -> list[BaseTrial]:
"""List[BaseTrial]: the list of all trials for which data has arrived
"""list[BaseTrial]: the list of all trials for which data has arrived
or is expected to arrive.
"""
return [trial for trial in self.trials.values() if trial.status.expecting_data]

@property
def completed_trials(self) -> list[BaseTrial]:
"""List[BaseTrial]: the list of all trials for which data has arrived
"""list[BaseTrial]: the list of all trials for which data has arrived
or is expected to arrive.
"""
return self.trials_by_status[TrialStatus.COMPLETED]
Expand Down
8 changes: 5 additions & 3 deletions ax/core/tests/test_auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def test_AuxiliaryExperiment(self) -> None:
self.assertEqual(aux_exp.experiment, exp)
self.assertEqual(aux_exp.data, data)

aux_exp = AuxiliaryExperiment(experiment=exp, data=exp.lookup_data())
self.assertEqual(aux_exp.experiment, exp)
self.assertEqual(aux_exp.data, data)
another_aux_exp = AuxiliaryExperiment(
experiment=exp, data=exp.lookup_data()
)
self.assertEqual(another_aux_exp.experiment, exp)
self.assertEqual(another_aux_exp.data, data)
56 changes: 56 additions & 0 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

import logging
from collections import OrderedDict
from enum import unique
from unittest.mock import MagicMock, patch

import pandas as pd
from ax.core import BatchTrial, Experiment, Trial
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.core.map_data import MapData
Expand Down Expand Up @@ -52,6 +54,7 @@
get_branin_search_space,
get_data,
get_experiment,
get_experiment_with_data,
get_experiment_with_map_data_type,
get_optimization_config,
get_scalarized_outcome_constraint,
Expand Down Expand Up @@ -1470,3 +1473,56 @@ def test_it_does_not_take_both_single_and_multiple_gr_ars(self) -> None:
generator_run=gr1,
generator_runs=[gr2],
)

def test_experiment_with_aux_experiments(self) -> None:
@unique
class TestAuxiliaryExperimentPurpose(AuxiliaryExperimentPurpose):
MyAuxExpPurpose = "my_auxiliary_experiment_purpose"
MyOtherAuxExpPurpose = "my_other_auxiliary_experiment_purpose"

for get_exp_func in [get_experiment, get_experiment_with_data]:
exp = get_exp_func()
data = exp.lookup_data()

aux_exp = AuxiliaryExperiment(experiment=exp)
another_aux_exp = AuxiliaryExperiment(experiment=exp, data=data)

# init experiment with auxiliary experiments
exp_w_aux_exp = Experiment(
name="test",
search_space=get_search_space(),
auxiliary_experiments_by_purpose={
TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp],
},
)

# in-place modification of auxiliary experiments
exp_w_aux_exp.auxiliary_experiments_by_purpose[
TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose
] = [aux_exp]
self.assertEqual(
exp_w_aux_exp.auxiliary_experiments_by_purpose,
{
TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp],
TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [aux_exp],
},
)

# test setter
exp_w_aux_exp.auxiliary_experiments_by_purpose = {
TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp],
TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [
aux_exp,
another_aux_exp,
],
}
self.assertEqual(
exp_w_aux_exp.auxiliary_experiments_by_purpose,
{
TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp],
TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [
aux_exp,
another_aux_exp,
],
},
)
12 changes: 10 additions & 2 deletions ax/service/utils/instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Optional, Union

from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.experiment import DataType, Experiment
from ax.core.metric import Metric
from ax.core.objective import MultiObjective, Objective
Expand Down Expand Up @@ -784,6 +785,9 @@ def make_experiment(
objective_thresholds: Optional[list[str]] = None,
support_intermediate_data: bool = False,
immutable_search_space_and_opt_config: bool = True,
auxiliary_experiments_by_purpose: Optional[
dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]
] = None,
is_test: bool = False,
) -> Experiment:
"""Instantiation wrapper that allows for Ax `Experiment` creation
Expand Down Expand Up @@ -823,6 +827,8 @@ def make_experiment(
a product in which it is used), if any.
tracking_metric_names: Names of additional tracking metrics not used for
optimization.
metric_definitions: A mapping of metric names to extra kwargs to pass
to that metric
objectives: Mapping from an objective name to "minimize" or "maximize"
representing the direction for that objective.
objective_thresholds: A list of objective threshold constraints for multi-
Expand All @@ -835,10 +841,11 @@ def make_experiment(
Defaults to True. If set to True, we won't store or load copies of the
search space and optimization config on each generator run, which will
improve storage performance.
auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments for
different use cases (e.g., transfer learning).
is_test: Whether this experiment will be a test experiment (useful for
marking test experiments in storage etc). Defaults to False.
metric_definitions: A mapping of metric names to extra kwargs to pass
to that metric
"""
status_quo_arm = None if status_quo is None else Arm(parameters=status_quo)

Expand Down Expand Up @@ -889,6 +896,7 @@ def make_experiment(
tracking_metrics=tracking_metrics,
default_data_type=default_data_type,
properties=properties,
auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose,
is_test=is_test,
)

Expand Down

0 comments on commit 02df121

Please sign in to comment.