diff --git a/botorch/models/utils/assorted.py b/botorch/models/utils/assorted.py index fa1e62f30a..fadab8da66 100644 --- a/botorch/models/utils/assorted.py +++ b/botorch/models/utils/assorted.py @@ -167,7 +167,7 @@ def check_min_max_scaling( msg = "contained" if msg is not None: msg = ( - f"Input data is not {msg} to the unit cube. " + f"Data (input features) not {msg} to the unit cube. " "Please consider min-max scaling the input data." ) if raise_on_fail: @@ -197,7 +197,7 @@ def check_standardization( if Y.shape[-2] <= 1: if mean_not_zero: msg = ( - f"Data is not standardized (mean = {Ymean}). " + f"Data (outcome observations) not standardized (mean = {Ymean}). " "Please consider scaling the input to zero mean and unit variance." ) if raise_on_fail: @@ -208,7 +208,8 @@ def check_standardization( std_not_one = torch.abs(Ystd - 1).max() > atol_std if mean_not_zero or std_not_one: msg = ( - f"Data is not standardized (std = {Ystd}, mean = {Ymean}). " + "Data (outcome observations) not standardized " + f"(std = {Ystd}, mean = {Ymean})." "Please consider scaling the input to zero mean and unit variance." ) if raise_on_fail: diff --git a/test/models/utils/test_assorted.py b/test/models/utils/test_assorted.py index 459893363e..21533f762f 100644 --- a/test/models/utils/test_assorted.py +++ b/test/models/utils/test_assorted.py @@ -158,14 +158,14 @@ def test_check_standardization(self): check_standardization(Y=y, raise_on_fail=True) # check nonzero mean for case where >= 2 observations per batch - msg_more_than_1_obs = r"Data is not standardized \(std =" + msg_more_than_1_obs = r"Data \(outcome observations\) not standardized \(std =" with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs): check_standardization(Y=Yst + 1) with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs): check_standardization(Y=Yst + 1, raise_on_fail=True) # check nonzero mean for case where < 2 observations per batch - msg_one_obs = r"Data is not standardized \(mean =" + msg_one_obs = r"Data \(outcome observations\) not standardized \(mean =" y = torch.ones((3, 1, 2), dtype=torch.float32) with self.assertWarnsRegex(InputDataWarning, msg_one_obs): check_standardization(Y=y)