Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Even more fixes to unused kwargs #1782

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions ax/modelbridge/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def get_MOO_NEHVI(
# having a batch limit is very important for avoiding
# memory issues in the initialization
"options": {"batch_limit": DEFAULT_EHVI_BATCH_LIMIT},
"sequential": True,
},
},
use_input_warping=use_input_warping,
Expand Down Expand Up @@ -174,7 +173,6 @@ def get_MTGP_NEHVI(
# having a batch limit is very important for avoiding
# memory issues in the initialization
"options": {"batch_limit": DEFAULT_EHVI_BATCH_LIMIT},
"sequential": True,
},
},
optimization_config=optimization_config,
Expand Down Expand Up @@ -522,7 +520,7 @@ def get_MOO_EHVI(
torch_device=device,
acqf_constructor=get_EHVI,
default_model_gen_options={
"acquisition_function_kwargs": {"sequential": True},
"acquisition_function_kwargs": {},
"optimizer_kwargs": {
# having a batch limit is very important for avoiding
# memory issues in the initialization
Expand Down Expand Up @@ -562,7 +560,6 @@ def get_MOO_PAREGO(
default_model_gen_options={
"acquisition_function_kwargs": {
"chebyshev_scalarization": True,
"sequential": True,
}
},
optimization_config=optimization_config,
Expand Down Expand Up @@ -599,7 +596,6 @@ def get_MOO_RS(
default_model_gen_options={
"acquisition_function_kwargs": {
"random_scalarization": True,
"sequential": True,
}
},
optimization_config=optimization_config,
Expand Down Expand Up @@ -673,7 +669,6 @@ def get_MTGP_PAREGO(
default_model_gen_options={
"acquisition_function_kwargs": {
"chebyshev_scalarization": True,
"sequential": True,
}
},
optimization_config=optimization_config,
Expand Down
2 changes: 0 additions & 2 deletions ax/modelbridge/tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ def test_MOO_RS(self) -> None:
{
"acquisition_function_kwargs": {
"random_scalarization": True,
"sequential": True,
}
},
moo_rs._default_model_gen_options,
Expand All @@ -335,7 +334,6 @@ def test_MOO_PAREGO(self) -> None:
{
"acquisition_function_kwargs": {
"chebyshev_scalarization": True,
"sequential": True,
}
},
moo_parego._default_model_gen_options,
Expand Down
1 change: 0 additions & 1 deletion ax/models/tests/test_botorch_moo_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def test_get_ehvi(self, _) -> None:
X_pending=X_pending,
constraints=cons_tfs,
mc_samples=128,
qmc=True,
alpha=0.0,
seed=seed,
ref_point=new_obj_thresholds.tolist(),
Expand Down
4 changes: 3 additions & 1 deletion ax/models/tests/test_botorch_moo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,9 @@ def test_BotorchMOOModel_with_qehvi(
"acquisition_function_kwargs": {
"cache_root": False,
"prune_baseline": False,
},
}
if use_qnehvi
else {},
},
)
gen_results = model.gen(
Expand Down
209 changes: 120 additions & 89 deletions ax/models/torch/botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ax.models.torch.utils import _to_inequality_constraints
from ax.models.torch_base import TorchModel
from ax.models.types import TConfig
from ax.utils.common.constants import Keys
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.objective import ConstrainedMCObjective, GenericMCObjective
Expand Down Expand Up @@ -51,6 +50,90 @@
)


def _construct_model(
task_feature: Optional[int],
Xs: List[Tensor],
Ys: List[Tensor],
Yvars: List[Tensor],
fidelity_features: List[int],
metric_names: List[str],
use_input_warping: bool = False,
prior: Optional[Dict[str, Any]] = None,
*,
multitask_gp_ranks: Optional[Dict[str, Union[Prior, float]]] = None,
**kwargs: Any,
) -> GPyTorchModel:
"""
Figures out how to call `_get_model` depending on inputs. Used by
`get_and_fit_model`.
"""
if task_feature is None:
if len(Xs) == 1:
# Use single output, single task GP
return _get_model(
X=Xs[0],
Y=Ys[0],
Yvar=Yvars[0],
task_feature=task_feature,
fidelity_features=fidelity_features,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
if all(torch.equal(Xs[0], X) for X in Xs[1:]) and not use_input_warping:
# Use batched multioutput, single task GP
# Require using a ModelListGP if using input warping
Y = torch.cat(Ys, dim=-1)
Yvar = torch.cat(Yvars, dim=-1)
return _get_model(
X=Xs[0],
Y=Y,
Yvar=Yvar,
task_feature=task_feature,
fidelity_features=fidelity_features,
prior=deepcopy(prior),
**kwargs,
)

if task_feature is None:
models = [
_get_model(
X=X,
Y=Y,
Yvar=Yvar,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
for X, Y, Yvar in zip(Xs, Ys, Yvars)
]
else:
# use multi-task GP
mtgp_rank_dict = {} if multitask_gp_ranks is None else multitask_gp_ranks
# assembles list of ranks associated with each metric
if len({len(Xs), len(Ys), len(Yvars), len(metric_names)}) > 1:
raise ValueError(
"Lengths of Xs, Ys, Yvars, and metric_names must match. Your "
f"inputs have lengths {len(Xs)}, {len(Ys)}, {len(Yvars)}, and "
f"{len(metric_names)}, respectively."
)
mtgp_rank_list = [mtgp_rank_dict.get(metric, None) for metric in metric_names]
models = [
_get_model(
X=X,
Y=Y,
Yvar=Yvar,
task_feature=task_feature,
rank=mtgp_rank,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
for X, Y, Yvar, mtgp_rank in zip(Xs, Ys, Yvars, mtgp_rank_list)
]
return ModelListGP(*models)


def get_and_fit_model(
Xs: List[Tensor],
Ys: List[Tensor],
Expand All @@ -63,6 +146,8 @@ def get_and_fit_model(
use_input_warping: bool = False,
use_loocv_pseudo_likelihood: bool = False,
prior: Optional[Dict[str, Any]] = None,
*,
multitask_gp_ranks: Optional[Dict[str, Union[Prior, float]]] = None,
**kwargs: Any,
) -> GPyTorchModel:
r"""Instantiates and fits a botorch GPyTorchModel using the given data.
Expand All @@ -88,6 +173,7 @@ def get_and_fit_model(
- sd_prior: A scalar prior over nonnegative numbers, which is used for the
default LKJCovariancePrior task_covar_prior.
- eta: The eta parameter on the default LKJ task_covar_prior.
kwargs: Passed to `_get_model`.
Returns:
A fitted GPyTorchModel.
Expand All @@ -109,82 +195,25 @@ def get_and_fit_model(
task_feature = task_features[0]
else:
task_feature = None
model = None

model = _construct_model(
task_feature=task_feature,
Xs=Xs,
Ys=Ys,
Yvars=Yvars,
fidelity_features=fidelity_features,
metric_names=metric_names,
use_input_warping=use_input_warping,
prior=prior,
multitask_gp_ranks=multitask_gp_ranks,
**kwargs,
)

# TODO: Better logic for deciding when to use a ModelListGP. Currently the
# logic is unclear. The two cases in which ModelListGP is used are
# (i) the training inputs (Xs) are not the same for the different outcomes, and
# (ii) a multi-task model is used

if task_feature is None:
if len(Xs) == 1:
# Use single output, single task GP
model = _get_model(
X=Xs[0],
Y=Ys[0],
Yvar=Yvars[0],
task_feature=task_feature,
fidelity_features=fidelity_features,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
elif all(torch.equal(Xs[0], X) for X in Xs[1:]) and not use_input_warping:
# Use batched multioutput, single task GP
# Require using a ModelListGP if using input warping
Y = torch.cat(Ys, dim=-1)
Yvar = torch.cat(Yvars, dim=-1)
model = _get_model(
X=Xs[0],
Y=Y,
Yvar=Yvar,
task_feature=task_feature,
fidelity_features=fidelity_features,
prior=deepcopy(prior),
**kwargs,
)
# TODO: Is this equivalent an "else:" here?

if model is None:
if task_feature is None:
models = [
_get_model(
X=X,
Y=Y,
Yvar=Yvar,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
for X, Y, Yvar in zip(Xs, Ys, Yvars)
]
else:
# use multi-task GP
mtgp_rank_dict = kwargs.pop("multitask_gp_ranks", {})
# assembles list of ranks associated with each metric
if len({len(Xs), len(Ys), len(Yvars), len(metric_names)}) > 1:
raise ValueError(
"Lengths of Xs, Ys, Yvars, and metric_names must match. Your "
f"inputs have lengths {len(Xs)}, {len(Ys)}, {len(Yvars)}, and "
f"{len(metric_names)}, respectively."
)
mtgp_rank_list = [
mtgp_rank_dict.get(metric, None) for metric in metric_names
]
models = [
_get_model(
X=X,
Y=Y,
Yvar=Yvar,
task_feature=task_feature,
rank=mtgp_rank,
use_input_warping=use_input_warping,
prior=deepcopy(prior),
**kwargs,
)
for X, Y, Yvar, mtgp_rank in zip(Xs, Ys, Yvars, mtgp_rank_list)
]
model = ModelListGP(*models)
model.to(Xs[0])
if state_dict is not None:
model.load_state_dict(state_dict)
Expand Down Expand Up @@ -237,7 +266,11 @@ def _get_acquisition_func(
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict` to avoid runtime subscripting errors.
mc_objective_kwargs: Optional[Dict] = None,
**kwargs: Any,
*,
chebyshev_scalarization: bool = False,
prune_baseline: bool = True,
mc_samples: int = 512,
marginalize_dim: Optional[int] = None,
) -> AcquisitionFunction:
r"""Instantiates a acquisition function.
Expand Down Expand Up @@ -266,7 +299,6 @@ def _get_acquisition_func(
For GenericMCObjective, leave it as None. For PenalizedMCObjective,
it needs to be specified in the format of kwargs.
mc_samples: The number of MC samples to use (default: 512).
qmc: If True, use qMC instead of MC (default: True).
prune_baseline: If True, prune the baseline points for NEI (default: True).
chebyshev_scalarization: Use augmented Chebyshev scalarization.
Expand All @@ -276,7 +308,7 @@ def _get_acquisition_func(
if X_observed is None:
raise ValueError(NO_FEASIBLE_POINTS_MESSAGE)
# construct Objective module
if kwargs.get("chebyshev_scalarization", False):
if chebyshev_scalarization:
with torch.no_grad():
Y = model.posterior(X_observed).mean # pyre-ignore [16]
obj_tf = get_chebyshev_scalarization(weights=objective_weights, Y=Y)
Expand Down Expand Up @@ -312,13 +344,12 @@ def objective(samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
objective=objective,
X_observed=X_observed,
X_pending=X_pending,
prune_baseline=kwargs.get("prune_baseline", True),
mc_samples=kwargs.get("mc_samples", 512),
qmc=kwargs.get("qmc", True),
prune_baseline=prune_baseline,
mc_samples=mc_samples,
# pyre-fixme[6]: Expected `Optional[int]` for 9th param but got
# `Union[float, int]`.
seed=torch.randint(1, 10000, (1,)).item(),
marginalize_dim=kwargs.get("marginalize_dim"),
marginalize_dim=marginalize_dim,
)


Expand All @@ -330,7 +361,11 @@ def scipy_optimizer(
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
fixed_features: Optional[Dict[int, float]] = None,
rounding_func: Optional[Callable[[Tensor], Tensor]] = None,
**kwargs: Any,
*,
num_restarts: int = 20,
raw_samples: Optional[int] = None,
joint_optimization: bool = False,
options: Optional[Dict[str, Union[bool, float, int, str]]] = None,
) -> Tuple[Tensor, Tensor]:
r"""Optimizer using scipy's minimize module on a numpy-adpator.
Expand Down Expand Up @@ -360,25 +395,21 @@ def scipy_optimizer(
values, where `i`-th element is the expected acquisition value
conditional on having observed candidates `0,1,...,i-1`.
"""
num_restarts: int = kwargs.pop(Keys.NUM_RESTARTS, 20)
raw_samples: int = kwargs.pop(Keys.RAW_SAMPLES, 50 * num_restarts)

if kwargs.get("joint_optimization", False):
sequential = False
else:
sequential = True
options: Dict[str, Union[bool, float, int, str]] = {
sequential = not joint_optimization
optimize_acqf_options: Dict[str, Union[bool, float, int, str]] = {
"batch_limit": 5,
"init_batch_limit": 32,
}
options.update(kwargs.get("options", {}))
if options is not None:
optimize_acqf_options.update(options)
X, expected_acquisition_value = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
q=n,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=options,
raw_samples=50 * num_restarts if raw_samples is None else raw_samples,
options=optimize_acqf_options,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features=fixed_features,
Expand Down
Loading
Loading