-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Learner group checkpointing #34379
[RLlib] Learner group checkpointing #34379
Conversation
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
…ner_group_checkpointing
Signed-off-by: Avnish <[email protected]>
- stop creating multiple distributed tf strategies - add multinode release test for checkpointing Signed-off-by: avnishn <[email protected]>
Signed-off-by: avnishn <[email protected]>
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
…ner_group_checkpointing
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
rllib/core/rl_module/rl_module.py
Outdated
@@ -609,3 +609,7 @@ def as_multi_agent(self) -> "MultiAgentRLModule": | |||
marl_module = MultiAgentRLModule() | |||
marl_module.add_module(DEFAULT_POLICY_ID, self) | |||
return marl_module | |||
|
|||
def unwrapped(self) -> "RLModule": | |||
"""Returns the underlying module if this module is a wrapper.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you specify, what wrapper here means?
Like what are examples for RLModule wrappers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch rl modules get wrapped with the torch ddp rl module wrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
from ray.rllib.env.multi_agent_env import make_multi_agent | ||
from ray.rllib.utils.test_utils import check | ||
|
||
|
||
DEFAULT_POLICY_ID = "default_policy" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! :)
@@ -26,6 +25,9 @@ | |||
Optimizer = Union["tf.keras.optimizers.Optimizer", "torch.optim.Optimizer"] | |||
|
|||
|
|||
DEFAULT_POLICY_ID = "default_policy" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we import this from policy.py
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to avoid mixing policy code in the new stack.
rllib/core/learner/tf/tf_learner.py
Outdated
# the default strategy is a no-op that can be used in the local mode | ||
# cpu only case, build will override this if needed. | ||
self._strategy = tf.distribute.get_strategy() | ||
self._strategy = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you leave the comment on what self._strategy
is (or should be when not None)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strategy is a tf distributed strategy object that is used for the ddp logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a param notation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rllib/core/learner/tf/tf_learner.py
Outdated
@@ -349,6 +347,25 @@ def remove_module(self, module_id: ModuleID) -> None: | |||
if self._enable_tf_function: | |||
self._update_fn = tf.function(self._do_update_fn, reduce_retracing=True) | |||
|
|||
def _make_distributed_strategy(self): | |||
"""Create a distributed strategy for the learner.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, can you add a little more explanation here on what a "strategy" is and which types exist (an example?)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strategy is a tf distributed strategy object.
The different types of strategies are contained within the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -483,7 +483,7 @@ def from_module(self, module: MultiAgentRLModule) -> "MultiAgentRLModuleSpec": | |||
The MultiAgentRLModuleSpec. | |||
""" | |||
module_specs = { | |||
module_id: SingleAgentRLModuleSpec.from_module(rl_module) | |||
module_id: SingleAgentRLModuleSpec.from_module(rl_module.unwrapped()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Explain why we need to unwrap here. rl_module could be a framework-specific DDP wrapper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
@@ -19,7 +25,7 @@ | |||
|
|||
REMOTE_SCALING_CONFIGS = { | |||
"remote-cpu": LearnerGroupScalingConfig(num_workers=1), | |||
"remote-gpu": LearnerGroupScalingConfig(num_workers=1, num_gpus_per_worker=0.5), | |||
"remote-gpu": LearnerGroupScalingConfig(num_workers=1, num_gpus_per_worker=1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did we change this? Would it break if we used fractional GPUs here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this learner/group actually won't even take fractional gpus. so it is pointless. I changed it while I was doing some debugging
learner_group.load_state(initial_learner_checkpoint_dir) | ||
check(learner_group.get_weights(), initial_learner_group_weights) | ||
learner_group.update(batch.as_multi_agent(), reduce_fn=None) | ||
results_without_break = learner_group.update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we check here again to see whether the weights after one update (based off the initial state) are the same as the weights of the original learner (after one update)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome PR @avnish! Thanks for covering this important feature in our release tests from here on.
Just a few nits, questions, and suggestions for better comments.
Signed-off-by: Avnish <[email protected]>
…ner_group_checkpointing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me talk to you offline about how you intend to use this?
def remove_dir(w): | ||
import shutil | ||
|
||
shutil.rmtree(worker_temp_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you make this a member function on Worker as well?
so you can do lambda w: w.remove_worker_temp_dir()
below.
import socket | ||
import tempfile | ||
|
||
hostname = socket.gethostname() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ray.util.get_node_ip
Implement multinode learner group checkpointing and tests. --------- Signed-off-by: Avnish <[email protected]> Signed-off-by: avnishn <[email protected]> Signed-off-by: elliottower <[email protected]>
Implement multinode learner group checkpointing and tests. --------- Signed-off-by: Avnish <[email protected]> Signed-off-by: avnishn <[email protected]> Signed-off-by: Jack He <[email protected]>
Signed-off-by: Avnish [email protected]
Implement multinode learner group checkpointing and tests.
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.