Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add Optimizer State To Learner get_state #34760

Merged

Conversation

avnishn
Copy link
Member

@avnishn avnishn commented Apr 25, 2023

Signed-off-by: Avnish [email protected]

As apart of todos, add the optimizer state to the learner's get state.

TBF. I don't know who is going to ever need the optimizer state at runtime, other than for testing, but now at least we support it, for completeness.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@@ -852,7 +852,12 @@ def set_state(self, state: Mapping[str, Any]) -> None:
# having both can become confusing. Can we simplify this API requirement?
self._check_is_built()
# TODO: once we figure out the optimizer format, we can set/get the state
self._module.set_state(state.get("module_state", {}))
module_state = state["module_state"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could probably check for the existence of these keys first then error out if necessary.

@@ -36,7 +36,7 @@

LOCAL_SCALING_CONFIGS = {
"local-cpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=0),
"local-gpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=0.5),
"local-gpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=1),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't actually support fractional gpu, so this doesn't matter.

@@ -267,6 +267,25 @@ def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None:
def set_weights(self, weights: Mapping[str, Any]) -> None:
self._module.set_state(weights)

@override(Learner)
def get_optimizer_weights(self) -> Mapping[str, Any]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to find a way to reuse these functions when saving the optimizer state, but its difficult since there is actually little overlap -- when saving the optimizer state, we actually save in native tensorflow format instead of numpy.

@@ -172,6 +172,21 @@ def mapping(item):
return tree.map_structure(mapping, x)


@PublicAPI
def copy_and_move_to_device(x: TensorStructType, device: Optional[str] = None):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using convert_to_torch_tensor when reloading optimizer state dicts, but it was doing something funny that was causing some of my types to get improperly cast, which caused a precision error down the line. Instead I created this function, which probably also deserves its own test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the funny thing something we should fix? Not saying that we should, just asking for your optinion.

@@ -172,6 +172,21 @@ def mapping(item):
return tree.map_structure(mapping, x)


@PublicAPI
def copy_and_move_to_device(x: TensorStructType, device: Optional[str] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add docstrings here to clarify what sort of copy this is?
Also add what happens to items that are not torch Tensors?
Are some of the optimizer weights not torch tensors? Because usually I'd expect this to error out if elements of the tensorstruct are not torch tensors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a docstring, hoping it adds enough clarity :)

Copy link
Contributor

@ArturNiederfahrenhorst ArturNiederfahrenhorst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some nits :) Thanks for the PR!

Copy link
Contributor

@ArturNiederfahrenhorst ArturNiederfahrenhorst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Thanks for the additional util! 😃

@gjoliver gjoliver merged commit 0d59be7 into ray-project:master Apr 28, 2023
ProjectsByJackHe pushed a commit to ProjectsByJackHe/ray that referenced this pull request May 4, 2023
architkulkarni pushed a commit to architkulkarni/ray that referenced this pull request May 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants