Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR does 3 things:
Removes setting the remat policy on the
layers per stage
submodule, instead relying on it being set by only the pipeline iteration (previously it was set in both layers per stage and the pipeline iteration). We saw extra remat when setting the policy twice - currently we see desired behavior with this change + settingscan_layers=False
andscan_pipeline_iterations=True
, but setting both scans to True we see extra memory usage (b/373949783 for more details)Refactor remat_policy=minimal to explicitly list out all desired matmuls (which is all of the activations @ weights, e.g. everything except qkv attention). This should do the same as the previous
jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
, which checks for the presence of a global batch dimension (which are not present in activation @ weights, only in the activations @ activations in qkv). However our implementation our pipeline parallelims uses a vmap over stages, so every matmul will have a global batch dimension of stage, so this policy would actually not save anything, the opposite of its desired effect. In order to switch to this "explicit list" approach this PR also fixes the mlpwi checkpoint names, which were not capturing the matmuls as desired before (this particular checkpoint name was not being used anywhere)Fix casting to float32 in the case of fused_mlp
Tested:
Difficult to add unit tests for these changes outside of our regular perf regression suite. We could assert the old minimal and new are close enough but then we need to keep the old around in the code