Skip to content

Commit

Permalink
Move model input constructor helper functions off of Surrogate, to …
Browse files Browse the repository at this point in the history
…standalone functions (#2655)

Summary:
Pull Request resolved: #2655

Context: These methods do not reference attributes of the surrogate, so it is confusing for these to be `Surrogate` methods. Also, one might want to use these in the absence of a surrogate.

This PR:
* Makes `_make_botorch_input_transform`, `_make_botorch_outcome_transform`, and `_set_formatted_inputs` into standalone functions rather than `Surrogate` methods.

Reviewed By: saitcakmak, Balandat

Differential Revision: D61215241

fbshipit-source-id: 33339394987202510ec8256cc6b3576cf074c214
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 14, 2024
1 parent d8078c5 commit a61c362
Showing 1 changed file with 155 additions and 163 deletions.
318 changes: 155 additions & 163 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,159 @@ def _extract_model_kwargs(
return kwargs


def _make_botorch_input_transform(
input_classes: list[type[InputTransform]],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
input_options: dict[str, dict[str, Any]],
) -> Optional[InputTransform]:
"""
Makes a BoTorch input transform from the provided input classes and options.
"""
if not (
isinstance(input_classes, list)
and all(issubclass(c, InputTransform) for c in input_classes)
):
raise UserInputError("Expected a list of input transforms.")
if len(input_classes) == 0:
return None

input_transform_kwargs = [
input_transform_argparse(
single_input_class,
dataset=dataset,
search_space_digest=search_space_digest,
input_transform_options=input_options.get(single_input_class.__name__, {}),
)
for single_input_class in input_classes
]

input_transforms = [
# pyre-fixme[45]: Cannot instantiate abstract class `InputTransform`.
single_input_class(**single_input_transform_kwargs)
for single_input_class, single_input_transform_kwargs in zip(
input_classes, input_transform_kwargs
)
]

input_instance = (
ChainedInputTransform(
**{f"tf{i}": input_transforms[i] for i in range(len(input_transforms))}
)
if len(input_transforms) > 1
else input_transforms[0]
)

return input_instance


def _make_botorch_outcome_transform(
input_classes: list[type[OutcomeTransform]],
input_options: dict[str, dict[str, Any]],
dataset: SupervisedDataset,
) -> Optional[OutcomeTransform]:
"""
Makes a BoTorch outcome transform from the provided classes and options.
"""
if not (
isinstance(input_classes, list)
and all(issubclass(c, OutcomeTransform) for c in input_classes)
):
raise UserInputError("Expected a list of outcome transforms.")
if len(input_classes) == 0:
return None

outcome_transform_kwargs = [
outcome_transform_argparse(
input_class,
outcome_transform_options=input_options.get(input_class.__name__, {}),
dataset=dataset,
)
for input_class in input_classes
]

outcome_transforms = [
# pyre-fixme[45]: Cannot instantiate abstract class `OutcomeTransform`.
input_class(**single_outcome_transform_kwargs)
for input_class, single_outcome_transform_kwargs in zip(
input_classes, outcome_transform_kwargs
)
]

outcome_transform_instance = (
ChainedOutcomeTransform(
**{f"otf{i}": otf for i, otf in enumerate(outcome_transforms)}
)
if len(outcome_transforms) > 1
else outcome_transforms[0]
)
return outcome_transform_instance


def _set_formatted_inputs(
formatted_model_inputs: dict[str, Any],
# pyre-ignore [2] The proper hint for the second arg is Union[None,
# Type[Kernel], Type[Likelihood], List[Type[OutcomeTransform]],
# List[Type[InputTransform]]]. Keeping it as Any saves us from a
# bunch of checked_cast calls within the for loop.
inputs: list[tuple[str, Any, dict[str, Any]]],
dataset: SupervisedDataset,
botorch_model_class_args: list[str],
search_space_digest: SearchSpaceDigest,
botorch_model_class: type[Model],
) -> None:
"""Modifies `formatted_model_inputs` in place."""
for input_name, input_class, input_options in inputs:
if input_class is None:
# This is a temporary solution until all BoTorch models use
# `Standardize` by default, see TODO [T197435440].
# After this, we should update `Surrogate` to use `DEFAULT`
# (https://fburl.com/code/22f4397e) for both of these args. This will
# allow users to explicitly disable the default transforms by passing
# in `None`.
if (
input_name in ["outcome_transform"]
and input_name in botorch_model_class_args
):
formatted_model_inputs[input_name] = None
continue
if input_name not in botorch_model_class_args:
raise UserInputError(
f"The BoTorch model class {botorch_model_class.__name__} does not "
f"support the input {input_name}."
)
input_options = deepcopy(input_options) or {}

if input_name == "covar_module":
covar_module_with_defaults = covar_module_argparse(
input_class,
dataset=dataset,
botorch_model_class=botorch_model_class,
**input_options,
)

formatted_model_inputs[input_name] = input_class(
**covar_module_with_defaults
)

elif input_name == "input_transform":
formatted_model_inputs[input_name] = _make_botorch_input_transform(
input_classes=input_class,
input_options=input_options,
dataset=dataset,
search_space_digest=search_space_digest,
)

elif input_name == "outcome_transform":
formatted_model_inputs[input_name] = _make_botorch_outcome_transform(
input_classes=input_class,
input_options=input_options,
dataset=dataset,
)
else:
formatted_model_inputs[input_name] = input_class(**input_options)


class Surrogate(Base):
"""
**All classes in 'botorch_modular' directory are under
Expand Down Expand Up @@ -344,167 +497,6 @@ def _should_reuse_last_model(
return True
return False

def _set_formatted_inputs(
self,
formatted_model_inputs: dict[str, Any],
# pyre-ignore [2] The proper hint for the second arg is Union[None,
# Type[Kernel], Type[Likelihood], List[Type[OutcomeTransform]],
# List[Type[InputTransform]]]. Keeping it as Any saves us from a
# bunch of checked_cast calls within the for loop.
inputs: list[tuple[str, Any, dict[str, Any]]],
dataset: SupervisedDataset,
botorch_model_class_args: list[str],
search_space_digest: SearchSpaceDigest,
botorch_model_class: type[Model],
) -> None:
"""Modifies `formatted_model_inputs` in place."""
for input_name, input_class, input_options in inputs:
if input_class is None:
# This is a temporary solution until all BoTorch models use
# `Standardize` by default, see TODO [T197435440].
# After this, we should update `Surrogate` to use `DEFAULT`
# (https://fburl.com/code/22f4397e) for both of these args. This will
# allow users to explicitly disable the default transforms by passing
# in `None`.
if (
input_name in ["outcome_transform"]
and input_name in botorch_model_class_args
):
formatted_model_inputs[input_name] = None
continue
if input_name not in botorch_model_class_args:
# TODO: We currently only pass in `covar_module` and `likelihood`
# if they are inputs to the BoTorch model. This interface will need
# to be expanded to a ModelFactory, see D22457664, to accommodate
# different models in the future.
raise UserInputError(
f"The BoTorch model class {botorch_model_class.__name__} does not "
f"support the input {input_name}."
)
input_options = deepcopy(input_options) or {}

if input_name == "covar_module":
covar_module_with_defaults = covar_module_argparse(
input_class,
dataset=dataset,
botorch_model_class=botorch_model_class,
**input_options,
)

formatted_model_inputs[input_name] = input_class(
**covar_module_with_defaults
)

elif input_name == "input_transform":
formatted_model_inputs[input_name] = self._make_botorch_input_transform(
input_classes=input_class,
input_options=input_options,
dataset=dataset,
search_space_digest=search_space_digest,
)

elif input_name == "outcome_transform":
formatted_model_inputs[input_name] = (
self._make_botorch_outcome_transform(
input_classes=input_class,
input_options=input_options,
dataset=dataset,
)
)
else:
formatted_model_inputs[input_name] = input_class(**input_options)

def _make_botorch_input_transform(
self,
input_classes: list[type[InputTransform]],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
input_options: dict[str, dict[str, Any]],
) -> Optional[InputTransform]:
"""
Makes a BoTorch input transform from the provided input classes and options.
"""
if not (
isinstance(input_classes, list)
and all(issubclass(c, InputTransform) for c in input_classes)
):
raise UserInputError("Expected a list of input transforms.")
if len(input_classes) == 0:
return None

input_transform_kwargs = [
input_transform_argparse(
single_input_class,
dataset=dataset,
search_space_digest=search_space_digest,
input_transform_options=input_options.get(
single_input_class.__name__, {}
),
)
for single_input_class in input_classes
]

input_transforms = [
# pyre-fixme[45]: Cannot instantiate abstract class `InputTransform`.
single_input_class(**single_input_transform_kwargs)
for single_input_class, single_input_transform_kwargs in zip(
input_classes, input_transform_kwargs
)
]

input_instance = (
ChainedInputTransform(
**{f"tf{i}": input_transforms[i] for i in range(len(input_transforms))}
)
if len(input_transforms) > 1
else input_transforms[0]
)

return input_instance

def _make_botorch_outcome_transform(
self,
input_classes: list[type[OutcomeTransform]],
input_options: dict[str, dict[str, Any]],
dataset: SupervisedDataset,
) -> Optional[OutcomeTransform]:
"""
Makes a BoTorch outcome transform from the provided classes and options.
"""
if not (
isinstance(input_classes, list)
and all(issubclass(c, OutcomeTransform) for c in input_classes)
):
raise UserInputError("Expected a list of outcome transforms.")
if len(input_classes) == 0:
return None

outcome_transform_kwargs = [
outcome_transform_argparse(
input_class,
outcome_transform_options=input_options.get(input_class.__name__, {}),
dataset=dataset,
)
for input_class in input_classes
]

outcome_transforms = [
# pyre-fixme[45]: Cannot instantiate abstract class `OutcomeTransform`.
input_class(**single_outcome_transform_kwargs)
for input_class, single_outcome_transform_kwargs in zip(
input_classes, outcome_transform_kwargs
)
]

outcome_transform_instance = (
ChainedOutcomeTransform(
**{f"otf{i}": otf for i, otf in enumerate(outcome_transforms)}
)
if len(outcome_transforms) > 1
else outcome_transforms[0]
)
return outcome_transform_instance

def fit(
self,
datasets: Sequence[SupervisedDataset],
Expand Down Expand Up @@ -754,7 +746,7 @@ def _extract_construct_input_transform_args(
) -> tuple[Optional[Sequence[type[InputTransform]]], dict[str, dict[str, Any]]]:
"""
Extracts input transform classes and input transform options that will
be used in `self._set_formatted_inputs` and ultimately passed to
be used in `_set_formatted_inputs` and ultimately passed to
BoTorch.
Args:
Expand Down Expand Up @@ -852,7 +844,7 @@ def _submodel_input_constructor_base(
)

botorch_model_class_args = inspect.getfullargspec(botorch_model_class).args
surrogate._set_formatted_inputs(
_set_formatted_inputs(
formatted_model_inputs=formatted_model_inputs,
inputs=[
(
Expand Down

0 comments on commit a61c362

Please sign in to comment.