From cd3791bc0350fdd1f036bddb85ce2e8e76a48cf9 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 30 Jul 2024 22:34:13 -0700 Subject: [PATCH] Add storage support for SobolQMCNormalSampler Differential Revision: D60432564 --- ax/storage/botorch_modular_registry.py | 3 +++ ax/storage/json_store/encoders.py | 6 ++++++ ax/storage/json_store/registry.py | 3 +++ ax/storage/json_store/tests/test_json_store.py | 18 ++++++++++++++++++ 4 files changed, 30 insertions(+) diff --git a/ax/storage/botorch_modular_registry.py b/ax/storage/botorch_modular_registry.py index 876c5215a85..063e10b1c7a 100644 --- a/ax/storage/botorch_modular_registry.py +++ b/ax/storage/botorch_modular_registry.py @@ -75,6 +75,7 @@ OutcomeTransform, Standardize, ) +from botorch.sampling.normal import SobolQMCNormalSampler # Miscellaneous BoTorch imports from gpytorch.constraints import Interval @@ -168,6 +169,7 @@ Interval: "Interval", GammaPrior: "GammaPrior", LogNormalPrior: "LogNormalPrior", + SobolQMCNormalSampler: "SobolQMCNormalSampler", } """ @@ -205,6 +207,7 @@ LogNormalPrior: GPYTORCH_COMPONENT_REGISTRY, InputTransform: INPUT_TRANSFORM_REGISTRY, OutcomeTransform: OUTCOME_TRANSFORM_REGISTRY, + SobolQMCNormalSampler: GPYTORCH_COMPONENT_REGISTRY, } diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 1b5e8a6540f..4b1839f5930 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -63,6 +63,7 @@ from ax.utils.common.typeutils import not_none from ax.utils.common.typeutils_torch import torch_type_to_str from botorch.models.transforms.input import ChainedInputTransform, InputTransform +from botorch.sampling.base import MCSampler from botorch.utils.types import _DefaultType from torch import Tensor @@ -593,6 +594,11 @@ def botorch_component_to_dict(input_obj: Any) -> Dict[str, Any]: state_dict = botorch_input_transform_to_init_args(input_transform=input_obj) else: state_dict = dict(input_obj.state_dict()) + if isinstance(input_obj, MCSampler): + # The sampler args are not part of the state dict. Manually add them. + # Sample shape cannot be added to torch state dict since it is not a tensor. + state_dict["sample_shape"] = input_obj.sample_shape + state_dict["seed"] = input_obj.seed return { "__type": f"{class_type.__name__}", "index": class_type, diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 19dce07c158..e77a51a0244 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -167,6 +167,7 @@ from botorch.acquisition.acquisition import AcquisitionFunction from botorch.models.model import Model from botorch.models.transforms.input import ChainedInputTransform, Normalize, Round +from botorch.sampling.normal import SobolQMCNormalSampler from botorch.utils.types import DEFAULT from gpytorch.constraints import Interval from gpytorch.likelihoods.likelihood import Likelihood @@ -252,6 +253,7 @@ SearchSpace: search_space_to_dict, SingleDiagnosticBestModelSelector: best_model_selector_to_dict, HierarchicalSearchSpace: search_space_to_dict, + SobolQMCNormalSampler: botorch_component_to_dict, SumConstraint: sum_parameter_constraint_to_dict, Surrogate: surrogate_to_dict, BenchmarkMetric: metric_to_dict, @@ -385,6 +387,7 @@ "SurrogateMetric": BenchmarkMetric, # backward-compatiblity # NOTE: SurrogateRunners -> SyntheticRunner on load due to complications "SurrogateRunner": SyntheticRunner, + "SobolQMCNormalSampler": SobolQMCNormalSampler, "SyntheticRunner": SyntheticRunner, "SurrogateSpec": SurrogateSpec, "Trial": Trial, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 465657d6ae6..7306daf58d6 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -130,6 +130,7 @@ sobol_gpei_generation_node_gs, ) from ax.utils.testing.utils import generic_equals +from botorch.sampling.normal import SobolQMCNormalSampler # pyre-fixme[5]: Global expression must be annotated. @@ -731,3 +732,20 @@ def test_generation_step_backwards_compatibility(self) -> None: generation_step = object_from_json(json) self.assertIsInstance(generation_step, GenerationStep) self.assertEqual(generation_step.model_kwargs, {"other_kwarg": 5}) + + def test_SobolQMCNormalSampler(self) -> None: + # This fails default equality checks, so testing it separately. + sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) + sampler_json = object_to_json( + sampler, + encoder_registry=CORE_ENCODER_REGISTRY, + class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY, + ) + sampler_loaded = object_from_json( + sampler_json, + decoder_registry=CORE_DECODER_REGISTRY, + class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, + ) + self.assertIsInstance(sampler_loaded, SobolQMCNormalSampler) + self.assertEqual(sampler.sample_shape, sampler_loaded.sample_shape) + self.assertEqual(sampler.seed, sampler_loaded.seed)