Skip to content

Commit

Permalink
Update acquisition.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jungtaekkim committed May 27, 2024
1 parent 2775c96 commit 9c4a8f1
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions bayeso/acquisition.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#
# author: Jungtaek Kim ([email protected])
# last updated: November 29, 2022
# author: Jungtaek Kim ([email protected])
# last updated: May 26, 2024
#
"""It defines acquisition functions, each of which
is employed to determine where next to evaluate."""

import typing
import numpy as np
import scipy.stats

Expand Down Expand Up @@ -44,9 +45,9 @@ def pi(
assert isinstance(pred_std, np.ndarray)
assert isinstance(Y_train, np.ndarray)
assert isinstance(jitter, float)
assert len(pred_mean.shape) == 1
assert len(pred_std.shape) == 1
assert len(Y_train.shape) == 2
assert pred_mean.ndim == 1
assert pred_std.ndim == 1
assert Y_train.ndim == 2
assert pred_mean.shape[0] == pred_std.shape[0]

with np.errstate(divide="ignore"):
Expand Down Expand Up @@ -84,9 +85,9 @@ def ei(
assert isinstance(pred_std, np.ndarray)
assert isinstance(Y_train, np.ndarray)
assert isinstance(jitter, float)
assert len(pred_mean.shape) == 1
assert len(pred_std.shape) == 1
assert len(Y_train.shape) == 2
assert pred_mean.ndim == 1
assert pred_std.ndim == 1
assert Y_train.ndim == 2
assert pred_mean.shape[0] == pred_std.shape[0]

with np.errstate(divide="ignore"):
Expand All @@ -100,7 +101,7 @@ def ei(
def ucb(
pred_mean: np.ndarray,
pred_std: np.ndarray,
Y_train: constants.TYPING_UNION_ARRAY_NONE = None,
Y_train: typing.Union[type(None), np.ndarray] = None,
kappa: float = 2.0,
increase_kappa: bool = True,
) -> np.ndarray:
Expand Down Expand Up @@ -135,10 +136,10 @@ def ucb(
assert isinstance(Y_train, (np.ndarray, type(None)))
assert isinstance(kappa, float)
assert isinstance(increase_kappa, bool)
assert len(pred_mean.shape) == 1
assert len(pred_std.shape) == 1
assert pred_mean.ndim == 1
assert pred_std.ndim == 1
if Y_train is not None:
assert len(Y_train.shape) == 2
assert Y_train.ndim == 2
assert pred_mean.shape[0] == pred_std.shape[0]

if increase_kappa and Y_train is not None:
Expand Down Expand Up @@ -182,9 +183,9 @@ def aei(
assert isinstance(Y_train, np.ndarray)
assert isinstance(noise, float)
assert isinstance(jitter, float)
assert len(pred_mean.shape) == 1
assert len(pred_std.shape) == 1
assert len(Y_train.shape) == 2
assert pred_mean.ndim == 1
assert pred_std.ndim == 1
assert Y_train.ndim == 2
assert pred_mean.shape[0] == pred_std.shape[0]

with np.errstate(divide="ignore"):
Expand Down Expand Up @@ -212,7 +213,7 @@ def pure_exploit(pred_mean: np.ndarray) -> np.ndarray:
"""

assert isinstance(pred_mean, np.ndarray)
assert len(pred_mean.shape) == 1
assert pred_mean.ndim == 1

return -pred_mean

Expand All @@ -233,6 +234,6 @@ def pure_explore(pred_std: np.ndarray) -> np.ndarray:
"""

assert isinstance(pred_std, np.ndarray)
assert len(pred_std.shape) == 1
assert pred_std.ndim == 1

return pred_std

0 comments on commit 9c4a8f1

Please sign in to comment.