Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Learner group checkpointing #34379

Merged
merged 13 commits into from
Apr 18, 2023
21 changes: 21 additions & 0 deletions release/release_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2995,6 +2995,27 @@
# RLlib tests
########################

- name: rllib_learner_group_checkpointing_multinode
group: RLlib tests
working_dir: rllib_tests

frequency: nightly
team: rllib

cluster:
cluster_env: app_config.yaml
cluster_compute: multi_node_checkpointing_compute_config.yaml

run:
timeout: 3600
script: pytest checkpointing_tests/test_learner_group_checkpointing.py

wait_for_nodes:
num_nodes: 3

alert: default


- name: rllib_learning_tests_a2c_tf
group: RLlib tests
working_dir: rllib_tests
Expand Down
4 changes: 4 additions & 0 deletions release/rllib_tests/app_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ python:
# so we built it for py3 and use that instead. This wheel was tested for python 3.7, 3.8,
# and 3.9.
- https://ray-ci-deps-wheels.s3.us-west-2.amazonaws.com/AutoROM.accept_rom_license-0.5.4-py3-none-any.whl
- pytest
conda_packages: []

post_build_cmds:
Expand All @@ -41,3 +42,6 @@ post_build_cmds:
- mv mujoco210-linux-x86_64.tar.gz ~/.mujoco/.
- cd ~/.mujoco
- tar -xf ~/.mujoco/mujoco210-linux-x86_64.tar.gz

# not strictly necessary, but makes debugging easier
- git clone https://github.com/ray-project/ray.git
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import gymnasium as gym
import itertools
import numpy as np
import tempfile
import unittest

import ray
from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig
from ray.rllib.core.testing.utils import get_learner_group
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.test_utils import check


FAKE_BATCH = {
SampleBatch.OBS: np.array(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
dtype=np.float32,
),
SampleBatch.NEXT_OBS: np.array(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
dtype=np.float32,
),
SampleBatch.ACTIONS: np.array([0, 1, 1]),
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]),
SampleBatch.REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
SampleBatch.TERMINATEDS: np.array([False, False, True]),
SampleBatch.TRUNCATEDS: np.array([False, False, False]),
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32),
SampleBatch.ACTION_DIST_INPUTS: np.array(
[[-2.0, 0.5], [-3.0, -0.3], [-0.1, 2.5]], dtype=np.float32
),
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32),
SampleBatch.EPS_ID: np.array([0, 0, 0]),
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
}


REMOTE_SCALING_CONFIGS = {
"remote-cpu": LearnerGroupScalingConfig(num_workers=1),
"remote-gpu": LearnerGroupScalingConfig(num_workers=1, num_gpus_per_worker=1),
"multi-gpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_gpus_per_worker=1),
"multi-cpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_cpus_per_worker=2),
# "multi-gpu-ddp-pipeline": LearnerGroupScalingConfig(
# num_workers=2, num_gpus_per_worker=2
# ),
}


class TestLearnerGroupCheckpointing(unittest.TestCase):
def setUp(self) -> None:
ray.init()

def tearDown(self) -> None:
ray.shutdown()

def test_save_load_state(self):
fws = ["tf", "torch"]
scaling_modes = REMOTE_SCALING_CONFIGS.keys()
test_iterator = itertools.product(fws, scaling_modes)

batch = SampleBatch(FAKE_BATCH)
for fw, scaling_mode in test_iterator:
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
env = gym.make("CartPole-v1")

scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode]
initial_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)

# checkpoint the initial learner state for later comparison
initial_learner_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(initial_learner_checkpoint_dir)
initial_learner_group_weights = initial_learner_group.get_weights()

# do a single update
initial_learner_group.update(batch.as_multi_agent(), reduce_fn=None)

# checkpoint the learner state after 1 update for later comparison
learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(learner_after_1_update_checkpoint_dir)

# remove that learner, construct a new one, and load the state of the old
# learner into the new one
initial_learner_group.shutdown()
del initial_learner_group
new_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
new_learner_group.load_state(learner_after_1_update_checkpoint_dir)

# do another update
results_with_break = new_learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_with_break = new_learner_group.get_weights()
new_learner_group.shutdown()
del new_learner_group

# construct a new learner group and load the initial state of the learner
learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
learner_group.load_state(initial_learner_checkpoint_dir)
check(learner_group.get_weights(), initial_learner_group_weights)
learner_group.update(batch.as_multi_agent(), reduce_fn=None)
results_without_break = learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_without_break = learner_group.get_weights()
learner_group.shutdown()
del learner_group

# compare the results of the two updates
check(results_with_break, results_without_break)
check(
weights_after_1_update_with_break, weights_after_1_update_without_break
)


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
22 changes: 22 additions & 0 deletions release/rllib_tests/multi_node_checkpointing_compute_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}}
region: us-west-2

max_workers: 3

head_node_type:
name: head_node
instance_type: m5.2xlarge

worker_node_types:
- name: worker_node
instance_type: g4dn.xlarge
min_workers: 2
max_workers: 2
use_spot: false

aws:
BlockDeviceMappings:
- DeviceName: /dev/sda1
Ebs:
DeleteOnTermination: true
VolumeSize: 150
132 changes: 132 additions & 0 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections import deque
import pathlib
import socket
from typing import Any, List, Mapping, Type, Optional, Callable, Set, TYPE_CHECKING

import ray
Expand All @@ -17,6 +19,8 @@
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.numpy import convert_to_numpy
from ray.train._internal.backend_executor import BackendExecutor
from ray.tune.utils.file_transfer import sync_dir_between_nodes


if TYPE_CHECKING:
from ray.rllib.core.learner.learner import Learner
Expand Down Expand Up @@ -404,6 +408,134 @@ def set_is_module_trainable(
if is_module_trainable is not None:
self._is_module_trainable = is_module_trainable

def save_state(self, path: str) -> None:
"""Saves the state of the LearnerGroup.

Args:
path: The path to save the state to.
"""
if self.is_local:
self._learner.save_state(path)
else:
worker = self._worker_manager.healthy_actor_ids()[0]
worker_ip_addr = self._worker_manager.foreach_actor(
self._get_ip_address, remote_actor_ids=[worker]
)
worker_ip_addr = self._get_results(worker_ip_addr)[0]
self_ip_addr = self._get_ip_address()

if worker_ip_addr == self_ip_addr:
self._worker_manager.foreach_actor(
lambda w: w.save_state(path), remote_actor_ids=[worker]
)
else:
# save the checkpoint to a temporary location on the worker

# create a temporary directory on the worker
worker_temp_dir = self._worker_manager.foreach_actor(
self._create_temporary_dir, remote_actor_ids=[worker]
)
worker_temp_dir = self._get_results(worker_temp_dir)[0]

# save the checkpoint to the temporary directory on the worker
self._worker_manager.foreach_actor(
lambda w: w.save_state(worker_temp_dir), remote_actor_ids=[worker]
)

# sync the temporary directory on the worker to the local directory
sync_dir_between_nodes(
worker_ip_addr, worker_temp_dir, self_ip_addr, path
)

# creating this function here instead of making it a member funciton
# becasue it uses the worker_temp_dir variable, and this can't
# be passed in as an argument to foreach_actor
def remove_dir(w):
import shutil

shutil.rmtree(worker_temp_dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make this a member function on Worker as well?
so you can do lambda w: w.remove_worker_temp_dir() below.


# remove the temporary directory on the worker
self._worker_manager.foreach_actor(
remove_dir, remote_actor_ids=[worker]
)

def load_state(self, path: str) -> None:
"""Loads the state of the LearnerGroup.

Args:
path: The path to load the state from.
"""
path = pathlib.Path(path)
if not path.is_dir():
raise ValueError(
f"Path {path} is not a directory. "
"Please specify a directory containing the checkpoint files."
)
if not path.exists():
raise ValueError(f"Path {path} does not exist.")
path = str(path.absolute())
assert len(self._workers) == self._worker_manager.num_healthy_actors()
if self.is_local:
self._learner.load_state(path)
else:
head_node_ip = socket.gethostbyname(socket.gethostname())
workers = self._worker_manager.healthy_actor_ids()

def _load_state(w):
# doing imports here since they might not be imported on the worker
import socket
import tempfile

hostname = socket.gethostname()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ray.util.get_node_ip

worker_node_ip = socket.gethostbyname(hostname)
# if the worker is on the same node as the head, load the checkpoint
# directly from the path otherwise sync the checkpoint from the head
# to the worker and load it from there
if worker_node_ip == head_node_ip:
w.load_state(path)
else:
with tempfile.TemporaryDirectory() as temp_dir:
sync_dir_between_nodes(
head_node_ip, path, worker_node_ip, temp_dir
)
w.load_state(temp_dir)

self._worker_manager.foreach_actor(_load_state, remote_actor_ids=workers)

@staticmethod
def _create_temporary_dir(_=None) -> str:
"""Creates a temporary directory.

Args:
_: Unused arg. Exists to make this function compatible with foreach_actor
calls.

Returns:
The path to the temporary directory.
"""
import tempfile

return tempfile.mkdtemp()

@staticmethod
def _get_ip_address(_=None) -> str:
"""Returns this process's address.

Args:
_: Unused arg. Exists to make this function compatible with foreach_actor
calls.

Returns:
The address of this process.

"""
import socket

hostname = socket.gethostname()

return socket.gethostbyname(hostname)

def shutdown(self):
"""Shuts down the LearnerGroup."""
if not self._is_local:
Expand Down
Loading