Skip to content

Commit

Permalink
[RLlib] Add Optimizer State To Learner get_state (#34760)
Browse files Browse the repository at this point in the history
Signed-off-by: Avnish <[email protected]>
  • Loading branch information
avnishn authored Apr 28, 2023
1 parent 7401b39 commit 0d59be7
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 12 deletions.
4 changes: 2 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2283,8 +2283,8 @@ py_test(

py_test(
name = "utils/tests/test_torch_utils",
tags = ["team:rllib", "utils"],
size = "small",
tags = ["team:rllib", "utils", "gpu"],
size = "medium",
srcs = ["utils/tests/test_torch_utils.py"]
)

Expand Down
43 changes: 40 additions & 3 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,14 +849,29 @@ def set_state(self, state: Mapping[str, Any]) -> None:
Args:
state: The state of the optimizer and module. Can be obtained
from `get_state`.
from `get_state`. State is a dictionary with two keys:
"module_state" and "optimizer_state". The value of each key
is a dictionary that can be passed to `set_weights` and
`set_optimizer_weights` respectively.
"""
# TODO (Kourosh): We have both get(set)_state and get(set)_weights. I think
# having both can become confusing. Can we simplify this API requirement?
self._check_is_built()
# TODO: once we figure out the optimizer format, we can set/get the state
self._module.set_state(state.get("module_state", {}))
if "module_state" not in state:
raise ValueError(
"state must have a key 'module_state' for the module weights"
)
if "optimizer_state" not in state:
raise ValueError(
"state must have a key 'optimizer_state' for the optimizer weights"
)

module_state = state.get("module_state")
optimizer_state = state.get("optimizer_state")
self.set_weights(module_state)
self.set_optimizer_weights(optimizer_state)

def get_state(self) -> Mapping[str, Any]:
"""Get the state of the learner.
Expand All @@ -867,7 +882,29 @@ def get_state(self) -> Mapping[str, Any]:
"""
self._check_is_built()
# TODO: once we figure out the optimizer format, we can set/get the state
return {"module_state": self._module.get_state()}
return {
"module_state": self.get_weights(),
"optimizer_state": self.get_optimizer_weights(),
}
# return {"module_state": self.get_weights(), "optimizer_state": {}}

def set_optimizer_weights(self, weights: Mapping[str, Any]) -> None:
"""Set the weights of the optimizer.
Args:
weights: The weights of the optimizer.
"""
raise NotImplementedError

def get_optimizer_weights(self) -> Mapping[str, Any]:
"""Get the weights of the optimizer.
Returns:
The weights of the optimizer.
"""
raise NotImplementedError

def _get_metadata(self) -> Dict[str, Any]:
metadata = {
Expand Down
27 changes: 25 additions & 2 deletions rllib/core/learner/tests/test_learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

LOCAL_SCALING_CONFIGS = {
"local-cpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=0),
"local-gpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=0.5),
"local-gpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=1),
}


Expand All @@ -45,6 +45,17 @@
@ray.remote(num_gpus=1)
class RemoteTrainingHelper:
def local_training_helper(self, fw, scaling_mode) -> None:
if fw == "torch":
import torch

torch.manual_seed(0)
elif fw == "tf":
import tensorflow as tf

# this is done by rllib already inside of the policy class, but we need to
# do it here for testing purposes
tf.compat.v1.enable_eager_execution()
tf.random.set_seed(0)
env = gym.make("CartPole-v1")
scaling_config = LOCAL_SCALING_CONFIGS[scaling_mode]
lr = 1e-3
Expand All @@ -71,13 +82,25 @@ def local_training_helper(self, fw, scaling_mode) -> None:

# make the state of the learner and the local learner_group identical
local_learner.set_state(learner_group.get_state())
# learner_group.set_state(learner_group.get_state())
check(local_learner.get_state(), learner_group.get_state())

# do another update
batch = reader.next()
ma_batch = MultiAgentBatch(
{new_module_id: batch, DEFAULT_POLICY_ID: batch}, env_steps=batch.count
)
check(local_learner.update(ma_batch), learner_group.update(ma_batch))
# the optimizer state is not initialized fully until the first time that
# training is completed. A call to get state before that won't contain the
# optimizer state. So we do a dummy update here to initialize the optimizer
local_learner.update(ma_batch)
learner_group.update(ma_batch)

check(local_learner.get_state(), learner_group.get_state())
local_learner_results = local_learner.update(ma_batch)
learner_group_results = learner_group.update(ma_batch)

check(local_learner_results, learner_group_results)

check(local_learner.get_state(), learner_group.get_state())

Expand Down
19 changes: 19 additions & 0 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,25 @@ def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None:
def set_weights(self, weights: Mapping[str, Any]) -> None:
self._module.set_state(weights)

@override(Learner)
def get_optimizer_weights(self) -> Mapping[str, Any]:
optim_weights = {}
with tf.init_scope():
for name, optim in self._named_optimizers.items():
optim_weights[name] = [var.numpy() for var in optim.variables()]
return optim_weights

@override(Learner)
def set_optimizer_weights(self, weights: Mapping[str, Any]) -> None:
for name, weight_array in weights.items():
if name not in self._named_optimizers:
raise ValueError(
f"Optimizer {name} in weights is not known."
f"Known optimizers are {self._named_optimizers.keys()}"
)
optim = self._named_optimizers[name]
optim.set_weights(weight_array)

@override(Learner)
def get_param_ref(self, param: ParamType) -> Hashable:
return param.ref()
Expand Down
38 changes: 34 additions & 4 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchDDPRLModule
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.torch_utils import clip_gradients, convert_to_torch_tensor
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.torch_utils import (
clip_gradients,
convert_to_torch_tensor,
copy_torch_tensors,
)
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()
Expand Down Expand Up @@ -119,16 +123,42 @@ def set_weights(self, weights: Mapping[str, Any]) -> None:
def _save_optimizers(self, path: Union[str, pathlib.Path]) -> None:
path = pathlib.Path(path)
path.mkdir(parents=True, exist_ok=True)
for name, optim in self._named_optimizers.items():
torch.save(optim.state_dict(), path / f"{name}.pt")
optim_weights = self.get_optimizer_weights()
for name, weights in optim_weights.items():
torch.save(weights, path / f"{name}.pt")

@override(Learner)
def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None:
path = pathlib.Path(path)
if not path.exists():
raise ValueError(f"Directory {path} does not exist.")
weights = {}
for name in self._named_optimizers.keys():
weights[name] = torch.load(path / f"{name}.pt")
self.set_optimizer_weights(weights)

@override(Learner)
def get_optimizer_weights(self) -> Mapping[str, Any]:
optimizer_name_weights = {}
for name, optim in self._named_optimizers.items():
optim.load_state_dict(torch.load(path / f"{name}.pt"))
optim_state_dict = optim.state_dict()
optim_state_dict_cpu = copy_torch_tensors(optim_state_dict, device="cpu")
optimizer_name_weights[name] = optim_state_dict_cpu
return optimizer_name_weights

@override(Learner)
def set_optimizer_weights(self, weights: Mapping[str, Any]) -> None:
for name, weight_dict in weights.items():
if name not in self._named_optimizers:
raise ValueError(
f"Optimizer {name} in weights is not known."
f"Known optimizers are {self._named_optimizers.keys()}"
)
optim = self._named_optimizers[name]
weight_dict_correct_device = copy_torch_tensors(
weight_dict, device=self._device
)
optim.load_state_dict(weight_dict_correct_device)

@override(Learner)
def get_param_ref(self, param: ParamType) -> Hashable:
Expand Down
53 changes: 52 additions & 1 deletion rllib/utils/tests/test_torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import torch.cuda

import ray
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.torch_utils import (
convert_to_torch_tensor,
copy_torch_tensors,
)


class TestTorchUtils(unittest.TestCase):
Expand Down Expand Up @@ -43,6 +46,54 @@ def test_convert_to_torch_tensor(self):
self.assertTrue(converted["b"].dtype is torch.float32)
self.assertTrue(converted["c"] is None)

def test_copy_torch_tensors(self):
array = np.array([1, 2, 3], dtype=np.float32)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tensor = torch.from_numpy(array).to(device)
tensor_2 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64).to(device)

# Test single tensor
copied_tensor = copy_torch_tensors(tensor, device)
self.assertTrue(copied_tensor.device == device)
self.assertNotEqual(id(copied_tensor), id(tensor))
self.assertTrue(all(copied_tensor == tensor))

# check that dtypes aren't modified
copied_tensor_2 = copy_torch_tensors(tensor_2, device)
self.assertTrue(copied_tensor_2.dtype == tensor_2.dtype)
self.assertFalse(copied_tensor_2.dtype == torch.float32)

# Test nested structure can be converted
nested_structure = {"a": tensor, "b": tensor_2, "c": 1}
copied_nested_structure = copy_torch_tensors(nested_structure, device)
self.assertTrue(copied_nested_structure["a"].device == device)
self.assertTrue(copied_nested_structure["b"].device == device)
self.assertTrue(copied_nested_structure["c"] == 1)
self.assertNotEqual(id(copied_nested_structure["a"]), id(tensor))
self.assertNotEqual(id(copied_nested_structure["b"]), id(tensor_2))
self.assertTrue(all(copied_nested_structure["a"] == tensor))
self.assertTrue(all(copied_nested_structure["b"] == tensor_2))

# if gpu is available test moving tensor from cpu to gpu and vice versa
if torch.cuda.is_available():
tensor = torch.from_numpy(array).to("cpu")
copied_tensor = copy_torch_tensors(tensor, "cuda:0")
self.assertFalse(copied_tensor.device == torch.device("cpu"))
self.assertTrue(copied_tensor.device == torch.device("cuda:0"))
self.assertNotEqual(id(copied_tensor), id(tensor))
self.assertTrue(
all(copied_tensor.detach().cpu().numpy() == tensor.detach().numpy())
)

tensor = torch.from_numpy(array).to("cuda:0")
copied_tensor = copy_torch_tensors(tensor, "cpu")
self.assertFalse(copied_tensor.device == torch.device("cuda:0"))
self.assertTrue(copied_tensor.device == torch.device("cpu"))
self.assertNotEqual(id(copied_tensor), id(tensor))
self.assertTrue(
all(copied_tensor.detach().numpy() == tensor.detach().cpu().numpy())
)


if __name__ == "__main__":
import pytest
Expand Down
31 changes: 31 additions & 0 deletions rllib/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,37 @@ def mapping(item):
return tree.map_structure(mapping, x)


@PublicAPI
def copy_torch_tensors(x: TensorStructType, device: Optional[str] = None):
"""Creates a copy of `x` and makes deep copies torch.Tensors in x.
Also moves the copied tensors to the specified device (if not None).
Note if an object in x is not a torch.Tensor, it will be shallow-copied.
Args:
x : Any (possibly nested) struct possibly containing torch.Tensors.
device : The device to move the tensors to.
Returns:
Any: A new struct with the same structure as `x`, but with all
torch.Tensors deep-copied and moved to the specified device.
"""

def mapping(item):
if isinstance(item, torch.Tensor):
return (
torch.clone(item.detach())
if device is None
else item.detach().to(device)
)
else:
return item

return tree.map_structure(mapping, x)


@PublicAPI
def explained_variance(y: TensorType, pred: TensorType) -> TensorType:
"""Computes the explained variance for a pair of labels and predictions.
Expand Down

0 comments on commit 0d59be7

Please sign in to comment.