-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RotaryEmbedding computation is wrong for certain position/feature pairs in reduced precision (both fp16 and bfloat) #1003
Comments
Noting this as a possible cause of DeepSpeed issue #3742 |
Isn't NeoX already doing it the |
The cast-to-float happens after |
Doesn't the |
|
So Deepspeed casts the module buffer to the model dtype when |
I don't know about documentation, but you can see it in DeepspeedEngine here: https://github.com/microsoft/DeepSpeed/blob/46784cb58edf7bbe9b6bbec95212de7b81e55b01/deepspeed/runtime/engine.py#L1142 The best fix is probably to do all the cos/sin table calculation and caching in the constructor when you can be sure about dtypes (or insert appropriate casts in |
@cbcase How does this implementation look? I admit it's quite hacky because I wanted to preserve backward compatibility of checkpoints. Also, right now when seq_len changes it computes a new cos/sin table instead of doing a slice. If this looks good to you, I can turn it into a version that's ready for main. https://github.com/EleutherAI/gpt-neox/blob/math-lm-2-rotary/megatron/model/positional_embeddings.py#L38 |
Describe the bug
The
RotaryEmbedding
module does substantially all of the computation of the cached cos and sin tables in whatever is the model precision (usually fp16 or bfloat16). For certain (position, feature) pairs, this produces wildly different values than the corresponding fp32 computation.To Reproduce
Here is a small reproducer:
On my machine, I get:
As you can see, the issue is that some of the outer product values (
emb
) are large, so the small relative rounding to fp16 is large in absolute magnitude compared to the period of cos / sin. Note that the issue is relatively worse in bfloat, since it has less precision (across a wider range).Expected behavior
The computed embeddings shouldn't depend on model precision (up to the point of rounding the computed cos/sin tables)
Proposed solution
Lots of reasonable ways to rework this. I expect the main pain is that
inv_freq
is a buffer stored in people's checkpoints with the model dtype.The text was updated successfully, but these errors were encountered: