Skip to content

Commit

Permalink
AuxiliaryExperiment (#2632)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2632

Adding AuxiliaryExperiment to allow for experiments and models to use other experiments as auxiliary information during optimization

Meta: Design doc https://docs.google.com/document/d/1MYq4WHPDLoWp-RUyk7jXiqswKARtyoSeM68PRayPGUk/edit?usp=sharing

Reviewed By: saitcakmak, mgarrard, lena-kashtelyan

Differential Revision: D60192602
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Aug 14, 2024
1 parent 15bb8a0 commit 9975af5
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 6 deletions.
41 changes: 41 additions & 0 deletions ax/core/auxiliary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from typing import Optional, TYPE_CHECKING

from ax.core.data import Data
from ax.utils.common.base import SortableBase


if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import core # noqa F401


class AuxiliaryExperiment(SortableBase):
"""Class for defining an auxiliary experiment."""

def __init__(
self,
experiment: core.experiment.Experiment,
data: Optional[Data] = None,
) -> None:
"""
Lightweight container of an experiment, and its data,
that will be used as auxiliary information for another experiment.
"""
self.experiment = experiment
self.data: Data = data or experiment.lookup_data()

def _unique_id(self) -> str:
# While there can be multiple `AuxiliarySource`-s made from the same
# experiment (and thus sharing the experiment name), the uniqueness
# here is only needed w.r.t. parent object ("main experiment", for which
# this will be an auxiliary source for).
return self.experiment.name
24 changes: 24 additions & 0 deletions ax/core/tests/test_auxiliary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ax.core.auxiliary import AuxiliaryExperiment
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment, get_experiment_with_data


class AuxiliaryExperimentTest(TestCase):
def test_AuxiliaryExperiment(self) -> None:
for get_exp_func in [get_experiment, get_experiment_with_data]:
exp = get_exp_func()
data = exp.lookup_data()

# Test init
aux_exp = AuxiliaryExperiment(experiment=exp)
self.assertEqual(aux_exp.experiment, exp)
self.assertEqual(aux_exp.data, data)

aux_exp = AuxiliaryExperiment(experiment=exp, data=exp.lookup_data())
self.assertEqual(aux_exp.experiment, exp)
self.assertEqual(aux_exp.data, data)
3 changes: 1 addition & 2 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
from unittest.mock import MagicMock, patch

import pandas as pd
from ax.core import BatchTrial, Trial
from ax.core import BatchTrial, Experiment, Trial
from ax.core.arm import Arm
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.map_data import MapData
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric
Expand Down
14 changes: 12 additions & 2 deletions ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from typing import Any

from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionBenchmarkProblem
from ax.core import ObservationFeatures
from ax.core import Experiment, ObservationFeatures
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData, MapKeyInfo
from ax.core.metric import Metric
Expand Down Expand Up @@ -710,6 +710,16 @@ def risk_measure_to_dict(
}


def auxiliary_experiment_to_dict(
auxiliary_experiment: AuxiliaryExperiment,
) -> dict[str, Any]:
return {
"__type": auxiliary_experiment.__class__.__name__,
"experiment": auxiliary_experiment.experiment,
"data": auxiliary_experiment.data,
}


def pathlib_to_dict(path: Path) -> dict[str, Any]:
return {"__type": path.__class__.__name__, "pathsegments": [str(path)]}

Expand Down
8 changes: 6 additions & 2 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
)
from ax.benchmark.runners.botorch_test import BotorchTestProblemRunner
from ax.benchmark.runners.surrogate import SurrogateRunner
from ax.core import ObservationFeatures
from ax.core import Experiment, ObservationFeatures
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment
from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import (
AbandonedArm,
Expand All @@ -35,7 +36,7 @@
LifecycleStage,
)
from ax.core.data import Data
from ax.core.experiment import DataType, Experiment
from ax.core.experiment import DataType
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData, MapKeyInfo
from ax.core.map_metric import MapMetric
Expand Down Expand Up @@ -116,6 +117,7 @@
)
from ax.storage.json_store.encoders import (
arm_to_dict,
auxiliary_experiment_to_dict,
batch_to_dict,
best_model_selector_to_dict,
botorch_component_to_dict,
Expand Down Expand Up @@ -181,6 +183,7 @@
# avoid runtime subscripting errors.
CORE_ENCODER_REGISTRY: dict[type, Callable[[Any], dict[str, Any]]] = {
Arm: arm_to_dict,
AuxiliaryExperiment: auxiliary_experiment_to_dict,
AndEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict,
AugmentedBraninMetric: metric_to_dict,
AugmentedHartmann6Metric: metric_to_dict,
Expand Down Expand Up @@ -293,6 +296,7 @@
"AugmentedBraninMetric": AugmentedBraninMetric,
"AugmentedHartmann6Metric": AugmentedHartmann6Metric,
"AutoTransitionAfterGen": AutoTransitionAfterGen,
"AuxiliaryExperiment": AuxiliaryExperiment,
"Arm": Arm,
"AggregatedBenchmarkResult": AggregatedBenchmarkResult,
"BatchTrial": BatchTrial,
Expand Down
2 changes: 2 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
get_arm,
get_augmented_branin_metric,
get_augmented_hartmann_metric,
get_auxiliary_experiment,
get_batch_trial,
get_botorch_model,
get_botorch_model_with_default_acquisition_class,
Expand Down Expand Up @@ -141,6 +142,7 @@
("Arm", get_arm),
("AugmentedBraninMetric", get_augmented_branin_metric),
("AugmentedHartmannMetric", get_augmented_hartmann_metric),
("AuxiliaryExperiment", get_auxiliary_experiment),
("BatchTrial", get_batch_trial),
("BenchmarkMethod", get_sobol_gpei_benchmark_method),
("BenchmarkProblem", get_single_objective_benchmark_problem),
Expand Down
5 changes: 5 additions & 0 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pandas as pd
import torch
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import AbandonedArm, BatchTrial
from ax.core.data import Data
Expand Down Expand Up @@ -879,6 +880,10 @@ def get_high_dimensional_branin_experiment(with_batch: bool = False) -> Experime
return exp


def get_auxiliary_experiment() -> AuxiliaryExperiment:
return AuxiliaryExperiment(experiment=get_experiment_with_data())


##############################
# Search Spaces
##############################
Expand Down
9 changes: 9 additions & 0 deletions sphinx/source/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ Core Classes
:undoc-members:
:show-inheritance:


`AuxiliaryExperiment`
~~~~~~~~~~~~

.. automodule:: ax.core.auxiliary
:members:
:undoc-members:
:show-inheritance:

`GenerationStrategyInterface`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down

0 comments on commit 9975af5

Please sign in to comment.