Skip to content

Commit

Permalink
[RLlib] Fix convert to torch tensor (#31023)
Browse files Browse the repository at this point in the history
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst authored Dec 13, 2022
1 parent 76b22f4 commit 8e9fa22
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 1 deletion.
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2109,6 +2109,13 @@ py_test(
srcs = ["utils/exploration/tests/test_random_encoder.py"]
)

py_test(
name = "utils/tests/test_torch_utils",
tags = ["team:rllib", "policutils"],
size = "small",
srcs = ["utils/tests/test_torch_utils.py"]
)

# Schedules
py_test(
name = "test_schedules",
Expand Down
45 changes: 45 additions & 0 deletions rllib/utils/tests/test_torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import unittest

import numpy as np
import torch.cuda

import ray
from ray.rllib.utils.torch_utils import convert_to_torch_tensor


class TestTorchUtils(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()

@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()

def test_convert_to_torch_tensor(self):
# Tests whether convert_to_torch_tensor works as expected

# Test single array
array = np.array([1, 2, 3])
tensor = torch.from_numpy(array)
self.assertTrue(all(convert_to_torch_tensor(array) == tensor))

# Test torch tensor
self.assertTrue(convert_to_torch_tensor(tensor) is tensor)

# Test conversion to 32-bit float
tensor_2 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64)
self.assertTrue(convert_to_torch_tensor(tensor_2).dtype is torch.float32)

# Test nested structure with objects tested above
converted = convert_to_torch_tensor({"a": (array, tensor), "b": tensor_2})
self.assertTrue(all(convert_to_torch_tensor(converted["a"][0]) == tensor))
self.assertTrue(convert_to_torch_tensor(converted["a"][1]) is tensor)
self.assertTrue(convert_to_torch_tensor(converted["b"]).dtype is torch.float32)


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
2 changes: 1 addition & 1 deletion rllib/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def mapping(item):
elif isinstance(item, np.ndarray):
# Object type (e.g. info dicts in train batch): leave as-is.
# str type (e.g. agent_id in train batch): leave as-is.
if item.dtype == object or isinstance(item.dtype, type(np.dtype("str_"))):
if item.dtype == object or item.dtype.type is np.str_:
return item
# Non-writable numpy-arrays will cause PyTorch warning.
elif item.flags.writeable is False:
Expand Down

0 comments on commit 8e9fa22

Please sign in to comment.