Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add the option to recompute the matmul in the MLP part.
Since this output is the largest remaining single tensor, we do get substantial improvements in memory consumption.
With 350M model at batch size 20, activation memory goes from 12042 MiB to 8362 MiB; conversely, on 2x4060Ti, I can increase the batch size from 20 to 28 without OOM. Overall, there is still a slowdown 30ktok/s to 28ktok/s, but
a) we cannot increase gradient accumulation arbitrarily due to bf16 rounding problems, so getting more tokens per fwd/bwd allows larger effective batch sizes
b) for smaller cards/larger models/longer contexts, this might make the difference between being able to run or not
so I think despite the noticeable drop in tok/s, this is an option that we want to have.