diff --git a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py index 3b4aa3e5f45..1cc17df277f 100644 --- a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py +++ b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py @@ -10,7 +10,6 @@ import torch from ax.core.search_space import SearchSpaceDigest -from ax.models.torch.botorch_defaults import _get_batch_shape from ax.models.torch.utils import normalize_indices from ax.utils.common.typeutils import _argparse_type_encoder from botorch.models.transforms.input import ( @@ -20,7 +19,7 @@ Warp, ) from botorch.utils.containers import SliceContainer -from botorch.utils.datasets import SupervisedDataset +from botorch.utils.datasets import RankingDataset, SupervisedDataset from botorch.utils.dispatcher import Dispatcher @@ -75,22 +74,15 @@ def _input_transform_argparse_warp( Returns: A dictionary with input transform kwargs. """ - input_transform_options = input_transform_options or {} - d = dataset.X.shape[-1] - + d = len(dataset.feature_names) indices = list(range(d)) - task_features = normalize_indices(search_space_digest.task_features, d=d) for task_feature in sorted(task_features, reverse=True): del indices[task_feature] - batch_shape = _get_batch_shape(dataset.X, dataset.Y) - input_transform_options.setdefault("indices", indices) - input_transform_options.setdefault("batch_shape", batch_shape) - return input_transform_options @@ -118,18 +110,15 @@ def _input_transform_argparse_normalize( Returns: A dictionary with input transform kwargs. """ - input_transform_options = input_transform_options or {} - - d = dataset.X.shape[-1] - + d = input_transform_options.get("d", len(dataset.feature_names)) bounds = torch.as_tensor( search_space_digest.bounds, dtype=torch_dtype, device=torch_device, ).T - if isinstance(dataset.X, SliceContainer): + if isinstance(dataset, RankingDataset) and isinstance(dataset.X, SliceContainer): d = dataset.X.values.shape[-1] indices = list(range(d)) diff --git a/ax/models/torch/tests/test_input_transform_argparse.py b/ax/models/torch/tests/test_input_transform_argparse.py index 360484b40ae..2d05fb716e4 100644 --- a/ax/models/torch/tests/test_input_transform_argparse.py +++ b/ax/models/torch/tests/test_input_transform_argparse.py @@ -23,7 +23,7 @@ Normalize, Warp, ) -from botorch.utils.datasets import SupervisedDataset +from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset class DummyInputTransform(InputTransform): # pyre-ignore [13] @@ -92,7 +92,6 @@ def test_argparse_input_transform(self) -> None: self.assertEqual(input_transform_kwargs, {"d": 10}) def test_argparse_normalize(self) -> None: - input_transform_kwargs = input_transform_argparse( Normalize, dataset=self.dataset, @@ -137,8 +136,30 @@ def test_argparse_normalize(self) -> None: ) ) - def test_argparse_warp(self) -> None: + # Test with MultiTaskDataset. + dataset1 = SupervisedDataset( + X=torch.rand(5, 4), + Y=torch.randn(5, 1), + feature_names=[f"x{i}" for i in range(4)], + outcome_names=["y0"], + ) + + dataset2 = SupervisedDataset( + X=torch.rand(5, 2), + Y=torch.randn(5, 1), + feature_names=[f"x{i}" for i in range(2)], + outcome_names=["y1"], + ) + mtds = MultiTaskDataset(datasets=[dataset1, dataset2], target_outcome_name="y0") + input_transform_kwargs = input_transform_argparse( + Normalize, + dataset=mtds, + search_space_digest=self.search_space_digest, + ) + self.assertEqual(input_transform_kwargs["d"], 4) + self.assertEqual(input_transform_kwargs["indices"], [0, 1, 2]) + def test_argparse_warp(self) -> None: self.search_space_digest.task_features = [0, 3] input_transform_kwargs = input_transform_argparse( Warp, @@ -148,7 +169,7 @@ def test_argparse_warp(self) -> None: self.assertEqual( input_transform_kwargs, - {"indices": [1, 2], "batch_shape": torch.Size([2])}, + {"indices": [1, 2]}, ) input_transform_kwargs = input_transform_argparse( @@ -158,9 +179,7 @@ def test_argparse_warp(self) -> None: input_transform_options={"indices": [0, 1]}, ) - self.assertEqual( - input_transform_kwargs, {"indices": [0, 1], "batch_shape": torch.Size([2])} - ) + self.assertEqual(input_transform_kwargs, {"indices": [0, 1]}) def test_argparse_input_perturbation(self) -> None: