Skip to content

Commit

Permalink
add output_tasks to MTGP in MBM (#2241)
Browse files Browse the repository at this point in the history
Summary:

see title. output_tasks were not set for MTGP in MBM.

Differential Revision: D54453253
  • Loading branch information
sdaulton authored and facebook-github-bot committed Mar 4, 2024
1 parent b98f3be commit 79ee045
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
)
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import (
ChainedInputTransform,
Expand All @@ -57,7 +58,7 @@
)
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from gpytorch.kernels import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
Expand Down Expand Up @@ -841,3 +842,33 @@ def _submodel_input_constructor_base(
botorch_model_class_args=botorch_model_class_args,
)
return formatted_model_inputs


@submodel_input_constructor.register(MultiTaskGP)
def _submodel_input_constructor_mtgp(
botorch_model_class: Type[Model],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
surrogate: Surrogate,
) -> Dict[str, Any]:
if len(dataset.outcome_names) > 1:
raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.")
formatted_model_inputs = _submodel_input_constructor_base(
botorch_model_class=botorch_model_class,
dataset=dataset,
search_space_digest=search_space_digest,
surrogate=surrogate,
)
if (
isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features
) or "task_feature" not in formatted_model_inputs:
# Note: if that datasets are heterogeneous, `dataset.X`
# will fail since it tries to concatenate heterogeneous
# datasets.
return formatted_model_inputs
# specify output tasks so that model.num_outputs = 1
# since the model only models a single outcome
formatted_model_inputs["output_tasks"] = dataset.X[
:1, formatted_model_inputs["task_feature"]
].tolist()
return formatted_model_inputs

0 comments on commit 79ee045

Please sign in to comment.