diff --git a/bayeso/acquisition.py b/bayeso/acquisition.py index ea153c9..a31f325 100644 --- a/bayeso/acquisition.py +++ b/bayeso/acquisition.py @@ -1,10 +1,11 @@ # -# author: Jungtaek Kim (jtkim@postech.ac.kr) -# last updated: November 29, 2022 +# author: Jungtaek Kim (jungtaek.kim.mail@gmail.com) +# last updated: May 26, 2024 # """It defines acquisition functions, each of which is employed to determine where next to evaluate.""" +import typing import numpy as np import scipy.stats @@ -44,9 +45,9 @@ def pi( assert isinstance(pred_std, np.ndarray) assert isinstance(Y_train, np.ndarray) assert isinstance(jitter, float) - assert len(pred_mean.shape) == 1 - assert len(pred_std.shape) == 1 - assert len(Y_train.shape) == 2 + assert pred_mean.ndim == 1 + assert pred_std.ndim == 1 + assert Y_train.ndim == 2 assert pred_mean.shape[0] == pred_std.shape[0] with np.errstate(divide="ignore"): @@ -84,9 +85,9 @@ def ei( assert isinstance(pred_std, np.ndarray) assert isinstance(Y_train, np.ndarray) assert isinstance(jitter, float) - assert len(pred_mean.shape) == 1 - assert len(pred_std.shape) == 1 - assert len(Y_train.shape) == 2 + assert pred_mean.ndim == 1 + assert pred_std.ndim == 1 + assert Y_train.ndim == 2 assert pred_mean.shape[0] == pred_std.shape[0] with np.errstate(divide="ignore"): @@ -100,7 +101,7 @@ def ei( def ucb( pred_mean: np.ndarray, pred_std: np.ndarray, - Y_train: constants.TYPING_UNION_ARRAY_NONE = None, + Y_train: typing.Union[type(None), np.ndarray] = None, kappa: float = 2.0, increase_kappa: bool = True, ) -> np.ndarray: @@ -135,10 +136,10 @@ def ucb( assert isinstance(Y_train, (np.ndarray, type(None))) assert isinstance(kappa, float) assert isinstance(increase_kappa, bool) - assert len(pred_mean.shape) == 1 - assert len(pred_std.shape) == 1 + assert pred_mean.ndim == 1 + assert pred_std.ndim == 1 if Y_train is not None: - assert len(Y_train.shape) == 2 + assert Y_train.ndim == 2 assert pred_mean.shape[0] == pred_std.shape[0] if increase_kappa and Y_train is not None: @@ -182,9 +183,9 @@ def aei( assert isinstance(Y_train, np.ndarray) assert isinstance(noise, float) assert isinstance(jitter, float) - assert len(pred_mean.shape) == 1 - assert len(pred_std.shape) == 1 - assert len(Y_train.shape) == 2 + assert pred_mean.ndim == 1 + assert pred_std.ndim == 1 + assert Y_train.ndim == 2 assert pred_mean.shape[0] == pred_std.shape[0] with np.errstate(divide="ignore"): @@ -212,7 +213,7 @@ def pure_exploit(pred_mean: np.ndarray) -> np.ndarray: """ assert isinstance(pred_mean, np.ndarray) - assert len(pred_mean.shape) == 1 + assert pred_mean.ndim == 1 return -pred_mean @@ -233,6 +234,6 @@ def pure_explore(pred_std: np.ndarray) -> np.ndarray: """ assert isinstance(pred_std, np.ndarray) - assert len(pred_std.shape) == 1 + assert pred_std.ndim == 1 return pred_std