diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index e319e050650..e24c12050dd 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -566,12 +566,12 @@ def _predict_on_training_data( observations = observations_from_data( experiment=experiment, data=data ) # List[Observation] - observation_features = [obs.features for obs in observations] - # Transform observation features - observation_features = deepcopy(observation_features) + # Transform observations -- this will transform both obs data and features for t in model_bridge.transforms.values(): - observation_features = t.transform_observation_features(observation_features) + observations = t.transform_observations(observations) + + observation_features = [obs.features for obs in observations] # Make predictions in transformed space observation_data_pred = model_bridge._predict(observation_features)