Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recompute mlp #676

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Recompute mlp #676

wants to merge 3 commits into from

Conversation

ngc92
Copy link
Contributor

@ngc92 ngc92 commented Jul 11, 2024

Add the option to recompute the matmul in the MLP part.
Since this output is the largest remaining single tensor, we do get substantial improvements in memory consumption.

With 350M model at batch size 20, activation memory goes from 12042 MiB to 8362 MiB; conversely, on 2x4060Ti, I can increase the batch size from 20 to 28 without OOM. Overall, there is still a slowdown 30ktok/s to 28ktok/s, but
a) we cannot increase gradient accumulation arbitrarily due to bf16 rounding problems, so getting more tokens per fwd/bwd allows larger effective batch sizes
b) for smaller cards/larger models/longer contexts, this might make the difference between being able to run or not

so I think despite the noticeable drop in tok/s, this is an option that we want to have.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant