Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add storage support for SobolQMCNormalSampler #2622

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ax/storage/botorch_modular_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
OutcomeTransform,
Standardize,
)
from botorch.sampling.normal import SobolQMCNormalSampler

# Miscellaneous BoTorch imports
from gpytorch.constraints import Interval
Expand Down Expand Up @@ -168,6 +169,7 @@
Interval: "Interval",
GammaPrior: "GammaPrior",
LogNormalPrior: "LogNormalPrior",
SobolQMCNormalSampler: "SobolQMCNormalSampler",
}

"""
Expand Down Expand Up @@ -205,6 +207,7 @@
LogNormalPrior: GPYTORCH_COMPONENT_REGISTRY,
InputTransform: INPUT_TRANSFORM_REGISTRY,
OutcomeTransform: OUTCOME_TRANSFORM_REGISTRY,
SobolQMCNormalSampler: GPYTORCH_COMPONENT_REGISTRY,
}


Expand Down
6 changes: 6 additions & 0 deletions ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from ax.utils.common.typeutils import not_none
from ax.utils.common.typeutils_torch import torch_type_to_str
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
from botorch.sampling.base import MCSampler
from botorch.utils.types import _DefaultType
from torch import Tensor

Expand Down Expand Up @@ -593,6 +594,11 @@ def botorch_component_to_dict(input_obj: Any) -> Dict[str, Any]:
state_dict = botorch_input_transform_to_init_args(input_transform=input_obj)
else:
state_dict = dict(input_obj.state_dict())
if isinstance(input_obj, MCSampler):
# The sampler args are not part of the state dict. Manually add them.
# Sample shape cannot be added to torch state dict since it is not a tensor.
state_dict["sample_shape"] = input_obj.sample_shape
state_dict["seed"] = input_obj.seed
return {
"__type": f"{class_type.__name__}",
"index": class_type,
Expand Down
3 changes: 3 additions & 0 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.model import Model
from botorch.models.transforms.input import ChainedInputTransform, Normalize, Round
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.types import DEFAULT
from gpytorch.constraints import Interval
from gpytorch.likelihoods.likelihood import Likelihood
Expand Down Expand Up @@ -252,6 +253,7 @@
SearchSpace: search_space_to_dict,
SingleDiagnosticBestModelSelector: best_model_selector_to_dict,
HierarchicalSearchSpace: search_space_to_dict,
SobolQMCNormalSampler: botorch_component_to_dict,
SumConstraint: sum_parameter_constraint_to_dict,
Surrogate: surrogate_to_dict,
BenchmarkMetric: metric_to_dict,
Expand Down Expand Up @@ -385,6 +387,7 @@
"SurrogateMetric": BenchmarkMetric, # backward-compatiblity
# NOTE: SurrogateRunners -> SyntheticRunner on load due to complications
"SurrogateRunner": SyntheticRunner,
"SobolQMCNormalSampler": SobolQMCNormalSampler,
"SyntheticRunner": SyntheticRunner,
"SurrogateSpec": SurrogateSpec,
"Trial": Trial,
Expand Down
18 changes: 18 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
sobol_gpei_generation_node_gs,
)
from ax.utils.testing.utils import generic_equals
from botorch.sampling.normal import SobolQMCNormalSampler


# pyre-fixme[5]: Global expression must be annotated.
Expand Down Expand Up @@ -731,3 +732,20 @@ def test_generation_step_backwards_compatibility(self) -> None:
generation_step = object_from_json(json)
self.assertIsInstance(generation_step, GenerationStep)
self.assertEqual(generation_step.model_kwargs, {"other_kwarg": 5})

def test_SobolQMCNormalSampler(self) -> None:
# This fails default equality checks, so testing it separately.
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2]))
sampler_json = object_to_json(
sampler,
encoder_registry=CORE_ENCODER_REGISTRY,
class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY,
)
sampler_loaded = object_from_json(
sampler_json,
decoder_registry=CORE_DECODER_REGISTRY,
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
)
self.assertIsInstance(sampler_loaded, SobolQMCNormalSampler)
self.assertEqual(sampler.sample_shape, sampler_loaded.sample_shape)
self.assertEqual(sampler.seed, sampler_loaded.seed)