-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2775c96
commit 9c4a8f1
Showing
1 changed file
with
18 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
# | ||
# author: Jungtaek Kim ([email protected]) | ||
# last updated: November 29, 2022 | ||
# author: Jungtaek Kim ([email protected]) | ||
# last updated: May 26, 2024 | ||
# | ||
"""It defines acquisition functions, each of which | ||
is employed to determine where next to evaluate.""" | ||
|
||
import typing | ||
import numpy as np | ||
import scipy.stats | ||
|
||
|
@@ -44,9 +45,9 @@ def pi( | |
assert isinstance(pred_std, np.ndarray) | ||
assert isinstance(Y_train, np.ndarray) | ||
assert isinstance(jitter, float) | ||
assert len(pred_mean.shape) == 1 | ||
assert len(pred_std.shape) == 1 | ||
assert len(Y_train.shape) == 2 | ||
assert pred_mean.ndim == 1 | ||
assert pred_std.ndim == 1 | ||
assert Y_train.ndim == 2 | ||
assert pred_mean.shape[0] == pred_std.shape[0] | ||
|
||
with np.errstate(divide="ignore"): | ||
|
@@ -84,9 +85,9 @@ def ei( | |
assert isinstance(pred_std, np.ndarray) | ||
assert isinstance(Y_train, np.ndarray) | ||
assert isinstance(jitter, float) | ||
assert len(pred_mean.shape) == 1 | ||
assert len(pred_std.shape) == 1 | ||
assert len(Y_train.shape) == 2 | ||
assert pred_mean.ndim == 1 | ||
assert pred_std.ndim == 1 | ||
assert Y_train.ndim == 2 | ||
assert pred_mean.shape[0] == pred_std.shape[0] | ||
|
||
with np.errstate(divide="ignore"): | ||
|
@@ -100,7 +101,7 @@ def ei( | |
def ucb( | ||
pred_mean: np.ndarray, | ||
pred_std: np.ndarray, | ||
Y_train: constants.TYPING_UNION_ARRAY_NONE = None, | ||
Y_train: typing.Union[type(None), np.ndarray] = None, | ||
kappa: float = 2.0, | ||
increase_kappa: bool = True, | ||
) -> np.ndarray: | ||
|
@@ -135,10 +136,10 @@ def ucb( | |
assert isinstance(Y_train, (np.ndarray, type(None))) | ||
assert isinstance(kappa, float) | ||
assert isinstance(increase_kappa, bool) | ||
assert len(pred_mean.shape) == 1 | ||
assert len(pred_std.shape) == 1 | ||
assert pred_mean.ndim == 1 | ||
assert pred_std.ndim == 1 | ||
if Y_train is not None: | ||
assert len(Y_train.shape) == 2 | ||
assert Y_train.ndim == 2 | ||
assert pred_mean.shape[0] == pred_std.shape[0] | ||
|
||
if increase_kappa and Y_train is not None: | ||
|
@@ -182,9 +183,9 @@ def aei( | |
assert isinstance(Y_train, np.ndarray) | ||
assert isinstance(noise, float) | ||
assert isinstance(jitter, float) | ||
assert len(pred_mean.shape) == 1 | ||
assert len(pred_std.shape) == 1 | ||
assert len(Y_train.shape) == 2 | ||
assert pred_mean.ndim == 1 | ||
assert pred_std.ndim == 1 | ||
assert Y_train.ndim == 2 | ||
assert pred_mean.shape[0] == pred_std.shape[0] | ||
|
||
with np.errstate(divide="ignore"): | ||
|
@@ -212,7 +213,7 @@ def pure_exploit(pred_mean: np.ndarray) -> np.ndarray: | |
""" | ||
|
||
assert isinstance(pred_mean, np.ndarray) | ||
assert len(pred_mean.shape) == 1 | ||
assert pred_mean.ndim == 1 | ||
|
||
return -pred_mean | ||
|
||
|
@@ -233,6 +234,6 @@ def pure_explore(pred_std: np.ndarray) -> np.ndarray: | |
""" | ||
|
||
assert isinstance(pred_std, np.ndarray) | ||
assert len(pred_std.shape) == 1 | ||
assert pred_std.ndim == 1 | ||
|
||
return pred_std |