-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Add Optimizer State To Learner get_state #34760
[RLlib] Add Optimizer State To Learner get_state #34760
Conversation
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
rllib/core/learner/learner.py
Outdated
@@ -852,7 +852,12 @@ def set_state(self, state: Mapping[str, Any]) -> None: | |||
# having both can become confusing. Can we simplify this API requirement? | |||
self._check_is_built() | |||
# TODO: once we figure out the optimizer format, we can set/get the state | |||
self._module.set_state(state.get("module_state", {})) | |||
module_state = state["module_state"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could probably check for the existence of these keys first then error out if necessary.
@@ -36,7 +36,7 @@ | |||
|
|||
LOCAL_SCALING_CONFIGS = { | |||
"local-cpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=0), | |||
"local-gpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=0.5), | |||
"local-gpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't actually support fractional gpu, so this doesn't matter.
@@ -267,6 +267,25 @@ def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None: | |||
def set_weights(self, weights: Mapping[str, Any]) -> None: | |||
self._module.set_state(weights) | |||
|
|||
@override(Learner) | |||
def get_optimizer_weights(self) -> Mapping[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to find a way to reuse these functions when saving the optimizer state, but its difficult since there is actually little overlap -- when saving the optimizer state, we actually save in native tensorflow format instead of numpy.
rllib/utils/torch_utils.py
Outdated
@@ -172,6 +172,21 @@ def mapping(item): | |||
return tree.map_structure(mapping, x) | |||
|
|||
|
|||
@PublicAPI | |||
def copy_and_move_to_device(x: TensorStructType, device: Optional[str] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried using convert_to_torch_tensor when reloading optimizer state dicts, but it was doing something funny that was causing some of my types to get improperly cast, which caused a precision error down the line. Instead I created this function, which probably also deserves its own test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the funny thing something we should fix? Not saying that we should, just asking for your optinion.
rllib/utils/torch_utils.py
Outdated
@@ -172,6 +172,21 @@ def mapping(item): | |||
return tree.map_structure(mapping, x) | |||
|
|||
|
|||
@PublicAPI | |||
def copy_and_move_to_device(x: TensorStructType, device: Optional[str] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add docstrings here to clarify what sort of copy this is?
Also add what happens to items that are not torch Tensors?
Are some of the optimizer weights not torch tensors? Because usually I'd expect this to error out if elements of the tensorstruct are not torch tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a docstring, hoping it adds enough clarity :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some nits :) Thanks for the PR!
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
…optimizer_state_to_learner_get_state
Signed-off-by: avnishn <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Thanks for the additional util! 😃
Signed-off-by: Avnish <[email protected]> Signed-off-by: Jack He <[email protected]>
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish [email protected]
As apart of todos, add the optimizer state to the learner's get state.
TBF. I don't know who is going to ever need the optimizer state at runtime, other than for testing, but now at least we support it, for completeness.
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.