From 8e9fa223d35ec87fd33a8eb0fe706bc6a77dcf23 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 13 Dec 2022 22:15:04 +0100 Subject: [PATCH] [RLlib] Fix convert to torch tensor (#31023) Signed-off-by: Artur Niederfahrenhorst --- rllib/BUILD | 7 +++++ rllib/utils/tests/test_torch_utils.py | 45 +++++++++++++++++++++++++++ rllib/utils/torch_utils.py | 2 +- 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 rllib/utils/tests/test_torch_utils.py diff --git a/rllib/BUILD b/rllib/BUILD index 4b54bd27e8eb..8b4e8ba9c783 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2109,6 +2109,13 @@ py_test( srcs = ["utils/exploration/tests/test_random_encoder.py"] ) +py_test( + name = "utils/tests/test_torch_utils", + tags = ["team:rllib", "policutils"], + size = "small", + srcs = ["utils/tests/test_torch_utils.py"] +) + # Schedules py_test( name = "test_schedules", diff --git a/rllib/utils/tests/test_torch_utils.py b/rllib/utils/tests/test_torch_utils.py new file mode 100644 index 000000000000..83baf344bc3a --- /dev/null +++ b/rllib/utils/tests/test_torch_utils.py @@ -0,0 +1,45 @@ +import unittest + +import numpy as np +import torch.cuda + +import ray +from ray.rllib.utils.torch_utils import convert_to_torch_tensor + + +class TestTorchUtils(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init() + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def test_convert_to_torch_tensor(self): + # Tests whether convert_to_torch_tensor works as expected + + # Test single array + array = np.array([1, 2, 3]) + tensor = torch.from_numpy(array) + self.assertTrue(all(convert_to_torch_tensor(array) == tensor)) + + # Test torch tensor + self.assertTrue(convert_to_torch_tensor(tensor) is tensor) + + # Test conversion to 32-bit float + tensor_2 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64) + self.assertTrue(convert_to_torch_tensor(tensor_2).dtype is torch.float32) + + # Test nested structure with objects tested above + converted = convert_to_torch_tensor({"a": (array, tensor), "b": tensor_2}) + self.assertTrue(all(convert_to_torch_tensor(converted["a"][0]) == tensor)) + self.assertTrue(convert_to_torch_tensor(converted["a"][1]) is tensor) + self.assertTrue(convert_to_torch_tensor(converted["b"]).dtype is torch.float32) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/utils/torch_utils.py b/rllib/utils/torch_utils.py index 7f59a7a9e62a..98e437736e83 100644 --- a/rllib/utils/torch_utils.py +++ b/rllib/utils/torch_utils.py @@ -149,7 +149,7 @@ def mapping(item): elif isinstance(item, np.ndarray): # Object type (e.g. info dicts in train batch): leave as-is. # str type (e.g. agent_id in train batch): leave as-is. - if item.dtype == object or isinstance(item.dtype, type(np.dtype("str_"))): + if item.dtype == object or item.dtype.type is np.str_: return item # Non-writable numpy-arrays will cause PyTorch warning. elif item.flags.writeable is False: