Skip to content

Commit

Permalink
Fix usage of batch shape for warp transform
Browse files Browse the repository at this point in the history
Summary: Follows up pytorch/botorch#2109 to fix potential for erroneous usage.

Differential Revision: D51369374
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 15, 2023
1 parent 6c0564e commit e8599dd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
4 changes: 2 additions & 2 deletions ax/models/tests/test_botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def test_get_customized_covar_module(self) -> None:
covar_module = _get_customized_covar_module(
covar_module_prior_dict={},
ard_num_dims=ard_num_dims,
batch_shape=batch_shape,
aug_batch_shape=batch_shape,
task_feature=None,
)
self.assertIsInstance(covar_module, Module)
Expand All @@ -495,7 +495,7 @@ def test_get_customized_covar_module(self) -> None:
"outputscale_prior": GammaPrior(2.0, 12.0),
},
ard_num_dims=ard_num_dims,
batch_shape=batch_shape,
aug_batch_shape=batch_shape,
task_feature=3,
)
self.assertIsInstance(covar_module, Module)
Expand Down
23 changes: 13 additions & 10 deletions ax/models/torch/botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,6 @@ def _get_model(
is_nan = torch.isnan(Yvar)
any_nan_Yvar = torch.any(is_nan)
all_nan_Yvar = torch.all(is_nan)
batch_shape = _get_batch_shape(X, Y)
if any_nan_Yvar and not all_nan_Yvar:
if task_feature:
# TODO (jej): Replace with inferred noise before making perf judgements.
Expand All @@ -722,10 +721,14 @@ def _get_model(
"errors. Variances should all be specified, or none should be."
)
if use_input_warping:
if Y.shape[-1] > 1 and X.ndim > 2:
raise UnsupportedError( # pragma: no cover
"Input warping is not supported for batched multi output models."
)
warp_tf = get_warping_transform(
d=X.shape[-1],
task_feature=task_feature,
batch_shape=batch_shape,
batch_shape=X.shape[:-2],
)
else:
warp_tf = None
Expand All @@ -741,7 +744,7 @@ def _get_model(
covar_module = _get_customized_covar_module(
covar_module_prior_dict=covar_module_prior_dict,
ard_num_dims=X.shape[-1],
batch_shape=batch_shape,
batch_shape=_get_aug_batch_shape(X, Y),
task_feature=task_feature,
)

Expand Down Expand Up @@ -804,17 +807,17 @@ def _get_model(
def _get_customized_covar_module(
covar_module_prior_dict: Dict[str, Prior],
ard_num_dims: int,
batch_shape: torch.Size,
aug_batch_shape: torch.Size,
task_feature: Optional[int] = None,
) -> Kernel:
"""Construct a GP kernel based on customized prior dict.
Args:
covar_module_prior_dict: Dict. The keys are the names of the prior and values
are the priors. e.g. {"lengthscale_prior": GammaPrior(3.0, 6.0)}.
ard_num_dims: The dimension of the input, including task features
batch_shape: The batch_shape of the model
task_feature: The index of the task feature
ard_num_dims: The dimension of the inputs, including task features.
batch_shape: The output dimension augmented batch shape of the model.
task_feature: The index of the task feature.
"""
# TODO: add more checks of covar_module_prior_dict
if task_feature is not None:
Expand All @@ -823,19 +826,19 @@ def _get_customized_covar_module(
MaternKernel(
nu=2.5,
ard_num_dims=ard_num_dims,
batch_shape=batch_shape,
batch_shape=aug_batch_shape,
lengthscale_prior=covar_module_prior_dict.get(
"lengthscale_prior", GammaPrior(3.0, 6.0)
),
),
batch_shape=batch_shape,
batch_shape=aug_batch_shape,
outputscale_prior=covar_module_prior_dict.get(
"outputscale_prior", GammaPrior(2.0, 0.15)
),
)


def _get_batch_shape(X: Tensor, Y: Tensor) -> torch.Size:
def _get_aug_batch_shape(X: Tensor, Y: Tensor) -> torch.Size:
"""Obtain the output-augmented batch shape of GP model.
Args:
Expand Down

0 comments on commit e8599dd

Please sign in to comment.