-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[draft] proposed fix for incorrect mask application in FSDP #1807
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
bfineran
pushed a commit
that referenced
this pull request
Dec 7, 2023
* WIP * WIP trainer refactor * loss updates and removing manager references * WIP generation script * events updating properly * fix structure init * running for text_generation * dataloaders and cleaning up finetuning script * reorganizing and fsdp * fix gradient bug * add fsdp config * clean up for debugging * clean up textgen script * model/recipe save and loading * quality and fixing tests * fix test * fix recipe load * [Finetuning] Model/Recipe reloading and Checkpoints (#1795) * Initial commit * Add end to end tests * Add e2e tests for constant pruning modifier * Move imports inside the test fuctions so that torch isn't imported unless running the tests * Update setup.py to not run modifier tests unless pytorch is specified * [Bugfix] .dict() method on Recipe (#1753) * Bugfix .dict() method on Recipe * Remove extraneous local test, [faulty commit] * [modifier refactor] Add serialization tests (#1755) * Add serialization tests * Clean up * Keep original stage and group names Clean up _get_yaml_dict * fix comment * Typo * [Unit Tests][Modifier Refactor] (#1756) * Move valid recipes to a helper file Add tests for session.py * Increase test coverage of src/sparseml/core/session.py to 100% Run Style Add logs to .gitignore * Increase coverage of tests/sparseml/core/test_state.py to 100% * add tests for lifecycle/event.py * Increase code coverage of lifecycle/event to 100% * increase lifecycle/session.py code coverage to 93% * Address review comments from @Satrat * Address review comments on 1752 (#1772) Update makefile to only ignore *pytorch.py files in modifier dir Fix order in test Add regex to makefile Add helper function to determine if torch tests should be run Check masks Make transformers import optional in sparsegpt.py * Fix merge conflict * Add more tests to check valid modifiers are created (#1774) * [Bug][ConstantPruningModifier] Fix mask de register bug (#1773) * Fix mask de-register logic * forgot to remove commented out line * Move tests inside pytorch directory as requested * Fix session reset (#1790) * save recipe with model * saving/loading/checkpointing * clean up structure initialization * clean up end stages * style * fixing test failures * fix test file --------- Co-authored-by: rahul-tuli <[email protected]> * style * add init for modifiers util * consolidate classes * cleaning up mixin classes and precision callback * specific train/eval fn * clean print statements * Additional Datasets for Finetuning (#1803) * wip support for additional datasets * support for splits and load_dataset args * clean up * c4 and op working with splits * load less data, run faster * [draft] proposed fix for incorrect mask application in FSDP (#1807) * [draft] proposed fix for incorrect mask application in FSDP * fix for multi-gpu * fix for hanging model save * clean up --------- Co-authored-by: Sara Adkins <[email protected]> * clean up logging * adding transformers GHA tests * clean up GHA * clean up GHA * Docstrings + Testing for Finetuning (#1832) * initial commit * docstrings for dataset registry * docstrings for helpers and clean reload_model_state * import fix * session_mixin docstrings * session mixin documentation and CLI hooks * cleaning up CLI calls * WIP unit tests * tests for dataset loading * session mixin unit tests * addressing PR comments * fix unit test * more unit test fixes * Distillation Support for Finetuning (#1865) * initial commit * propogate teacher to modifier * cherrypick distil changes * WIP for distillation loss fixes * WIP fixing distillation * fixing kd_wrapper issues * fixing comparison reference issue * cleanup for PR * more cleanup * fixing finalization sync * fix for saving * update example fsdp * fixing unit tests * update fsdp config * update fsdp config * remove copied function * Misc Finetuning Checkpointing Fixes (#1881) * initial commit * speeding up fsdp, fixing (some) checkpoint bugs ---------
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
in the current layer masking implementation, parameters are stored at initialization and references to them are used to apply masks on modifier update
in FSDP mode, it seems that masking on top of these references does update these parameter references, but these parameters no longer have an effect on the FSDP module.
this fix just implements the simple flow @dsikka used in the sparsify MVP to apply the masks over a reference to the current FSDP module at update time. (ie instead of applying masks on the saved references to the layers, the masks are applied directly over fresh references from the model)
handing off to @Satrat
confirmation of fix
test command:
accelerate launch --config_file fsdp_config.yaml test_trainer.py
snippet of output with sparsity log (previously 0.0 for all sparsity values):
Update 11/1/23
The above fix works for n_gpu=1 but multi-gpu. The latest commit should fix the issue with multi-gpu. Essentially what was happening is we were initializing the model to our SparseSession before it was wrapped by FSDP. To fix this I added a new callback for
on_train_begin
that replaces the session's pytorch model with the FSDP wrapped one.Using
model.apply
as implemented in the initial fix works because FSDP overrides the moduleapply
function, see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.apply.In order to access and update the underlying model outside of apply, we need to use the summon_full_params function, see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params. This fixed the issue with reading out the sparsities after training
We may need to implement this idea in other areas of the codebase.
Remaining things to wrap up: