Fix resuming from checkpoint when using RayFSDPStrategy #43594
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Why are these changes needed?
Restoring from a checkpoint when using FSDP is currently flawed as the state_dict keys for each layer get modified and torch can not associate the weights to the layer name when loading. The current implementation always assumes that the layer keys in the state dict are prefixed with
_forward_module.
and then slices the key based on the length of the prefix.The underlying reason why we remove the
_forward_module.
is unclear to me but we should check if it is prefixed before removing. This is implemented in this PR and fixes the loading of checkpoints when usingRayFSDPStrategy
The following is an example of the wrong state_dict keys for a checkpoint:
Correct keys: