Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] First draft for softcapping. #1025

Merged
merged 2 commits into from
Jul 8, 2024
Merged

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Jul 3, 2024

Quick&dirty implementation (but seems functional).

Fixes #1016

@Narsil Narsil marked this pull request as draft July 3, 2024 19:40
@Narsil Narsil mentioned this pull request Jul 3, 2024
@tridao
Copy link
Contributor

tridao commented Jul 3, 2024

You can add these to the setup.py to reduce compilation time:

                        "-DFLASHATTENTION_DISABLE_BACKWARD",
                        "-DFLASHATTENTION_DISABLE_DROPOUT",
                        "-DFLASHATTENTION_DISABLE_ALIBI",
                        "-DFLASHATTENTION_DISABLE_UNEVEN_K",
                        "-DFLASHATTENTION_DISABLE_LOCAL",

@Narsil
Copy link
Contributor Author

Narsil commented Jul 3, 2024

Oh nice missed them !

@tridao
Copy link
Contributor

tridao commented Jul 3, 2024

I think softcapping should be done before the masking. i.e. the sequence is gemm, softcapping, masking, then softmax.
If you do softcapping after masking, then some masked tokens will contribute a tiny amount to the softmax. In practice it's probably ok if softcap value is large (like 50) but if it's small (e.g. 1.0), this can lead to information leakage from future tokens to past tokens.

@tridao
Copy link
Contributor

tridao commented Jul 3, 2024

softcapping can be fused with dividing by softmax_scale.
i.e. we do S = gemm(Q, K), then S *= softmax_scale * 1 / softcap where "softmax_scale * 1 / softcap" is a constant we can compute before hand and put in the params.
Then we do masking, then take the max.
When it's time to do exp, we do exp(scores * softcap - max * softcap). This will use a fused multiply add so it's just 1 instruction.

@Narsil
Copy link
Contributor Author

Narsil commented Jul 4, 2024

There is still the tanh that's missing somewhere, where would you put it ?
I was thinking adding a const template for softcapping so the cost of the branch wouldn't affect non softcapped kernels.

wdyt ?

@tridao
Copy link
Contributor

tridao commented Jul 4, 2024

Sorry I missed the tanh.
The step should be S = gemm(Q, K), then S = tanh(softmax_scale * 1 / softcap), then masking, taking max, then exp2f(scores * softcap * log_2(e) - max * softcap * log_2(e)).

Yeah we should template to avoid slowing down the usual attention.

@Narsil
Copy link
Contributor Author

Narsil commented Jul 4, 2024

Ok I put the template for Is_softcapping, however I cannot get your idea working. I may be missing something.

My understandling is that

tanh(x * softmax_scale * 1 / softcap)

Happens right when I was already doing it (but after gemm, before masking since tanh would throw -inf to -50).
To adding a new op there.
https://github.com/Dao-AILab/flash-attention/pull/1025/files#diff-9e1775131ae22a74dc4e0333c57539573394a059db8097ef74fae24243347ce1R330-R332

flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
    acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
    smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcapping){
    cute::transform(acc_s, softcapping_op{params.softcapping_scale});
}

And then I updated softmax_scale and softmax_scale_log2 to the equivalent but with softcapping value instead of the regular softmax_scale (since it's effectively taking it's place).

https://github.com/Dao-AILab/flash-attention/pull/1025/files#diff-406036c9702cf749b9e58833b342cfeb66a40c0faa1b43e2e8610f43c1332a5bR104-R118

if (softcapping.has_value()){
    params.is_softcapping = true;
    params.softcapping_scale = softmax_scale / softcapping.value();
    params.scale_softmax_log2 = softcapping.value() * M_LOG2E;
    params.scale_softmax = softcapping.value();
}else{
    params.is_softcapping = false;
    params.softcapping_scale = 1.0;
    params.scale_softmax_log2 = softmax_scale * M_LOG2E;
    params.scale_softmax = softmax_scale;
}

@Narsil
Copy link
Contributor Author

Narsil commented Jul 4, 2024

My understanding is that scales_softmax_log2 is only ever used in the partial exponentiation (the log2 is to use exp2f which I assume has better intrinsics than expf, or is it more about numeric stability ?).

scale_softmax is only used in the renormalization which as I understand is used to return the softmax to users that ask for them.

There is never partial sum over in the gemm right ? (meaning gemm being split along the hidden_dim rank, meaning the tanh wouldn't capture the entire sum, I don't think it is but I'm looking for culprits).

Do you have anything better to suggest than printf debugging for this ?

@lucidrains
Copy link
Contributor

@Narsil oh yea, your code looks way better than what i have lmao, let's just go with your changes

so eventually you'll have to expose a customizable softcapping_scale, which you should default to 30. (iirc used by Grok and Gemma2). maybe researchers will find a better value in the future

@lucidrains
Copy link
Contributor

lucidrains commented Jul 4, 2024

@Narsil backwards pass looks something like this in naive pytorch - and done another way here. both examples are functional

i understand things are 10-20x harder translating it to CUDA

@lucidrains
Copy link
Contributor

@Narsil great job getting the ball rolling 🙏

@Narsil
Copy link
Contributor Author

Narsil commented Jul 4, 2024

Thanks, but this is not functional at all. (I think using the DISABLE_ flags break flash already).

@lucidrains
Copy link
Contributor

@Narsil i think you are really close!

@lucidrains
Copy link
Contributor

lucidrains commented Jul 4, 2024

@Narsil yea, it is really hard to contribute with these compilation times. how fast were you able to get the times down to? i'm still waiting up to 10 minutes per change

edit: it is 26 minutes, just timed it, let me work on cutting that down rest of the day, or this is impossible to tinker with

edit 2: got it down to 2 minutes! 😄

edit 3: think i got the forward working! a lot of it was just following your lead with cute::transform, so i think you may just be missing something small

edit 4: i'll just move onto the backwards pass, since i think you have it in the bag with the forwards.

@Narsil
Copy link
Contributor Author

Narsil commented Jul 5, 2024

@lucidrains Can't see your changes anywhere, are you on a branch somewhere ?

I got compilation times down to 1mn but results seems wrong with them (meaning the regular non softcapped flash fail).
I am starting a branch from scratch to bisect code removal to keep at least 1 test working and compilation times lower.

To reduce compilation times I uncomment the compilation flags in setup.py ( -DXXXX). Each flag divides the time in half (since they are all booleans afaik. LOCAL and UNEVEN_K I keep (not entirely sure exactly when but they seem necessary for the test I'm keeping around).

The I also limit the amount of hidden size to only 1, meaning, comment the associated kernels in setup.py and then updating things a bit everywhere to cut those to have symbols in the binary (Undefined symbol fun).
Basically needs the changes made here. (launch_temaplte + static_switch at least).

@lucidrains
Copy link
Contributor

@Narsil nice, i got it to around the same ballpark!

unfortunately was working off a runpod that went down before i can push the changes 😞. but the good news is that i was luckily able to get backwards pass working as well 🥳

i compared some of the changes for the forward hoping to catch what was wrong in your diff, but couldn't find a difference. in fact i looked to your changes first and used the way you did the cute::transform, so it must be something small on your end

@Narsil
Copy link
Contributor Author

Narsil commented Jul 5, 2024

Nice doing the backward !
You're saying this branch is supposed to work ?

Can I try your branch ? (My local changes were exactly this branch).
A10G (sm_86).
Cuda 12.5
Ubuntu 20.04
Here. I've also dealt a few times with my path being spoiled (so 2 versions of flash coexisted until I manually cleared).

@lucidrains
Copy link
Contributor

@Narsil lost my changes, but will work on restoring it later today (traveling with dog for the 4th, American holiday)

yes, i think you are probably just off by some scale, error could even be in your tests

@lucidrains
Copy link
Contributor

lucidrains commented Jul 5, 2024

i also want to do a separate PR just to make it easy for contributors to get started (specify a few hyperparameters, and enable/disable booleans for those flags, and only those kernels get compiled and tested)

that ended up being the hardest part of all this

@Shreya-Pathak
Copy link

Shreya-Pathak commented Jul 5, 2024

Hello, I was separately working on adding softcapping to FA for Gemma 2. But seeing this PR, I'll add my WIP repo here if needed for double-checking: https://github.com/Shreya-Pathak/flash-attention. It is using the same idea as discussed above. The forward seemed correct (read passing tests) when I was checking but let me know if you see some issues.
Additionally, I'm eager to get softcapping supported in FA as soon as possible so please let me know how I can contribute. If any support is needed on the backward pass, will be happy to spend some time on it.

@lucidrains
Copy link
Contributor

lucidrains commented Jul 5, 2024

@Shreya-Pathak looks great Shreya! i actually prefer your way for the interface, with boolean flag to turn on with a default softcapping scale (realistically the public won't know the right value, so it should just default to something proven and working) But Nicolas' way with the cute::transform on the cuda end seems cleaner

@lucidrains
Copy link
Contributor

lucidrains commented Jul 5, 2024

regardless, i think the forward issue is done for, and Gemma2 inference will be viable soon

i'll PR in the backwards pass once one of you get your changes in. there's actually two ways to approach backwards, and i'm not sure which is the right way, so may need to consult Tri

@Narsil
Copy link
Contributor Author

Narsil commented Jul 5, 2024

Hi @Shreya-Pathak,

I looked at your changes.

Personally I prefer the single flag (less things to know for users, and softcapping is unlikely to have a good default. Gemma2 uses 50, not 30 for instance: https://huggingface.co/google/gemma-2-9b-it/blob/main/config.json#L7).
Taste and preferences I guess, ultimately it's a core maintainer's job to decide I think.

What you did was the first iteration I did. And tri mentionned that we could do smarter things here: #1025 (comment)

Basically currently FA applies the scale during the exponentiation of the softmax (to exploit intrisics that can fuse both ops in a single instruction). Therefore we should be able to get away by using the softmax_scale / softcap directly insinde the tanh, and only replace the current softmax_scales with softcap which will get multiplied as usual (keeping the instruction fused).

Doesn't seem to be working for me though.

@lucidrains
Copy link
Contributor

lucidrains commented Jul 5, 2024

@Narsil ah, good to know that Gemma 2 used 50., i believe Grok used 30.. and yea, it is a matter of preference (edit: actually may change my mind on this, since they are using different values probably better to leave it undefaulted)

so when you say things don't work, it is the changes you made to account for Tri's suggestion? your first iteration worked?

@lucidrains
Copy link
Contributor

@Narsil ok, i'm sure you'll figure it out. @Shreya-Pathak can probably help too

ping me when you want me to throw up the backwards pass

@Shreya-Pathak
Copy link

@Narsil I think I have also done what Tri mentioned with the softcap / softmax scale and from a brief look at your code, you seem to be doing the same as well. Could you give more details about what tests are failing and what the differences are?
Regarding the default value, you can still change the softcap value in the calling function AFAICT but the difference is only in style so anything is fine by me.

@tridao
Copy link
Contributor

tridao commented Jul 5, 2024

I think we should pass to tanh with softclapping_scale = softmax_scale (e.g. 1/sqrt(headdim) / softcap_val (e.g. 50.0).
As an example, with headdim = 128 and softcap = 50, we would do tanh(acc_s * 1/sqrt(128) / 50) = tanh(acc_s * 1.77e-3).

Then in the softmax, instead of passing in softmax_scale_log2, we should pass in softcap_val * log2(e).
In this example we would have 50.0 * log2(e) = 72.1.

Overall we're doing exp2(log2(e) * 50 * tanh(acc_s * softmax_scale / 50)).
But hopefully by multiplying these constants together before hand we can reduce the number of instructions.

@Narsil
Copy link
Contributor Author

Narsil commented Jul 6, 2024

OK it's updated and now working.

I have no idea why but cute::transform seems to be the culprit.
Thanks @Shreya-Pathak for the apply_softcap function that I did use in the end.

@lucidrains
Copy link
Contributor

OK it's updated and now working.

I have no idea why but cute::transform seems to be the culprit. Thanks @Shreya-Pathak for the apply_softcap function that I did use in the end.

that's strange! i used cute::transform in both fwd and bwd

ah no matter, all roads lead to rome

@Narsil
Copy link
Contributor Author

Narsil commented Jul 6, 2024

Do you want to patch this PR for the backward ?

@lucidrains
Copy link
Contributor

lucidrains commented Jul 6, 2024

@Narsil let's land this PR first for the forwards, as I may do two separate PRs for backwards and let Tri pick the one that makes sense

@Narsil Narsil marked this pull request as ready for review July 8, 2024 07:25
@Narsil
Copy link
Contributor Author

Narsil commented Jul 8, 2024

@tridao do you know what's missing to merge ?

Should I run the CI manually maybe (without backward since it's not implemented yet.)

@lucidrains
Copy link
Contributor

@Narsil for backwards, just assert out or throw an error if soft capping is turned on

otherwise great job! @Shreya-Pathak too!

@iamsaurabhgupt
Copy link

great stuff @Narsil .
waiting for @tridao to merge.

@tridao tridao merged commit 8f873cc into Dao-AILab:main Jul 8, 2024
@iamsaurabhgupt
Copy link

thanks a lot @tridao

@iamsaurabhgupt
Copy link

The released build wheels are still dated May 27.
I think there must be some build pipeline that takes time to update release assets?

@@ -639,6 +659,7 @@ def backward(ctx, dout, *args):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small typo here, missing a comma

@@ -556,6 +572,7 @@ def backward(ctx, dout, *args):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small typo here, missing a comma

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this for me too. there is a second missing comma further down. also my first compile produced gibberish, am trying again with updated cutlass. it has successfully compiled from the same folder before.

@turboderp
Copy link

It compiled though (finally) and fixing the typos above everything appears to be working. Didn't test bwd, but inference is correct on Gemma2 models and there's no noticeable overhead. Thanks for this. 🥇

@Oxi84
Copy link

Oxi84 commented Jul 12, 2024

I installed via pip, the version flash-attn 2.6.0.post1 .

It still does not fix the problem.

Is this fix included in the pip, or there is another way to install it?

@foreverlms
Copy link

Glad to have this question answered:
If Is_softcap is False, where is the scaling of QK^T performed?

To the line of code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Logit soft-capping
9 participants