-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] First draft for softcapping. #1025
Conversation
You can add these to the setup.py to reduce compilation time:
|
Oh nice missed them ! |
I think softcapping should be done before the masking. i.e. the sequence is gemm, softcapping, masking, then softmax. |
softcapping can be fused with dividing by softmax_scale. |
There is still the wdyt ? |
Sorry I missed the tanh. Yeah we should template to avoid slowing down the usual attention. |
Ok I put the template for My understandling is that
Happens right when I was already doing it (but after gemm, before masking since tanh would throw -inf to -50). flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcapping){
cute::transform(acc_s, softcapping_op{params.softcapping_scale});
} And then I updated if (softcapping.has_value()){
params.is_softcapping = true;
params.softcapping_scale = softmax_scale / softcapping.value();
params.scale_softmax_log2 = softcapping.value() * M_LOG2E;
params.scale_softmax = softcapping.value();
}else{
params.is_softcapping = false;
params.softcapping_scale = 1.0;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
params.scale_softmax = softmax_scale;
} |
My understanding is that
There is never partial sum over in the gemm right ? (meaning gemm being split along the Do you have anything better to suggest than printf debugging for this ? |
@Narsil oh yea, your code looks way better than what i have lmao, let's just go with your changes so eventually you'll have to expose a customizable |
@Narsil backwards pass looks something like this in naive pytorch - and done another way here. both examples are functional i understand things are 10-20x harder translating it to CUDA |
@Narsil great job getting the ball rolling 🙏 |
Thanks, but this is not functional at all. (I think using the |
@Narsil i think you are really close! |
@Narsil yea, it is really hard to contribute with these compilation times. how fast were you able to get the times down to? i'm still waiting up to 10 minutes per change edit: it is 26 minutes, just timed it, let me work on cutting that down rest of the day, or this is impossible to tinker with edit 2: got it down to 2 minutes! 😄 edit 3: think i got the forward working! a lot of it was just following your lead with edit 4: i'll just move onto the backwards pass, since i think you have it in the bag with the forwards. |
@lucidrains Can't see your changes anywhere, are you on a branch somewhere ? I got compilation times down to 1mn but results seems wrong with them (meaning the regular non softcapped flash fail). To reduce compilation times I uncomment the compilation flags in The I also limit the amount of hidden size to only 1, meaning, comment the associated kernels in |
@Narsil nice, i got it to around the same ballpark! unfortunately was working off a runpod that went down before i can push the changes 😞. but the good news is that i was luckily able to get backwards pass working as well 🥳 i compared some of the changes for the forward hoping to catch what was wrong in your diff, but couldn't find a difference. in fact i looked to your changes first and used the way you did the |
Nice doing the backward ! Can I try your branch ? (My local changes were exactly this branch). |
@Narsil lost my changes, but will work on restoring it later today (traveling with dog for the 4th, American holiday) yes, i think you are probably just off by some scale, error could even be in your tests |
i also want to do a separate PR just to make it easy for contributors to get started (specify a few hyperparameters, and enable/disable booleans for those flags, and only those kernels get compiled and tested) that ended up being the hardest part of all this |
Hello, I was separately working on adding softcapping to FA for Gemma 2. But seeing this PR, I'll add my WIP repo here if needed for double-checking: https://github.com/Shreya-Pathak/flash-attention. It is using the same idea as discussed above. The forward seemed correct (read passing tests) when I was checking but let me know if you see some issues. |
@Shreya-Pathak looks great Shreya! i actually prefer your way for the interface, with boolean flag to turn on with a default softcapping scale (realistically the public won't know the right value, so it should just default to something proven and working) But Nicolas' way with the |
regardless, i think the forward issue is done for, and Gemma2 inference will be viable soon i'll PR in the backwards pass once one of you get your changes in. there's actually two ways to approach backwards, and i'm not sure which is the right way, so may need to consult Tri |
Hi @Shreya-Pathak, I looked at your changes. Personally I prefer the single flag (less things to know for users, and softcapping is unlikely to have a good default. Gemma2 uses 50, not 30 for instance: https://huggingface.co/google/gemma-2-9b-it/blob/main/config.json#L7). What you did was the first iteration I did. And tri mentionned that we could do smarter things here: #1025 (comment) Basically currently FA applies the scale during the exponentiation of the softmax (to exploit intrisics that can fuse both ops in a single instruction). Therefore we should be able to get away by using the softmax_scale / softcap directly insinde the tanh, and only replace the current softmax_scales with Doesn't seem to be working for me though. |
@Narsil ah, good to know that Gemma 2 used so when you say things don't work, it is the changes you made to account for Tri's suggestion? your first iteration worked? |
@Narsil ok, i'm sure you'll figure it out. @Shreya-Pathak can probably help too ping me when you want me to throw up the backwards pass |
@Narsil I think I have also done what Tri mentioned with the softcap / softmax scale and from a brief look at your code, you seem to be doing the same as well. Could you give more details about what tests are failing and what the differences are? |
I think we should pass to tanh with softclapping_scale = softmax_scale (e.g. 1/sqrt(headdim) / softcap_val (e.g. 50.0). Then in the softmax, instead of passing in softmax_scale_log2, we should pass in softcap_val * log2(e). Overall we're doing exp2(log2(e) * 50 * tanh(acc_s * softmax_scale / 50)). |
OK it's updated and now working. I have no idea why but |
that's strange! i used ah no matter, all roads lead to rome |
Do you want to patch this PR for the backward ? |
@Narsil let's land this PR first for the forwards, as I may do two separate PRs for backwards and let Tri pick the one that makes sense |
@tridao do you know what's missing to merge ? Should I run the CI manually maybe (without backward since it's not implemented yet.) |
@Narsil for backwards, just assert out or throw an error if soft capping is turned on otherwise great job! @Shreya-Pathak too! |
thanks a lot @tridao |
The released build wheels are still dated May 27. |
@@ -639,6 +659,7 @@ def backward(ctx, dout, *args): | |||
ctx.softmax_scale, | |||
ctx.causal, | |||
ctx.window_size, | |||
ctx.softcap |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small typo here, missing a comma
@@ -556,6 +572,7 @@ def backward(ctx, dout, *args): | |||
ctx.softmax_scale, | |||
ctx.causal, | |||
ctx.window_size, | |||
ctx.softcap |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small typo here, missing a comma
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this for me too. there is a second missing comma further down. also my first compile produced gibberish, am trying again with updated cutlass. it has successfully compiled from the same folder before.
It compiled though (finally) and fixing the typos above everything appears to be working. Didn't test bwd, but inference is correct on Gemma2 models and there's no noticeable overhead. Thanks for this. 🥇 |
I installed via pip, the version flash-attn 2.6.0.post1 . It still does not fix the problem. Is this fix included in the pip, or there is another way to install it? |
Glad to have this question answered: To the line of code |
Quick&dirty implementation (but seems functional).
Fixes #1016