Skip to content

Commit

Permalink
Merge 24455d9 into 3985791
Browse files Browse the repository at this point in the history
  • Loading branch information
saitcakmak authored Nov 16, 2023
2 parents 3985791 + 24455d9 commit aa5e700
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 41 deletions.
9 changes: 2 additions & 7 deletions ax/core/generation_strategy_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, Iterable, List, Optional, Tuple
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.optimization_config import (
MultiObjectiveOptimizationConfig,
OptimizationConfig,
)
from ax.core.types import TModelPredictArm, TParameterization
from ax.utils.common.base import Base
from ax.utils.common.typeutils import not_none

Expand Down
14 changes: 12 additions & 2 deletions ax/models/tests/test_botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from botorch.acquisition.objective import ConstrainedMCObjective
from botorch.acquisition.penalized import L1PenaltyObjective, PenalizedMCObjective
from botorch.exceptions.errors import UnsupportedError
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.multitask import MultiTaskGP
Expand Down Expand Up @@ -215,6 +216,15 @@ def test_get_model(self) -> None:
self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
self.assertEqual(covar_module, model.covar_module)

# test input warping dimension checks.
with self.assertRaisesRegex(UnsupportedError, "batched multi output models"):
_get_model(
X=torch.ones(4, 3, 2),
Y=torch.ones(4, 3, 2),
Yvar=torch.zeros(4, 3, 2),
use_input_warping=True,
)

@mock.patch("ax.models.torch.botorch_defaults._get_model", wraps=_get_model)
@fast_botorch_optimize
# pyre-fixme[3]: Return type must be annotated.
Expand Down Expand Up @@ -468,7 +478,7 @@ def test_get_customized_covar_module(self) -> None:
covar_module = _get_customized_covar_module(
covar_module_prior_dict={},
ard_num_dims=ard_num_dims,
batch_shape=batch_shape,
aug_batch_shape=batch_shape,
task_feature=None,
)
self.assertIsInstance(covar_module, Module)
Expand All @@ -495,7 +505,7 @@ def test_get_customized_covar_module(self) -> None:
"outputscale_prior": GammaPrior(2.0, 12.0),
},
ard_num_dims=ard_num_dims,
batch_shape=batch_shape,
aug_batch_shape=batch_shape,
task_feature=3,
)
self.assertIsInstance(covar_module, Module)
Expand Down
24 changes: 14 additions & 10 deletions ax/models/torch/botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,6 @@ def _get_model(
is_nan = torch.isnan(Yvar)
any_nan_Yvar = torch.any(is_nan)
all_nan_Yvar = torch.all(is_nan)
batch_shape = _get_batch_shape(X, Y)
if any_nan_Yvar and not all_nan_Yvar:
if task_feature:
# TODO (jej): Replace with inferred noise before making perf judgements.
Expand All @@ -722,10 +721,14 @@ def _get_model(
"errors. Variances should all be specified, or none should be."
)
if use_input_warping:
if Y.shape[-1] > 1 and X.ndim > 2:
raise UnsupportedError(
"Input warping is not supported for batched multi output models."
)
warp_tf = get_warping_transform(
d=X.shape[-1],
task_feature=task_feature,
batch_shape=batch_shape,
batch_shape=X.shape[:-2],
)
else:
warp_tf = None
Expand All @@ -741,7 +744,7 @@ def _get_model(
covar_module = _get_customized_covar_module(
covar_module_prior_dict=covar_module_prior_dict,
ard_num_dims=X.shape[-1],
batch_shape=batch_shape,
aug_batch_shape=_get_aug_batch_shape(X, Y),
task_feature=task_feature,
)

Expand Down Expand Up @@ -804,17 +807,18 @@ def _get_model(
def _get_customized_covar_module(
covar_module_prior_dict: Dict[str, Prior],
ard_num_dims: int,
batch_shape: torch.Size,
aug_batch_shape: torch.Size,
task_feature: Optional[int] = None,
) -> Kernel:
"""Construct a GP kernel based on customized prior dict.
Args:
covar_module_prior_dict: Dict. The keys are the names of the prior and values
are the priors. e.g. {"lengthscale_prior": GammaPrior(3.0, 6.0)}.
ard_num_dims: The dimension of the input, including task features
batch_shape: The batch_shape of the model
task_feature: The index of the task feature
ard_num_dims: The dimension of the inputs, including task features.
aug_batch_shape: The output dimension augmented batch shape of the model
(different from the batch shape for batched multi-output models).
task_feature: The index of the task feature.
"""
# TODO: add more checks of covar_module_prior_dict
if task_feature is not None:
Expand All @@ -823,19 +827,19 @@ def _get_customized_covar_module(
MaternKernel(
nu=2.5,
ard_num_dims=ard_num_dims,
batch_shape=batch_shape,
batch_shape=aug_batch_shape,
lengthscale_prior=covar_module_prior_dict.get(
"lengthscale_prior", GammaPrior(3.0, 6.0)
),
),
batch_shape=batch_shape,
batch_shape=aug_batch_shape,
outputscale_prior=covar_module_prior_dict.get(
"outputscale_prior", GammaPrior(2.0, 0.15)
),
)


def _get_batch_shape(X: Tensor, Y: Tensor) -> torch.Size:
def _get_aug_batch_shape(X: Tensor, Y: Tensor) -> torch.Size:
"""Obtain the output-augmented batch shape of GP model.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.models.torch.botorch_defaults import _get_batch_shape
from ax.models.torch.utils import normalize_indices
from ax.utils.common.typeutils import _argparse_type_encoder
from botorch.models.transforms.input import (
Expand All @@ -20,7 +19,7 @@
Warp,
)
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher


Expand Down Expand Up @@ -75,22 +74,15 @@ def _input_transform_argparse_warp(
Returns:
A dictionary with input transform kwargs.
"""

input_transform_options = input_transform_options or {}
d = dataset.X.shape[-1]

d = len(dataset.feature_names)
indices = list(range(d))

task_features = normalize_indices(search_space_digest.task_features, d=d)

for task_feature in sorted(task_features, reverse=True):
del indices[task_feature]

batch_shape = _get_batch_shape(dataset.X, dataset.Y)

input_transform_options.setdefault("indices", indices)
input_transform_options.setdefault("batch_shape", batch_shape)

return input_transform_options


Expand Down Expand Up @@ -118,18 +110,15 @@ def _input_transform_argparse_normalize(
Returns:
A dictionary with input transform kwargs.
"""

input_transform_options = input_transform_options or {}

d = dataset.X.shape[-1]

d = input_transform_options.get("d", len(dataset.feature_names))
bounds = torch.as_tensor(
search_space_digest.bounds,
dtype=torch_dtype,
device=torch_device,
).T

if isinstance(dataset.X, SliceContainer):
if isinstance(dataset, RankingDataset) and isinstance(dataset.X, SliceContainer):
d = dataset.X.values.shape[-1]

indices = list(range(d))
Expand Down
33 changes: 26 additions & 7 deletions ax/models/torch/tests/test_input_transform_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Normalize,
Warp,
)
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset


class DummyInputTransform(InputTransform): # pyre-ignore [13]
Expand Down Expand Up @@ -92,7 +92,6 @@ def test_argparse_input_transform(self) -> None:
self.assertEqual(input_transform_kwargs, {"d": 10})

def test_argparse_normalize(self) -> None:

input_transform_kwargs = input_transform_argparse(
Normalize,
dataset=self.dataset,
Expand Down Expand Up @@ -137,8 +136,30 @@ def test_argparse_normalize(self) -> None:
)
)

def test_argparse_warp(self) -> None:
# Test with MultiTaskDataset.
dataset1 = SupervisedDataset(
X=torch.rand(5, 4),
Y=torch.randn(5, 1),
feature_names=[f"x{i}" for i in range(4)],
outcome_names=["y0"],
)

dataset2 = SupervisedDataset(
X=torch.rand(5, 2),
Y=torch.randn(5, 1),
feature_names=[f"x{i}" for i in range(2)],
outcome_names=["y1"],
)
mtds = MultiTaskDataset(datasets=[dataset1, dataset2], target_outcome_name="y0")
input_transform_kwargs = input_transform_argparse(
Normalize,
dataset=mtds,
search_space_digest=self.search_space_digest,
)
self.assertEqual(input_transform_kwargs["d"], 4)
self.assertEqual(input_transform_kwargs["indices"], [0, 1, 2])

def test_argparse_warp(self) -> None:
self.search_space_digest.task_features = [0, 3]
input_transform_kwargs = input_transform_argparse(
Warp,
Expand All @@ -148,7 +169,7 @@ def test_argparse_warp(self) -> None:

self.assertEqual(
input_transform_kwargs,
{"indices": [1, 2], "batch_shape": torch.Size([2])},
{"indices": [1, 2]},
)

input_transform_kwargs = input_transform_argparse(
Expand All @@ -158,9 +179,7 @@ def test_argparse_warp(self) -> None:
input_transform_options={"indices": [0, 1]},
)

self.assertEqual(
input_transform_kwargs, {"indices": [0, 1], "batch_shape": torch.Size([2])}
)
self.assertEqual(input_transform_kwargs, {"indices": [0, 1]})

def test_argparse_input_perturbation(self) -> None:

Expand Down

0 comments on commit aa5e700

Please sign in to comment.