diff --git a/python/ray/ml/predictors/integrations/torch/torch_predictor.py b/python/ray/ml/predictors/integrations/torch/torch_predictor.py index 58c4593fc674..21d6d7052387 100644 --- a/python/ray/ml/predictors/integrations/torch/torch_predictor.py +++ b/python/ray/ml/predictors/integrations/torch/torch_predictor.py @@ -127,22 +127,19 @@ def predict( import numpy as np import torch - from ray.ml.predictors.torch import TorchPredictor + from ray.ml.predictors.integrations.torch import TorchPredictor - model = torch.nn.Linear(1, 1) + model = torch.nn.Linear(2, 1) predictor = TorchPredictor(model=model) data = np.array([[1, 2], [3, 4]]) predictions = predictor.predict(data) - # Only use first column as the feature - predictions = predictor.predict(data, feature_columns=[0]) - .. code-block:: python import pandas as pd import torch - from ray.ml.predictors.torch import TorchPredictor + from ray.ml.predictors.integrations.torch import TorchPredictor model = torch.nn.Linear(1, 1) predictor = TorchPredictor(model=model) @@ -155,7 +152,6 @@ def predict( # Only use first column as the feature predictions = predictor.predict(data, feature_columns=["A"]) - Returns: DataBatchType: Prediction result. """ @@ -165,10 +161,10 @@ def predict( data = self.preprocessor.transform_batch(data) if isinstance(data, np.ndarray): - # If numpy array, then convert to pandas dataframe. - data = pd.DataFrame(data) + tensor = torch.tensor(data, dtype=dtype) + else: + tensor = self._convert_to_tensor( + data, feature_columns=feature_columns, dtypes=dtype, unsqueeze=unsqueeze + ) - tensor = self._convert_to_tensor( - data, feature_columns=feature_columns, dtypes=dtype, unsqueeze=unsqueeze - ) return self._predict(tensor) diff --git a/python/ray/ml/tests/test_torch_predictor.py b/python/ray/ml/tests/test_torch_predictor.py index 95596a577e75..7aebc7b34c9f 100644 --- a/python/ray/ml/tests/test_torch_predictor.py +++ b/python/ray/ml/tests/test_torch_predictor.py @@ -1,6 +1,7 @@ import pytest import numpy as np +import pandas as pd import torch from ray.ml.predictors.integrations.torch import TorchPredictor @@ -50,7 +51,7 @@ def test_predict_model_not_training(model): assert not predictor.model.training -def test_predict_no_preprocessor(model): +def test_predict_array(model): predictor = TorchPredictor(model=model) data_batch = np.array([[1], [2], [3]]) @@ -60,7 +61,7 @@ def test_predict_no_preprocessor(model): assert predictions.to_numpy().flatten().tolist() == [2, 4, 6] -def test_predict_with_preprocessor(model, preprocessor): +def test_predict_array_with_preprocessor(model, preprocessor): predictor = TorchPredictor(model=model, preprocessor=preprocessor) data_batch = np.array([[1], [2], [3]]) @@ -70,31 +71,48 @@ def test_predict_with_preprocessor(model, preprocessor): assert predictions.to_numpy().flatten().tolist() == [4, 8, 12] -def test_predict_array_output(model): - """Tests if predictor works if model outputs an array instead of single value.""" +def test_predict_dataframe(): + predictor = TorchPredictor(model=torch.nn.Linear(2, 1, bias=False)) - predictor = TorchPredictor(model=model) - - data_batch = np.array([[1, 1], [2, 2], [3, 3]]) - predictions = predictor.predict(data_batch) + data_batch = pd.DataFrame({"X0": [0.0, 0.0, 0.0], "X1": [0.0, 0.0, 0.0]}) + predictions = predictor.predict(data_batch, dtype=torch.float) assert len(predictions) == 3 - assert np.array_equal( - predictions.to_numpy().flatten().tolist(), [[2, 2], [4, 4], [6, 6]] + assert predictions.to_numpy().flatten().tolist() == [0.0, 0.0, 0.0] + + +@pytest.mark.parametrize( + ("input_dtype", "expected_output_dtype"), + ( + (torch.float16, np.float16), + (torch.float64, np.float64), + (torch.int32, np.int32), + (torch.int64, np.int64), + ), +) +def test_predict_array_with_different_dtypes(input_dtype, expected_output_dtype): + predictor = TorchPredictor(model=torch.nn.Identity()) + + data_batch = np.array([[1], [2], [3]]) + predictions = predictor.predict(data_batch, dtype=input_dtype) + + assert all( + prediction.dtype == expected_output_dtype + for prediction in predictions["predictions"] ) -def test_predict_feature_columns(model): - predictor = TorchPredictor(model=model) +def test_predict_dataframe_with_feature_columns(): + predictor = TorchPredictor(model=torch.nn.Identity()) - data_batch = np.array([[1, 4], [2, 5], [3, 6]]) - predictions = predictor.predict(data_batch, feature_columns=[0]) + data_batch = pd.DataFrame({"X0": [0.0, 0.0, 0.0], "X1": [1.0, 1.0, 1.0]}) + predictions = predictor.predict(data_batch, feature_columns=["X0"]) assert len(predictions) == 3 - assert predictions.to_numpy().flatten().tolist() == [2, 4, 6] + assert predictions.to_numpy().flatten().tolist() == [0.0, 0.0, 0.0] -def test_predict_from_checkpoint_no_preprocessor(model): +def test_predict_array_from_checkpoint(model): checkpoint = Checkpoint.from_dict({MODEL_KEY: model}) predictor = TorchPredictor.from_checkpoint(checkpoint)