From 9975af5be94a5295a685e9f05f589b5357ad23a1 Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Wed, 14 Aug 2024 12:27:31 -0700 Subject: [PATCH] AuxiliaryExperiment (#2632) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2632 Adding AuxiliaryExperiment to allow for experiments and models to use other experiments as auxiliary information during optimization Meta: Design doc https://docs.google.com/document/d/1MYq4WHPDLoWp-RUyk7jXiqswKARtyoSeM68PRayPGUk/edit?usp=sharing Reviewed By: saitcakmak, mgarrard, lena-kashtelyan Differential Revision: D60192602 --- ax/core/auxiliary.py | 41 +++++++++++++++++++ ax/core/tests/test_auxiliary.py | 24 +++++++++++ ax/core/tests/test_experiment.py | 3 +- ax/storage/json_store/encoders.py | 14 ++++++- ax/storage/json_store/registry.py | 8 +++- .../json_store/tests/test_json_store.py | 2 + ax/utils/testing/core_stubs.py | 5 +++ sphinx/source/core.rst | 9 ++++ 8 files changed, 100 insertions(+), 6 deletions(-) create mode 100644 ax/core/auxiliary.py create mode 100644 ax/core/tests/test_auxiliary.py diff --git a/ax/core/auxiliary.py b/ax/core/auxiliary.py new file mode 100644 index 00000000000..507af38025c --- /dev/null +++ b/ax/core/auxiliary.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING + +from ax.core.data import Data +from ax.utils.common.base import SortableBase + + +if TYPE_CHECKING: + # import as module to make sphinx-autodoc-typehints happy + from ax import core # noqa F401 + + +class AuxiliaryExperiment(SortableBase): + """Class for defining an auxiliary experiment.""" + + def __init__( + self, + experiment: core.experiment.Experiment, + data: Optional[Data] = None, + ) -> None: + """ + Lightweight container of an experiment, and its data, + that will be used as auxiliary information for another experiment. + """ + self.experiment = experiment + self.data: Data = data or experiment.lookup_data() + + def _unique_id(self) -> str: + # While there can be multiple `AuxiliarySource`-s made from the same + # experiment (and thus sharing the experiment name), the uniqueness + # here is only needed w.r.t. parent object ("main experiment", for which + # this will be an auxiliary source for). + return self.experiment.name diff --git a/ax/core/tests/test_auxiliary.py b/ax/core/tests/test_auxiliary.py new file mode 100644 index 00000000000..8a816f6a1a7 --- /dev/null +++ b/ax/core/tests/test_auxiliary.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ax.core.auxiliary import AuxiliaryExperiment +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_experiment, get_experiment_with_data + + +class AuxiliaryExperimentTest(TestCase): + def test_AuxiliaryExperiment(self) -> None: + for get_exp_func in [get_experiment, get_experiment_with_data]: + exp = get_exp_func() + data = exp.lookup_data() + + # Test init + aux_exp = AuxiliaryExperiment(experiment=exp) + self.assertEqual(aux_exp.experiment, exp) + self.assertEqual(aux_exp.data, data) + + aux_exp = AuxiliaryExperiment(experiment=exp, data=exp.lookup_data()) + self.assertEqual(aux_exp.experiment, exp) + self.assertEqual(aux_exp.data, data) diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index de450f0f84c..0f2dfcda85f 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -11,11 +11,10 @@ from unittest.mock import MagicMock, patch import pandas as pd -from ax.core import BatchTrial, Trial +from ax.core import BatchTrial, Experiment, Trial from ax.core.arm import Arm from ax.core.base_trial import TrialStatus from ax.core.data import Data -from ax.core.experiment import Experiment from ax.core.map_data import MapData from ax.core.map_metric import MapMetric from ax.core.metric import Metric diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 34c74336652..6b46d0bfdf3 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -12,11 +12,11 @@ from typing import Any from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionBenchmarkProblem -from ax.core import ObservationFeatures +from ax.core import Experiment, ObservationFeatures from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperiment from ax.core.batch_trial import BatchTrial from ax.core.data import Data -from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData, MapKeyInfo from ax.core.metric import Metric @@ -710,6 +710,16 @@ def risk_measure_to_dict( } +def auxiliary_experiment_to_dict( + auxiliary_experiment: AuxiliaryExperiment, +) -> dict[str, Any]: + return { + "__type": auxiliary_experiment.__class__.__name__, + "experiment": auxiliary_experiment.experiment, + "data": auxiliary_experiment.data, + } + + def pathlib_to_dict(path: Path) -> dict[str, Any]: return {"__type": path.__class__.__name__, "pathsegments": [str(path)]} diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 98e3b2d6b92..46ee42b8d6c 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -25,8 +25,9 @@ ) from ax.benchmark.runners.botorch_test import BotorchTestProblemRunner from ax.benchmark.runners.surrogate import SurrogateRunner -from ax.core import ObservationFeatures +from ax.core import Experiment, ObservationFeatures from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperiment from ax.core.base_trial import TrialStatus from ax.core.batch_trial import ( AbandonedArm, @@ -35,7 +36,7 @@ LifecycleStage, ) from ax.core.data import Data -from ax.core.experiment import DataType, Experiment +from ax.core.experiment import DataType from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData, MapKeyInfo from ax.core.map_metric import MapMetric @@ -116,6 +117,7 @@ ) from ax.storage.json_store.encoders import ( arm_to_dict, + auxiliary_experiment_to_dict, batch_to_dict, best_model_selector_to_dict, botorch_component_to_dict, @@ -181,6 +183,7 @@ # avoid runtime subscripting errors. CORE_ENCODER_REGISTRY: dict[type, Callable[[Any], dict[str, Any]]] = { Arm: arm_to_dict, + AuxiliaryExperiment: auxiliary_experiment_to_dict, AndEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict, AugmentedBraninMetric: metric_to_dict, AugmentedHartmann6Metric: metric_to_dict, @@ -293,6 +296,7 @@ "AugmentedBraninMetric": AugmentedBraninMetric, "AugmentedHartmann6Metric": AugmentedHartmann6Metric, "AutoTransitionAfterGen": AutoTransitionAfterGen, + "AuxiliaryExperiment": AuxiliaryExperiment, "Arm": Arm, "AggregatedBenchmarkResult": AggregatedBenchmarkResult, "BatchTrial": BatchTrial, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 7b78fbf7fba..02e84bc9326 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -59,6 +59,7 @@ get_arm, get_augmented_branin_metric, get_augmented_hartmann_metric, + get_auxiliary_experiment, get_batch_trial, get_botorch_model, get_botorch_model_with_default_acquisition_class, @@ -141,6 +142,7 @@ ("Arm", get_arm), ("AugmentedBraninMetric", get_augmented_branin_metric), ("AugmentedHartmannMetric", get_augmented_hartmann_metric), + ("AuxiliaryExperiment", get_auxiliary_experiment), ("BatchTrial", get_batch_trial), ("BenchmarkMethod", get_sobol_gpei_benchmark_method), ("BenchmarkProblem", get_single_objective_benchmark_problem), diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 8d9b836696c..5d9caa3fbe1 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -20,6 +20,7 @@ import pandas as pd import torch from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperiment from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.batch_trial import AbandonedArm, BatchTrial from ax.core.data import Data @@ -879,6 +880,10 @@ def get_high_dimensional_branin_experiment(with_batch: bool = False) -> Experime return exp +def get_auxiliary_experiment() -> AuxiliaryExperiment: + return AuxiliaryExperiment(experiment=get_experiment_with_data()) + + ############################## # Search Spaces ############################## diff --git a/sphinx/source/core.rst b/sphinx/source/core.rst index 85ad971d468..9075d61d693 100644 --- a/sphinx/source/core.rst +++ b/sphinx/source/core.rst @@ -52,6 +52,15 @@ Core Classes :undoc-members: :show-inheritance: + +`AuxiliaryExperiment` +~~~~~~~~~~~~ + +.. automodule:: ax.core.auxiliary + :members: + :undoc-members: + :show-inheritance: + `GenerationStrategyInterface` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~