Skip to content

Commit

Permalink
[RLlib] Fix large batch size for synchronous algos after EnvRunner fa…
Browse files Browse the repository at this point in the history
…ilures. (#47356)
  • Loading branch information
sven1977 authored Aug 27, 2024
1 parent 68811a5 commit c67fb76
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 60 deletions.
15 changes: 12 additions & 3 deletions rllib/execution/rollout_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def synchronous_parallel_sample(
agent_or_env_steps += sum(
int(stat_dict[NUM_ENV_STEPS_SAMPLED]) for stat_dict in stats_dicts
)
sample_batches_or_episodes.extend(sampled_data)
all_stats_dicts.extend(stats_dicts)
else:
for batch_or_episode in sampled_data:
if max_agent_steps:
Expand All @@ -154,9 +156,16 @@ def synchronous_parallel_sample(
if _uses_new_env_runners
else batch_or_episode.env_steps()
)
sample_batches_or_episodes.extend(sampled_data)
if _return_metrics:
all_stats_dicts.extend(stats_dicts)
sample_batches_or_episodes.append(batch_or_episode)
# Break out (and ignore the remaining samples) if max timesteps (batch
# size) reached. We want to avoid collecting batches that are too large
# only because of a failed/restarted worker causing a second iteration
# of the main loop.
if (
max_agent_or_env_steps is not None
and agent_or_env_steps >= max_agent_or_env_steps
):
break

if concat is True:
# If we have episodes flatten the episode list.
Expand Down
121 changes: 64 additions & 57 deletions rllib/tests/test_node_failure.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
# This workload tests RLlib's ability to recover from failing workers nodes
import time
import unittest

import ray
from ray._private.test_utils import get_other_nodes
from ray.cluster_utils import Cluster
from ray.util.state import list_actors
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
LEARNER_RESULTS,
)


num_redis_shards = 5
redis_max_memory = 10**8
object_store_memory = 10**8
num_nodes = 3


assert (
num_nodes * object_store_memory + num_redis_shards * redis_max_memory
< ray._private.utils.get_system_memory() / 2
Expand All @@ -24,7 +26,7 @@
)


class NodeFailureTests(unittest.TestCase):
class TestNodeFailures(unittest.TestCase):
def setUp(self):
# Simulate a cluster on one machine.
self.cluster = Cluster()
Expand All @@ -46,69 +48,74 @@ def tearDown(self):
ray.shutdown()
self.cluster.shutdown()

def test_continue_training_on_failure(self):
# We tolerate failing workers and pause training
def test_continue_training_on_env_runner_node_failures(self):
# We tolerate failing workers and pause training.
config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.env_runners(
num_env_runners=6,
validate_env_runners_after_construction=True,
)
.fault_tolerance(recreate_failed_env_runners=True)
.training(
train_batch_size=300,
.fault_tolerance(
ignore_env_runner_failures=True,
recreate_failed_env_runners=True,
)
)
ppo = PPO(config=config)

# One step with all nodes up, enough to satisfy resource requirements
ppo.train()

self.assertEqual(ppo.env_runner_group.num_healthy_remote_workers(), 6)
self.assertEqual(ppo.env_runner_group.num_remote_workers(), 6)

# Remove the first non-head node.
node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
self.cluster.remove_node(node_to_kill)
algo = config.build()

# step() should continue with 4 rollout workers.
ppo.train()
best_return = 0.0
for i in range(40):
results = algo.train()
print(f"ITER={i} results={results}")

self.assertEqual(ppo.env_runner_group.num_healthy_remote_workers(), 4)
self.assertEqual(ppo.env_runner_group.num_remote_workers(), 6)

# node comes back immediately.
self.cluster.add_node(
redis_port=None,
num_redis_shards=None,
num_cpus=2,
num_gpus=0,
object_store_memory=object_store_memory,
redis_max_memory=redis_max_memory,
dashboard_host="0.0.0.0",
)

# Now, let's wait for Ray to restart all the RolloutWorker actors.
while True:
states = [
a["state"] == "ALIVE"
for a in list_actors()
if a["class_name"] == "RolloutWorker"
best_return = max(
best_return, results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
)
avg_batch = results[LEARNER_RESULTS][DEFAULT_MODULE_ID][
"module_train_batch_size_mean"
]
if all(states):
break
# Otherwise, wait a bit.
time.sleep(1)

# This step should continue with 4 workers, but by the end
# of weight syncing, the 2 recovered rollout workers should
# be back.
ppo.train()

# Workers should be back up, everything back to normal.
self.assertEqual(ppo.env_runner_group.num_healthy_remote_workers(), 6)
self.assertEqual(ppo.env_runner_group.num_remote_workers(), 6)
self.assertGreaterEqual(avg_batch, config.total_train_batch_size)
self.assertLess(
avg_batch,
config.total_train_batch_size + config.get_rollout_fragment_length(),
)

self.assertEqual(algo.env_runner_group.num_remote_workers(), 6)
healthy_env_runners = algo.env_runner_group.num_healthy_remote_workers()
# After node has been removed, we expect 2 workers to be gone.
if (i - 1) % 5 == 0:
self.assertEqual(healthy_env_runners, 4)
# Otherwise, all workers should be there (but might still be in the process
# of coming up).
else:
self.assertIn(healthy_env_runners, [4, 5, 6])

# print(f"healthy workers = {algo.env_runner_group.healthy_worker_ids()}")
# Shut down one node every n iterations.
if i % 5 == 0:
to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
print(f"Killing node {to_kill} ...")
self.cluster.remove_node(to_kill)

# Bring back a previously failed node.
elif (i - 1) % 5 == 0:
print("Bringing back node ...")
self.cluster.add_node(
redis_port=None,
num_redis_shards=None,
num_cpus=2,
num_gpus=0,
object_store_memory=object_store_memory,
redis_max_memory=redis_max_memory,
dashboard_host="0.0.0.0",
)

self.assertGreaterEqual(best_return, 450.0)


if __name__ == "__main__":
Expand Down

0 comments on commit c67fb76

Please sign in to comment.