Skip to content

Commit

Permalink
Even more fixes to unused kwargs (#1782)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1782

X-link: pytorch/botorch#1985

See previous diff

Reviewed By: lena-kashtelyan

Differential Revision: D48338443

fbshipit-source-id: 022be0195dc74475d38ec32a259a659838f633c1
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 17, 2023
1 parent 3ad1734 commit 7811fac
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 141 deletions.
7 changes: 1 addition & 6 deletions ax/modelbridge/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def get_MOO_NEHVI(
# having a batch limit is very important for avoiding
# memory issues in the initialization
"options": {"batch_limit": DEFAULT_EHVI_BATCH_LIMIT},
"sequential": True,
},
},
use_input_warping=use_input_warping,
Expand Down Expand Up @@ -174,7 +173,6 @@ def get_MTGP_NEHVI(
# having a batch limit is very important for avoiding
# memory issues in the initialization
"options": {"batch_limit": DEFAULT_EHVI_BATCH_LIMIT},
"sequential": True,
},
},
optimization_config=optimization_config,
Expand Down Expand Up @@ -522,7 +520,7 @@ def get_MOO_EHVI(
torch_device=device,
acqf_constructor=get_EHVI,
default_model_gen_options={
"acquisition_function_kwargs": {"sequential": True},
"acquisition_function_kwargs": {},
"optimizer_kwargs": {
# having a batch limit is very important for avoiding
# memory issues in the initialization
Expand Down Expand Up @@ -562,7 +560,6 @@ def get_MOO_PAREGO(
default_model_gen_options={
"acquisition_function_kwargs": {
"chebyshev_scalarization": True,
"sequential": True,
}
},
optimization_config=optimization_config,
Expand Down Expand Up @@ -599,7 +596,6 @@ def get_MOO_RS(
default_model_gen_options={
"acquisition_function_kwargs": {
"random_scalarization": True,
"sequential": True,
}
},
optimization_config=optimization_config,
Expand Down Expand Up @@ -673,7 +669,6 @@ def get_MTGP_PAREGO(
default_model_gen_options={
"acquisition_function_kwargs": {
"chebyshev_scalarization": True,
"sequential": True,
}
},
optimization_config=optimization_config,
Expand Down
2 changes: 0 additions & 2 deletions ax/modelbridge/tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ def test_MOO_RS(self) -> None:
{
"acquisition_function_kwargs": {
"random_scalarization": True,
"sequential": True,
}
},
moo_rs._default_model_gen_options,
Expand All @@ -335,7 +334,6 @@ def test_MOO_PAREGO(self) -> None:
{
"acquisition_function_kwargs": {
"chebyshev_scalarization": True,
"sequential": True,
}
},
moo_parego._default_model_gen_options,
Expand Down
1 change: 0 additions & 1 deletion ax/models/tests/test_botorch_moo_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def test_get_ehvi(self, _) -> None:
X_pending=X_pending,
constraints=cons_tfs,
mc_samples=128,
qmc=True,
alpha=0.0,
seed=seed,
ref_point=new_obj_thresholds.tolist(),
Expand Down
4 changes: 3 additions & 1 deletion ax/models/tests/test_botorch_moo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,9 @@ def test_BotorchMOOModel_with_qehvi(
"acquisition_function_kwargs": {
"cache_root": False,
"prune_baseline": False,
},
}
if use_qnehvi
else {},
},
)
gen_results = model.gen(
Expand Down
209 changes: 120 additions & 89 deletions ax/models/torch/botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ax.models.torch.utils import _to_inequality_constraints
from ax.models.torch_base import TorchModel
from ax.models.types import TConfig
from ax.utils.common.constants import Keys
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.objective import ConstrainedMCObjective, GenericMCObjective
Expand Down Expand Up @@ -51,6 +50,90 @@
)


def _construct_model(
task_feature: Optional[int],
Xs: List[Tensor],
Ys: List[Tensor],
Yvars: List[Tensor],
fidelity_features: List[int],
metric_names: List[str],
use_input_warping: bool = False,
prior: Optional[Dict[str, Any]] = None,
*,
multitask_gp_ranks: Optional[Dict[str, Union[Prior, float]]] = None,
**kwargs: Any,
) -> GPyTorchModel:
"""
Figures out how to call `_get_model` depending on inputs. Used by
`get_and_fit_model`.
"""
if task_feature is None:
if len(Xs) == 1:
# Use single output, single task GP
return _get_model(
X=Xs[0],
Y=Ys[0],
Yvar=Yvars[0],
task_feature=task_feature,
fidelity_features=fidelity_features,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
if all(torch.equal(Xs[0], X) for X in Xs[1:]) and not use_input_warping:
# Use batched multioutput, single task GP
# Require using a ModelListGP if using input warping
Y = torch.cat(Ys, dim=-1)
Yvar = torch.cat(Yvars, dim=-1)
return _get_model(
X=Xs[0],
Y=Y,
Yvar=Yvar,
task_feature=task_feature,
fidelity_features=fidelity_features,
prior=deepcopy(prior),
**kwargs,
)

if task_feature is None:
models = [
_get_model(
X=X,
Y=Y,
Yvar=Yvar,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
for X, Y, Yvar in zip(Xs, Ys, Yvars)
]
else:
# use multi-task GP
mtgp_rank_dict = {} if multitask_gp_ranks is None else multitask_gp_ranks
# assembles list of ranks associated with each metric
if len({len(Xs), len(Ys), len(Yvars), len(metric_names)}) > 1:
raise ValueError(
"Lengths of Xs, Ys, Yvars, and metric_names must match. Your "
f"inputs have lengths {len(Xs)}, {len(Ys)}, {len(Yvars)}, and "
f"{len(metric_names)}, respectively."
)
mtgp_rank_list = [mtgp_rank_dict.get(metric, None) for metric in metric_names]
models = [
_get_model(
X=X,
Y=Y,
Yvar=Yvar,
task_feature=task_feature,
rank=mtgp_rank,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
for X, Y, Yvar, mtgp_rank in zip(Xs, Ys, Yvars, mtgp_rank_list)
]
return ModelListGP(*models)


def get_and_fit_model(
Xs: List[Tensor],
Ys: List[Tensor],
Expand All @@ -63,6 +146,8 @@ def get_and_fit_model(
use_input_warping: bool = False,
use_loocv_pseudo_likelihood: bool = False,
prior: Optional[Dict[str, Any]] = None,
*,
multitask_gp_ranks: Optional[Dict[str, Union[Prior, float]]] = None,
**kwargs: Any,
) -> GPyTorchModel:
r"""Instantiates and fits a botorch GPyTorchModel using the given data.
Expand All @@ -88,6 +173,7 @@ def get_and_fit_model(
- sd_prior: A scalar prior over nonnegative numbers, which is used for the
default LKJCovariancePrior task_covar_prior.
- eta: The eta parameter on the default LKJ task_covar_prior.
kwargs: Passed to `_get_model`.
Returns:
A fitted GPyTorchModel.
Expand All @@ -109,82 +195,25 @@ def get_and_fit_model(
task_feature = task_features[0]
else:
task_feature = None
model = None

model = _construct_model(
task_feature=task_feature,
Xs=Xs,
Ys=Ys,
Yvars=Yvars,
fidelity_features=fidelity_features,
metric_names=metric_names,
use_input_warping=use_input_warping,
prior=prior,
multitask_gp_ranks=multitask_gp_ranks,
**kwargs,
)

# TODO: Better logic for deciding when to use a ModelListGP. Currently the
# logic is unclear. The two cases in which ModelListGP is used are
# (i) the training inputs (Xs) are not the same for the different outcomes, and
# (ii) a multi-task model is used

if task_feature is None:
if len(Xs) == 1:
# Use single output, single task GP
model = _get_model(
X=Xs[0],
Y=Ys[0],
Yvar=Yvars[0],
task_feature=task_feature,
fidelity_features=fidelity_features,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
elif all(torch.equal(Xs[0], X) for X in Xs[1:]) and not use_input_warping:
# Use batched multioutput, single task GP
# Require using a ModelListGP if using input warping
Y = torch.cat(Ys, dim=-1)
Yvar = torch.cat(Yvars, dim=-1)
model = _get_model(
X=Xs[0],
Y=Y,
Yvar=Yvar,
task_feature=task_feature,
fidelity_features=fidelity_features,
prior=deepcopy(prior),
**kwargs,
)
# TODO: Is this equivalent an "else:" here?

if model is None:
if task_feature is None:
models = [
_get_model(
X=X,
Y=Y,
Yvar=Yvar,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
for X, Y, Yvar in zip(Xs, Ys, Yvars)
]
else:
# use multi-task GP
mtgp_rank_dict = kwargs.pop("multitask_gp_ranks", {})
# assembles list of ranks associated with each metric
if len({len(Xs), len(Ys), len(Yvars), len(metric_names)}) > 1:
raise ValueError(
"Lengths of Xs, Ys, Yvars, and metric_names must match. Your "
f"inputs have lengths {len(Xs)}, {len(Ys)}, {len(Yvars)}, and "
f"{len(metric_names)}, respectively."
)
mtgp_rank_list = [
mtgp_rank_dict.get(metric, None) for metric in metric_names
]
models = [
_get_model(
X=X,
Y=Y,
Yvar=Yvar,
task_feature=task_feature,
rank=mtgp_rank,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
for X, Y, Yvar, mtgp_rank in zip(Xs, Ys, Yvars, mtgp_rank_list)
]
model = ModelListGP(*models)
model.to(Xs[0])
if state_dict is not None:
model.load_state_dict(state_dict)
Expand Down Expand Up @@ -237,7 +266,11 @@ def _get_acquisition_func(
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict` to avoid runtime subscripting errors.
mc_objective_kwargs: Optional[Dict] = None,
**kwargs: Any,
*,
chebyshev_scalarization: bool = False,
prune_baseline: bool = True,
mc_samples: int = 512,
marginalize_dim: Optional[int] = None,
) -> AcquisitionFunction:
r"""Instantiates a acquisition function.
Expand Down Expand Up @@ -266,7 +299,6 @@ def _get_acquisition_func(
For GenericMCObjective, leave it as None. For PenalizedMCObjective,
it needs to be specified in the format of kwargs.
mc_samples: The number of MC samples to use (default: 512).
qmc: If True, use qMC instead of MC (default: True).
prune_baseline: If True, prune the baseline points for NEI (default: True).
chebyshev_scalarization: Use augmented Chebyshev scalarization.
Expand All @@ -276,7 +308,7 @@ def _get_acquisition_func(
if X_observed is None:
raise ValueError(NO_FEASIBLE_POINTS_MESSAGE)
# construct Objective module
if kwargs.get("chebyshev_scalarization", False):
if chebyshev_scalarization:
with torch.no_grad():
Y = model.posterior(X_observed).mean # pyre-ignore [16]
obj_tf = get_chebyshev_scalarization(weights=objective_weights, Y=Y)
Expand Down Expand Up @@ -312,13 +344,12 @@ def objective(samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
objective=objective,
X_observed=X_observed,
X_pending=X_pending,
prune_baseline=kwargs.get("prune_baseline", True),
mc_samples=kwargs.get("mc_samples", 512),
qmc=kwargs.get("qmc", True),
prune_baseline=prune_baseline,
mc_samples=mc_samples,
# pyre-fixme[6]: Expected `Optional[int]` for 9th param but got
# `Union[float, int]`.
seed=torch.randint(1, 10000, (1,)).item(),
marginalize_dim=kwargs.get("marginalize_dim"),
marginalize_dim=marginalize_dim,
)


Expand All @@ -330,7 +361,11 @@ def scipy_optimizer(
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
fixed_features: Optional[Dict[int, float]] = None,
rounding_func: Optional[Callable[[Tensor], Tensor]] = None,
**kwargs: Any,
*,
num_restarts: int = 20,
raw_samples: Optional[int] = None,
joint_optimization: bool = False,
options: Optional[Dict[str, Union[bool, float, int, str]]] = None,
) -> Tuple[Tensor, Tensor]:
r"""Optimizer using scipy's minimize module on a numpy-adpator.
Expand Down Expand Up @@ -360,25 +395,21 @@ def scipy_optimizer(
values, where `i`-th element is the expected acquisition value
conditional on having observed candidates `0,1,...,i-1`.
"""
num_restarts: int = kwargs.pop(Keys.NUM_RESTARTS, 20)
raw_samples: int = kwargs.pop(Keys.RAW_SAMPLES, 50 * num_restarts)

if kwargs.get("joint_optimization", False):
sequential = False
else:
sequential = True
options: Dict[str, Union[bool, float, int, str]] = {
sequential = not joint_optimization
optimize_acqf_options: Dict[str, Union[bool, float, int, str]] = {
"batch_limit": 5,
"init_batch_limit": 32,
}
options.update(kwargs.get("options", {}))
if options is not None:
optimize_acqf_options.update(options)
X, expected_acquisition_value = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
q=n,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=options,
raw_samples=50 * num_restarts if raw_samples is None else raw_samples,
options=optimize_acqf_options,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features=fixed_features,
Expand Down
Loading

0 comments on commit 7811fac

Please sign in to comment.