From 79ee0450eb466aa4319a7a39390bcc77a2991ae5 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Mon, 4 Mar 2024 08:18:46 -0800 Subject: [PATCH] add output_tasks to MTGP in MBM (#2241) Summary: see title. output_tasks were not set for MTGP in MBM. Differential Revision: D54453253 --- ax/models/torch/botorch_modular/surrogate.py | 33 +++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 27f9ac9bde0..5371935a9ba 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -49,6 +49,7 @@ ) from botorch.models.model import Model from botorch.models.model_list_gp_regression import ModelListGP +from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP from botorch.models.transforms.input import ( ChainedInputTransform, @@ -57,7 +58,7 @@ ) from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform from botorch.utils.containers import SliceContainer -from botorch.utils.datasets import RankingDataset, SupervisedDataset +from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset from botorch.utils.dispatcher import Dispatcher from gpytorch.kernels import Kernel from gpytorch.likelihoods.likelihood import Likelihood @@ -841,3 +842,33 @@ def _submodel_input_constructor_base( botorch_model_class_args=botorch_model_class_args, ) return formatted_model_inputs + + +@submodel_input_constructor.register(MultiTaskGP) +def _submodel_input_constructor_mtgp( + botorch_model_class: Type[Model], + dataset: SupervisedDataset, + search_space_digest: SearchSpaceDigest, + surrogate: Surrogate, +) -> Dict[str, Any]: + if len(dataset.outcome_names) > 1: + raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.") + formatted_model_inputs = _submodel_input_constructor_base( + botorch_model_class=botorch_model_class, + dataset=dataset, + search_space_digest=search_space_digest, + surrogate=surrogate, + ) + if ( + isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features + ) or "task_feature" not in formatted_model_inputs: + # Note: if that datasets are heterogeneous, `dataset.X` + # will fail since it tries to concatenate heterogeneous + # datasets. + return formatted_model_inputs + # specify output tasks so that model.num_outputs = 1 + # since the model only models a single outcome + formatted_model_inputs["output_tasks"] = dataset.X[ + :1, formatted_model_inputs["task_feature"] + ].tolist() + return formatted_model_inputs