-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Learner group checkpointing #34379
Merged
amogkam
merged 13 commits into
ray-project:master
from
avnishn:learner_group_checkpointing
Apr 18, 2023
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
22f4d97
Initial commit:
avnishn 42afc7c
Temp
avnishn bbddd27
Merge branch 'master' of https://github.com/ray-project/ray into lear…
avnishn e2afe87
Working for 1 cpu distributed -- need to test on multinode
avnishn 4217faa
Make learner checkpointing work on multinode
avnishn a485944
Fix bugs
avnishn 6a45d20
Fix load state, make all actors run
avnishn 0c15398
Fix release testing path
avnishn fbce4b0
Merge branch 'master' of https://github.com/ray-project/ray into lear…
avnishn 81d1ba1
Fix script with broken imports
avnishn ee27fcd
Move import from out of tests dir
avnishn 35a7621
Address comments
avnishn 3318485
Merge branch 'master' of https://github.com/ray-project/ray into lear…
avnishn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions
126
release/rllib_tests/checkpointing_tests/test_learner_group_checkpointing.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import gymnasium as gym | ||
import itertools | ||
import numpy as np | ||
import tempfile | ||
import unittest | ||
|
||
import ray | ||
from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig | ||
from ray.rllib.core.testing.utils import get_learner_group | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.utils.test_utils import check | ||
|
||
|
||
FAKE_BATCH = { | ||
SampleBatch.OBS: np.array( | ||
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], | ||
dtype=np.float32, | ||
), | ||
SampleBatch.NEXT_OBS: np.array( | ||
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], | ||
dtype=np.float32, | ||
), | ||
SampleBatch.ACTIONS: np.array([0, 1, 1]), | ||
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]), | ||
SampleBatch.REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32), | ||
SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32), | ||
SampleBatch.TERMINATEDS: np.array([False, False, True]), | ||
SampleBatch.TRUNCATEDS: np.array([False, False, False]), | ||
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32), | ||
SampleBatch.ACTION_DIST_INPUTS: np.array( | ||
[[-2.0, 0.5], [-3.0, -0.3], [-0.1, 2.5]], dtype=np.float32 | ||
), | ||
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32), | ||
SampleBatch.EPS_ID: np.array([0, 0, 0]), | ||
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]), | ||
} | ||
|
||
|
||
REMOTE_SCALING_CONFIGS = { | ||
"remote-cpu": LearnerGroupScalingConfig(num_workers=1), | ||
"remote-gpu": LearnerGroupScalingConfig(num_workers=1, num_gpus_per_worker=1), | ||
"multi-gpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_gpus_per_worker=1), | ||
"multi-cpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_cpus_per_worker=2), | ||
# "multi-gpu-ddp-pipeline": LearnerGroupScalingConfig( | ||
# num_workers=2, num_gpus_per_worker=2 | ||
# ), | ||
} | ||
|
||
|
||
class TestLearnerGroupCheckpointing(unittest.TestCase): | ||
def setUp(self) -> None: | ||
ray.init() | ||
|
||
def tearDown(self) -> None: | ||
ray.shutdown() | ||
|
||
def test_save_load_state(self): | ||
fws = ["tf", "torch"] | ||
scaling_modes = REMOTE_SCALING_CONFIGS.keys() | ||
test_iterator = itertools.product(fws, scaling_modes) | ||
|
||
batch = SampleBatch(FAKE_BATCH) | ||
for fw, scaling_mode in test_iterator: | ||
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.") | ||
env = gym.make("CartPole-v1") | ||
|
||
scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode] | ||
initial_learner_group = get_learner_group( | ||
fw, env, scaling_config, eager_tracing=True | ||
) | ||
|
||
# checkpoint the initial learner state for later comparison | ||
initial_learner_checkpoint_dir = tempfile.TemporaryDirectory().name | ||
initial_learner_group.save_state(initial_learner_checkpoint_dir) | ||
initial_learner_group_weights = initial_learner_group.get_weights() | ||
|
||
# do a single update | ||
initial_learner_group.update(batch.as_multi_agent(), reduce_fn=None) | ||
|
||
# checkpoint the learner state after 1 update for later comparison | ||
learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name | ||
initial_learner_group.save_state(learner_after_1_update_checkpoint_dir) | ||
|
||
# remove that learner, construct a new one, and load the state of the old | ||
# learner into the new one | ||
initial_learner_group.shutdown() | ||
del initial_learner_group | ||
new_learner_group = get_learner_group( | ||
fw, env, scaling_config, eager_tracing=True | ||
) | ||
new_learner_group.load_state(learner_after_1_update_checkpoint_dir) | ||
|
||
# do another update | ||
results_with_break = new_learner_group.update( | ||
batch.as_multi_agent(), reduce_fn=None | ||
) | ||
weights_after_1_update_with_break = new_learner_group.get_weights() | ||
new_learner_group.shutdown() | ||
del new_learner_group | ||
|
||
# construct a new learner group and load the initial state of the learner | ||
learner_group = get_learner_group( | ||
fw, env, scaling_config, eager_tracing=True | ||
) | ||
learner_group.load_state(initial_learner_checkpoint_dir) | ||
check(learner_group.get_weights(), initial_learner_group_weights) | ||
learner_group.update(batch.as_multi_agent(), reduce_fn=None) | ||
results_without_break = learner_group.update( | ||
batch.as_multi_agent(), reduce_fn=None | ||
) | ||
weights_after_1_update_without_break = learner_group.get_weights() | ||
learner_group.shutdown() | ||
del learner_group | ||
|
||
# compare the results of the two updates | ||
check(results_with_break, results_without_break) | ||
check( | ||
weights_after_1_update_with_break, weights_after_1_update_without_break | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import pytest | ||
import sys | ||
|
||
sys.exit(pytest.main(["-v", __file__])) |
22 changes: 22 additions & 0 deletions
22
release/rllib_tests/multi_node_checkpointing_compute_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} | ||
region: us-west-2 | ||
|
||
max_workers: 3 | ||
|
||
head_node_type: | ||
name: head_node | ||
instance_type: m5.2xlarge | ||
|
||
worker_node_types: | ||
- name: worker_node | ||
instance_type: g4dn.xlarge | ||
min_workers: 2 | ||
max_workers: 2 | ||
use_spot: false | ||
|
||
aws: | ||
BlockDeviceMappings: | ||
- DeviceName: /dev/sda1 | ||
Ebs: | ||
DeleteOnTermination: true | ||
VolumeSize: 150 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
from collections import deque | ||
import pathlib | ||
import socket | ||
from typing import Any, List, Mapping, Type, Optional, Callable, Set, TYPE_CHECKING | ||
|
||
import ray | ||
|
@@ -17,6 +19,8 @@ | |
from ray.rllib.utils.typing import ResultDict | ||
from ray.rllib.utils.numpy import convert_to_numpy | ||
from ray.train._internal.backend_executor import BackendExecutor | ||
from ray.tune.utils.file_transfer import sync_dir_between_nodes | ||
|
||
|
||
if TYPE_CHECKING: | ||
from ray.rllib.core.learner.learner import Learner | ||
|
@@ -404,6 +408,134 @@ def set_is_module_trainable( | |
if is_module_trainable is not None: | ||
self._is_module_trainable = is_module_trainable | ||
|
||
def save_state(self, path: str) -> None: | ||
"""Saves the state of the LearnerGroup. | ||
|
||
Args: | ||
path: The path to save the state to. | ||
""" | ||
if self.is_local: | ||
self._learner.save_state(path) | ||
else: | ||
worker = self._worker_manager.healthy_actor_ids()[0] | ||
worker_ip_addr = self._worker_manager.foreach_actor( | ||
self._get_ip_address, remote_actor_ids=[worker] | ||
) | ||
worker_ip_addr = self._get_results(worker_ip_addr)[0] | ||
self_ip_addr = self._get_ip_address() | ||
|
||
if worker_ip_addr == self_ip_addr: | ||
self._worker_manager.foreach_actor( | ||
lambda w: w.save_state(path), remote_actor_ids=[worker] | ||
) | ||
else: | ||
# save the checkpoint to a temporary location on the worker | ||
|
||
# create a temporary directory on the worker | ||
worker_temp_dir = self._worker_manager.foreach_actor( | ||
self._create_temporary_dir, remote_actor_ids=[worker] | ||
) | ||
worker_temp_dir = self._get_results(worker_temp_dir)[0] | ||
|
||
# save the checkpoint to the temporary directory on the worker | ||
self._worker_manager.foreach_actor( | ||
lambda w: w.save_state(worker_temp_dir), remote_actor_ids=[worker] | ||
) | ||
|
||
# sync the temporary directory on the worker to the local directory | ||
sync_dir_between_nodes( | ||
worker_ip_addr, worker_temp_dir, self_ip_addr, path | ||
) | ||
|
||
# creating this function here instead of making it a member funciton | ||
# becasue it uses the worker_temp_dir variable, and this can't | ||
# be passed in as an argument to foreach_actor | ||
def remove_dir(w): | ||
import shutil | ||
|
||
shutil.rmtree(worker_temp_dir) | ||
|
||
# remove the temporary directory on the worker | ||
self._worker_manager.foreach_actor( | ||
remove_dir, remote_actor_ids=[worker] | ||
) | ||
|
||
def load_state(self, path: str) -> None: | ||
"""Loads the state of the LearnerGroup. | ||
|
||
Args: | ||
path: The path to load the state from. | ||
""" | ||
path = pathlib.Path(path) | ||
if not path.is_dir(): | ||
raise ValueError( | ||
f"Path {path} is not a directory. " | ||
"Please specify a directory containing the checkpoint files." | ||
) | ||
if not path.exists(): | ||
raise ValueError(f"Path {path} does not exist.") | ||
path = str(path.absolute()) | ||
assert len(self._workers) == self._worker_manager.num_healthy_actors() | ||
if self.is_local: | ||
self._learner.load_state(path) | ||
else: | ||
head_node_ip = socket.gethostbyname(socket.gethostname()) | ||
workers = self._worker_manager.healthy_actor_ids() | ||
|
||
def _load_state(w): | ||
# doing imports here since they might not be imported on the worker | ||
import socket | ||
import tempfile | ||
|
||
hostname = socket.gethostname() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ray.util.get_node_ip |
||
worker_node_ip = socket.gethostbyname(hostname) | ||
# if the worker is on the same node as the head, load the checkpoint | ||
# directly from the path otherwise sync the checkpoint from the head | ||
# to the worker and load it from there | ||
if worker_node_ip == head_node_ip: | ||
w.load_state(path) | ||
else: | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
sync_dir_between_nodes( | ||
head_node_ip, path, worker_node_ip, temp_dir | ||
) | ||
w.load_state(temp_dir) | ||
|
||
self._worker_manager.foreach_actor(_load_state, remote_actor_ids=workers) | ||
|
||
@staticmethod | ||
def _create_temporary_dir(_=None) -> str: | ||
"""Creates a temporary directory. | ||
|
||
Args: | ||
_: Unused arg. Exists to make this function compatible with foreach_actor | ||
calls. | ||
|
||
Returns: | ||
The path to the temporary directory. | ||
""" | ||
import tempfile | ||
|
||
return tempfile.mkdtemp() | ||
|
||
@staticmethod | ||
def _get_ip_address(_=None) -> str: | ||
"""Returns this process's address. | ||
|
||
Args: | ||
_: Unused arg. Exists to make this function compatible with foreach_actor | ||
calls. | ||
|
||
Returns: | ||
The address of this process. | ||
|
||
""" | ||
import socket | ||
|
||
hostname = socket.gethostname() | ||
|
||
return socket.gethostbyname(hostname) | ||
|
||
def shutdown(self): | ||
"""Shuts down the LearnerGroup.""" | ||
if not self._is_local: | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you make this a member function on Worker as well?
so you can do
lambda w: w.remove_worker_temp_dir()
below.