Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pipeline rematerialization #971

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gobbleturk
Copy link
Collaborator

@gobbleturk gobbleturk commented Oct 20, 2024

This PR does 3 things:

  1. Removes setting the remat policy on the layers per stage submodule, instead relying on it being set by only the pipeline iteration (previously it was set in both layers per stage and the pipeline iteration). We saw extra remat when setting the policy twice - currently we see desired behavior with this change + setting scan_layers=False and scan_pipeline_iterations=True, but setting both scans to True we see extra memory usage (b/373949783 for more details)

  2. Refactor remat_policy=minimal to explicitly list out all desired matmuls (which is all of the activations @ weights, e.g. everything except qkv attention). This should do the same as the previous jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, which checks for the presence of a global batch dimension (which are not present in activation @ weights, only in the activations @ activations in qkv). However our implementation our pipeline parallelims uses a vmap over stages, so every matmul will have a global batch dimension of stage, so this policy would actually not save anything, the opposite of its desired effect. In order to switch to this "explicit list" approach this PR also fixes the mlpwi checkpoint names, which were not capturing the matmuls as desired before (this particular checkpoint name was not being used anywhere)

  3. Fix casting to float32 in the case of fused_mlp

Tested:

  • For the general (not minimal specific) PP rematting, see b/373949783 where we confirm expected bwd pass time and memory usage
  • For minimal policy refactor
    • Before this PR with minimal remat XPROF 26GB HBM peak, 1.033 step time
    • After this PR with minimal remat XPROF 26GB HBM 1.017 step time
    • They are not identical but very close, I am not sure how they differ
      Difficult to add unit tests for these changes outside of our regular perf regression suite. We could assert the old minimal and new are close enough but then we need to keep the old around in the code

@gobbleturk gobbleturk force-pushed the mattdavidow-fix-pipeline-remat branch from 364f5e4 to bb33d78 Compare October 20, 2024 19:17
@jonb377
Copy link
Collaborator

jonb377 commented Oct 21, 2024

setting both scans to True we see extra memory usage

I'm curious, is the main fix to avoid the extra remat changing RemattedBlockLayer -> BlockLayer or just disabling scan_layers?

I see in the bug, even with BlockLayer we're still doing extra compute in the bwd with scan_layers set (while.94).

@gobbleturk
Copy link
Collaborator Author

gobbleturk commented Oct 21, 2024

setting both scans to True we see extra memory usage

I'm curious, is the main fix to avoid the extra remat changing RemattedBlockLayer -> BlockLayer or just disabling scan_layers?

I see in the bug, even with BlockLayer we're still doing extra compute in the bwd with scan_layers set (while.94).

Both RemattedBlockLayer -> BlockLayer and scan_layers=False are necessary for optimal performance, however only making either one of these changes does improve performance somewhat compared to making no changes

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing this Matt! Very interesting behavior with multiple scans/remats.

Long-term, we want to get back to double-scan at least, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants