Skip to content

Commit

Permalink
[RLlib] Fixed bug in restoring a gpu trained algorithm (#35024)
Browse files Browse the repository at this point in the history
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
  • Loading branch information
kouroshHakha authored May 8, 2023
1 parent 386e395 commit 67706f9
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 9 deletions.
29 changes: 29 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2458,6 +2458,16 @@ py_test(
args = ["TestCheckpointRestorePG"]
)


py_test(
name = "tests/test_checkpoint_restore_pg_gpu",
main = "tests/test_algorithm_checkpoint_restore.py",
tags = ["team:rllib", "tests_dir", "gpu"],
size = "large",
srcs = ["tests/test_algorithm_checkpoint_restore.py"],
args = ["TestCheckpointRestorePG"]
)

py_test(
name = "tests/test_checkpoint_restore_off_policy",
main = "tests/test_algorithm_checkpoint_restore.py",
Expand All @@ -2467,6 +2477,16 @@ py_test(
args = ["TestCheckpointRestoreOffPolicy"]
)


py_test(
name = "tests/test_checkpoint_restore_off_policy_gpu",
main = "tests/test_algorithm_checkpoint_restore.py",
tags = ["team:rllib", "tests_dir", "gpu"],
size = "large",
srcs = ["tests/test_algorithm_checkpoint_restore.py"],
args = ["TestCheckpointRestoreOffPolicy"]
)

py_test(
name = "tests/test_checkpoint_restore_evolution_algos",
main = "tests/test_algorithm_checkpoint_restore.py",
Expand All @@ -2476,6 +2496,15 @@ py_test(
args = ["TestCheckpointRestoreEvolutionAlgos"]
)

py_test(
name = "tests/test_checkpoint_restore_evolution_algos_gpu",
main = "tests/test_algorithm_checkpoint_restore.py",
tags = ["team:rllib", "tests_dir", "gpu"],
size = "medium",
srcs = ["tests/test_algorithm_checkpoint_restore.py"],
args = ["TestCheckpointRestoreEvolutionAlgos"]
)

py_test(
name = "policy/tests/test_policy_checkpoint_restore",
main = "policy/tests/test_policy_checkpoint_restore.py",
Expand Down
8 changes: 7 additions & 1 deletion rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,13 @@ def set_state(self, state: PolicyState) -> None:
if optimizer_vars:
assert len(optimizer_vars) == len(self._optimizers)
for o, s in zip(self._optimizers, optimizer_vars):
optim_state_dict = convert_to_torch_tensor(s, device=self.device)
# Torch optimizer param_groups include things like beta, etc. These
# parameters should be left as scalar and not converted to tensors.
# otherwise, torch.optim.step() will start to complain.
optim_state_dict = {"param_groups": s["param_groups"]}
optim_state_dict["state"] = convert_to_torch_tensor(
s["state"], device=self.device
)
o.load_state_dict(optim_state_dict)
# Set exploration's state.
if hasattr(self, "exploration") and "_exploration_state" in state:
Expand Down
8 changes: 7 additions & 1 deletion rllib/policy/torch_policy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,13 @@ def set_state(self, state: PolicyState) -> None:
if optimizer_vars:
assert len(optimizer_vars) == len(self._optimizers)
for o, s in zip(self._optimizers, optimizer_vars):
optim_state_dict = convert_to_torch_tensor(s, device=self.device)
# Torch optimizer param_groups include things like beta, etc. These
# parameters should be left as scalar and not converted to tensors.
# otherwise, torch.optim.step() will start to complain.
optim_state_dict = {"param_groups": s["param_groups"]}
optim_state_dict["state"] = convert_to_torch_tensor(
s["state"], device=self.device
)
o.load_state_dict(optim_state_dict)
# Set exploration's state.
if hasattr(self, "exploration") and "_exploration_state" in state:
Expand Down
40 changes: 33 additions & 7 deletions rllib/tests/test_algorithm_checkpoint_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ray.rllib.algorithms.ars import ARSConfig
from ray.rllib.algorithms.a3c import A3CConfig
from ray.tune.registry import get_trainable_cls
import os


def get_mean_action(alg, obs):
Expand All @@ -32,7 +33,12 @@ def get_mean_action(alg, obs):
# explore=None if we compare the mean of the distribution of actions for the
# same observation to be the same.
algorithms_and_configs = {
"A3C": (A3CConfig().exploration(explore=False).rollouts(num_rollout_workers=1)),
"A3C": (
A3CConfig()
.exploration(explore=False)
.rollouts(num_rollout_workers=1)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"APEX_DDPG": (
ApexDDPGConfig()
.exploration(explore=False)
Expand All @@ -42,51 +48,65 @@ def get_mean_action(alg, obs):
optimizer={"num_replay_buffer_shards": 1},
num_steps_sampled_before_learning_starts=0,
)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"ARS": (
ARSConfig()
.exploration(explore=False)
.rollouts(num_rollout_workers=2, observation_filter="MeanStdFilter")
.training(num_rollouts=10, noise_size=2500000)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"DDPG": (
DDPGConfig()
.exploration(explore=False)
.reporting(min_sample_timesteps_per_iteration=100)
.training(num_steps_sampled_before_learning_starts=0)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"DQN": (
DQNConfig()
.exploration(explore=False)
.training(num_steps_sampled_before_learning_starts=0)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"ES": (
ESConfig()
.exploration(explore=False)
.training(episodes_per_batch=10, train_batch_size=100, noise_size=2500000)
.rollouts(observation_filter="MeanStdFilter", num_rollout_workers=2)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"PPO": (
# See the comment before the `algorithms_and_configs` dict.
# explore is set to None for PPO in favor of RLModule API support.
PPOConfig()
.training(num_sgd_iter=5, train_batch_size=1000)
.rollouts(num_rollout_workers=2)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"SimpleQ": (
SimpleQConfig()
.exploration(explore=False)
.training(num_steps_sampled_before_learning_starts=0)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
"SAC": (
SACConfig()
.exploration(explore=False)
.training(num_steps_sampled_before_learning_starts=0)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
),
}


def ckpt_restore_test(algo_name, tf2=False, object_store=False, replay_buffer=False):
def ckpt_restore_test(
algo_name,
tf2=False,
object_store=False,
replay_buffer=False,
run_restored_algorithm=True,
):
config = algorithms_and_configs[algo_name].to_dict()
# If required, store replay buffer data in checkpoints as well.
if replay_buffer:
Expand Down Expand Up @@ -172,22 +192,28 @@ def ckpt_restore_test(algo_name, tf2=False, object_store=False, replay_buffer=Fa
raise AssertionError(
"algo={} [a1={} a2={}]".format(algo_name, a1, a2)
)
# Stop both algos.
# Stop algo 1.
alg1.stop()

if run_restored_algorithm:
# Check that algo 2 can still run.
print("Starting second run on Algo 2...")
alg2.train()
alg2.stop()


class TestCheckpointRestorePG(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=5)
ray.init()

@classmethod
def tearDownClass(cls):
ray.shutdown()

def test_a3c_checkpoint_restore(self):
ckpt_restore_test("A3C")
# TODO(Kourosh) A3C cannot run a restored algorithm for some reason.
ckpt_restore_test("A3C", run_restored_algorithm=False)

def test_ppo_checkpoint_restore(self):
ckpt_restore_test("PPO", object_store=True)
Expand All @@ -196,7 +222,7 @@ def test_ppo_checkpoint_restore(self):
class TestCheckpointRestoreOffPolicy(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=5)
ray.init()

@classmethod
def tearDownClass(cls):
Expand All @@ -221,7 +247,7 @@ def test_simpleq_checkpoint_restore(self):
class TestCheckpointRestoreEvolutionAlgos(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=5)
ray.init()

@classmethod
def tearDownClass(cls):
Expand Down

0 comments on commit 67706f9

Please sign in to comment.