Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama inference instability in fp16 producing inf in the middle of the model #27179

Closed
2 of 4 tasks
fxmarty opened this issue Oct 31, 2023 · 18 comments
Closed
2 of 4 tasks
Assignees

Comments

@fxmarty
Copy link
Contributor

fxmarty commented Oct 31, 2023

System Info

  • transformers version: 4.35.0.dev0
  • Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
  • Python version: 3.9.16
  • Huggingface_hub version: 0.17.3
  • Safetensors version: 0.3.1
  • Accelerate version: 0.25.0.dev0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0+cu118 (True)
  • Using GPU in script?: A100

Who can help?

@ydshieh @fxmarty @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hi, I encounter inference instability with llama running in fp16 when left padding is used, and especially when full rows are masked out in the 4D attention mask.

At some point in the forward, inf values may appear in the intermediate logits, ultimately leading to tensors filled with nan and raising the error:

Traceback (most recent call last):
  File "=debug.py", line 38, in <module>
    outputs = model.generate(
  File "/fsx/felix/condaenvs/fx/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/felix/transformers/src/transformers/generation/utils.py", line 1704, in generate
    return self.sample(
  File "/fsx/felix/transformers/src/transformers/generation/utils.py", line 2822, in sample
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

Note that the inf specifically appear at a padding position.

Reproduction:

from transformers import AutoTokenizer, pipeline, logging, AutoModelForCausalLM
import torch

model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
token = "[specify your token]"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, token=token)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

with torch.device("cuda"):
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, token=token)

sentence = "Felix Marty is a French"

# Alternatively, the issue can be reproduced with:
# sentence = "Elon Musk is a South"
# max_length=9

inp = tokenizer(sentence, return_tensors='pt', padding="max_length", max_length=9).to("cuda")

print("inp", inp["input_ids"].shape)
print("inp", inp)
torch.set_printoptions(threshold=10000000)

print("\n\n*** Generate:")
with torch.no_grad():
    outputs = model.generate(
        **inp,
        max_new_tokens=10,
        do_sample=True,
        top_p=0.9,
        temperature=float(0.01),
        top_k=40
    )

print(tokenizer.batch_decode(outputs))

Printing torch.all(torch.isfinite()) at some points in the model, it appears the inf start to appear in the MLP at self.gate_proj(x)) * self.up_proj(x) and things go crazy from there.

What's interesting is that for example fixing (two left padding tokens)
image

to

image

solves the issue.

It makes me think that the solution implemented for SDPA to avoid fully masked rows in the attention mask may actually be required for some other cases as this one #26572 - but it is unclear why it relates to overflow here.

WDYT @gante @ydshieh? Is this something you have ever observed?

Expected behavior

No inf spawning in the middle of inference with fp16 model

@ydshieh
Copy link
Collaborator

ydshieh commented Oct 31, 2023

Related to #17937 but there is dummy model.

Will take a look here.

@gante
Copy link
Member

gante commented Oct 31, 2023

@ArthurZucker has been tracking it, and has a draft PR for it: #27114

@fxmarty Can you check if applying this change fixes it?

@fxmarty
Copy link
Contributor Author

fxmarty commented Oct 31, 2023

@ydshieh @gante Thank you! No this PR is unrelated unfortunately, as it also happens when the prompt Elon Musk is a South with max_length=9 (only one padding token) and the extended attention mask

image

that does not have any inf.

It may just be instability in the model, but it feels weird that it arises only when some attention mask rows are fully masked.

@gante
Copy link
Member

gante commented Oct 31, 2023

Well, I guess it needs another deep dive 😬

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 2, 2023

I haven't been able to give a final conclusion, but in LlamaMLP.forward, change the else block to

            h1 = self.gate_proj(x)
            h2 = self.act_fn(h1)
            h3 = self.up_proj(x)
            h4 = self.down_proj(h2 * h3)
            down_proj = h4

and print their maximal absolute values, we will see their magnitude get unusually larger than before from layer 29 (0-based), and amplified to 255 in layer 30, than h4 get inf.

The question is what happened in layer 29 for this input: but I am afraid it's just some numerical issue and we don't really have the control.

Will take a further look later when I get spare time.

------------------------------------
layer: 29
h1: 13.234375
h2: 4.5
h3: 18.15625
h4: 57.28125
------------------------------------
layer: 30
h1: 255.875
h2: 255.875
h3: 261.75
h4: inf
------------------------------------
layer: 31
h1: nan
h2: nan
h3: nan
h4: nan
------------------------------------

full

layer: 0
h1: 4.2109375
h2: 4.1484375
h3: 2.220703125
h4: 4.08203125
------------------------------------
layer: 1
h1: 19.75
h2: 19.75
h3: 18.328125
h4: 753.0
------------------------------------
layer: 2
h1: 2.8671875
h2: 2.025390625
h3: 2.2421875
h4: 4.1640625
------------------------------------
layer: 3
h1: 2.259765625
h2: 1.11328125
h3: 1.484375
h4: 0.423583984375
------------------------------------
layer: 4
h1: 4.4375
h2: 4.38671875
h3: 2.642578125
h4: 7.58203125
------------------------------------
layer: 5
h1: 3.142578125
h2: 2.416015625
h3: 2.431640625
h4: 1.2734375
------------------------------------
layer: 6
h1: 2.7265625
h2: 1.98828125
h3: 2.2578125
h4: 1.0556640625
------------------------------------
layer: 7
h1: 2.650390625
h2: 1.8046875
h3: 2.349609375
h4: 1.3251953125
------------------------------------
layer: 8
h1: 3.34375
h2: 1.76171875
h3: 2.6796875
h4: 1.8408203125
------------------------------------
layer: 9
h1: 4.37109375
h2: 2.328125
h3: 3.142578125
h4: 1.734375
------------------------------------
layer: 10
h1: 4.3046875
h2: 3.1796875
h3: 2.62109375
h4: 1.3212890625
------------------------------------
layer: 11
h1: 3.853515625
h2: 3.5078125
h3: 3.0078125
h4: 1.62890625
------------------------------------
layer: 12
h1: 3.33203125
h2: 2.224609375
h3: 2.548828125
h4: 1.005859375
------------------------------------
layer: 13
h1: 3.560546875
h2: 2.783203125
h3: 3.087890625
h4: 2.23828125
------------------------------------
layer: 14
h1: 3.841796875
h2: 2.9609375
h3: 2.63671875
h4: 0.67626953125
------------------------------------
layer: 15
h1: 3.609375
h2: 3.4765625
h3: 3.107421875
h4: 1.8994140625
------------------------------------
layer: 16
h1: 5.06640625
h2: 4.078125
h3: 4.28515625
h4: 5.421875
------------------------------------
layer: 17
h1: 5.35546875
h2: 5.33203125
h3: 3.740234375
h4: 2.919921875
------------------------------------
layer: 18
h1: 4.0
h2: 3.853515625
h3: 3.60546875
h4: 3.271484375
------------------------------------
layer: 19
h1: 4.46484375
h2: 4.4140625
h3: 3.75
h4: 4.16796875
------------------------------------
layer: 20
h1: 3.66796875
h2: 2.970703125
h3: 3.658203125
h4: 2.962890625
------------------------------------
layer: 21
h1: 5.34375
h2: 5.31640625
h3: 3.400390625
h4: 2.0234375
------------------------------------
layer: 22
h1: 3.318359375
h2: 3.203125
h3: 3.451171875
h4: 1.546875
------------------------------------
layer: 23
h1: 4.28125
h2: 4.22265625
h3: 4.109375
h4: 2.67578125
------------------------------------
layer: 24
h1: 4.21484375
h2: 3.220703125
h3: 3.15625
h4: 0.9482421875
------------------------------------
layer: 25
h1: 3.93359375
h2: 3.7109375
h3: 3.947265625
h4: 3.9921875
------------------------------------
layer: 26
h1: 4.3359375
h2: 3.865234375
h3: 4.37109375
h4: 1.7041015625
------------------------------------
layer: 27
h1: 4.55078125
h2: 3.400390625
h3: 3.630859375
h4: 1.4111328125
------------------------------------
layer: 28
h1: 4.90234375
h2: 4.4453125
h3: 7.54296875
h4: 2.0546875
------------------------------------
layer: 29
h1: 13.234375
h2: 4.5
h3: 18.15625
h4: 57.28125
------------------------------------
layer: 30
h1: 255.875
h2: 255.875
h3: 261.75
h4: inf
------------------------------------
layer: 31
h1: nan
h2: nan
h3: nan
h4: nan
------------------------------------

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 6, 2023

After taking a further look, this doesn't seem to relate any bug but just the limitation of using fp16, and this is also depending on the input data.

One observation I found is: larger tensor values tend to appear when the prompt is (very) short.

Also, when this happens, I often see many places in the corresponding multiplications have values with the same sign.

Nothing more I can provide I am afraid.

@fxmarty
Copy link
Contributor Author

fxmarty commented Nov 6, 2023

Thanks a lot @ydshieh. Did you notice any difference with whether rows are fully masked in the attention mask or not?

We can probably close this one - at least it is good to know that (at least) llama 7b has numerical instabilities during inference in fp16.

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 6, 2023

whether rows are fully masked in the attention mask or not?

Oh, I might made a mistake! You have max_length=9 in the code snippet, so if I use long sequence, there is no padding!
OK, need to recheck !

@ArthurZucker
Copy link
Collaborator

I think beam search with ROPE and fp16 has instabilities yes, reported here: #26332 if I am not mistaken this is what we have no? And I think a recent PR to fix this was merged: #26843 .
But yeah I have a pretty huge list of bugs to process!

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 6, 2023

FYI: here the issue is not even in the generation - the issue comes already in the first step: just encoding the input prompt.

@fxmarty
Copy link
Contributor Author

fxmarty commented Nov 8, 2023

Same issue in layer 29/30 in AutoGPTQ/AutoGPTQ#412. Unmasking fully masked padding rows solves the issue there as well.

image

And the nans indeed start to appear at the padding index if we do not unmask:

In the layer 30 without unmasking:

hidden_states after layernorm torch.Size([2, 6, 4096])
hidden_states b=0, seq_idx=0 mean: 0.00121307373046875
hidden_states b=0, seq_idx=1 mean: -0.0168914794921875
hidden_states b=0, seq_idx=2 mean: -0.00237274169921875
hidden_states b=0, seq_idx=3 mean: 0.0007181167602539062
hidden_states b=0, seq_idx=4 mean: -0.0108642578125
hidden_states b=0, seq_idx=5 mean: -0.006961822509765625
hidden_states b=1, seq_idx=0 mean: -0.0016736984252929688
hidden_states b=1, seq_idx=1 mean: 0.0012159347534179688
hidden_states b=1, seq_idx=2 mean: -0.016876220703125
hidden_states b=1, seq_idx=3 mean: -0.0023746490478515625
hidden_states b=1, seq_idx=4 mean: 0.0006799697875976562
hidden_states b=1, seq_idx=5 mean: -0.010833740234375
up_proj, down_proj
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 11008])
output finite tensor(True, device='cuda:0')
output absmax tensor(1.0762e+02, device='cuda:0', dtype=torch.float16)
output absmean tensor(4.6924e-01, device='cuda:0', dtype=torch.float16)
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 11008])
output finite tensor(True, device='cuda:0')
output absmax tensor(1.0962e+02, device='cuda:0', dtype=torch.float16)
output absmean tensor(4.5728e-01, device='cuda:0', dtype=torch.float16)
gate_proj b=0, seq_idx=0 mean: -0.047821, absmax: 14.078125
gate_proj b=0, seq_idx=1 mean: -0.208618, absmax: 23.078125
gate_proj b=0, seq_idx=2 mean: -0.253174, absmax: 23.859375
gate_proj b=0, seq_idx=3 mean: -0.270264, absmax: 27.84375
gate_proj b=0, seq_idx=4 mean: -0.184692, absmax: 14.5078125
gate_proj b=0, seq_idx=5 mean: -0.254639, absmax: 12.8203125
gate_proj b=1, seq_idx=0 mean: 0.309814, absmax: 107.625
gate_proj b=1, seq_idx=1 mean: -0.047852, absmax: 14.078125
gate_proj b=1, seq_idx=2 mean: -0.208496, absmax: 23.234375
gate_proj b=1, seq_idx=3 mean: -0.252930, absmax: 23.96875
gate_proj b=1, seq_idx=4 mean: -0.270508, absmax: 27.984375
gate_proj b=1, seq_idx=5 mean: -0.184937, absmax: 14.6484375
up_proj b=0, seq_idx=0 mean: 0.001290, absmax: 15.0546875
up_proj b=0, seq_idx=1 mean: -0.008339, absmax: 18.40625
up_proj b=0, seq_idx=2 mean: -0.016205, absmax: 18.0
up_proj b=0, seq_idx=3 mean: -0.005768, absmax: 23.234375
up_proj b=0, seq_idx=4 mean: -0.000823, absmax: 6.44921875
up_proj b=0, seq_idx=5 mean: -0.003519, absmax: 11.6171875
up_proj b=1, seq_idx=0 mean: 0.015915, absmax: 109.625
up_proj b=1, seq_idx=1 mean: 0.001284, absmax: 15.046875
up_proj b=1, seq_idx=2 mean: -0.008362, absmax: 18.5625
up_proj b=1, seq_idx=3 mean: -0.016220, absmax: 18.046875
up_proj b=1, seq_idx=4 mean: -0.005787, absmax: 23.34375
up_proj b=1, seq_idx=5 mean: -0.000838, absmax: 6.546875
act_gate b=0, seq_idx=0 mean: -0.011940, absmax: 14.078125
act_gate b=0, seq_idx=1 mean: 0.004330, absmax: 4.80859375
act_gate b=0, seq_idx=2 mean: 0.010277, absmax: 5.859375
act_gate b=0, seq_idx=3 mean: -0.015503, absmax: 6.46875
act_gate b=0, seq_idx=4 mean: 0.031921, absmax: 5.67578125
act_gate b=0, seq_idx=5 mean: -0.006973, absmax: 6.5
act_gate b=1, seq_idx=0 mean: 0.219971, absmax: 107.625
act_gate b=1, seq_idx=1 mean: -0.011948, absmax: 14.078125
act_gate b=1, seq_idx=2 mean: 0.004345, absmax: 4.80859375
act_gate b=1, seq_idx=3 mean: 0.010429, absmax: 5.859375
act_gate b=1, seq_idx=4 mean: -0.015495, absmax: 6.46484375
act_gate b=1, seq_idx=5 mean: 0.031738, absmax: 5.67578125
inter b=0, seq_idx=0 mean: 0.03338623046875, absmax: 212.0
inter b=0, seq_idx=1 mean: 0.00040793418884277344, absmax: 6.7734375
inter b=0, seq_idx=2 mean: 0.0011510848999023438, absmax: 7.125
inter b=0, seq_idx=3 mean: 0.00832366943359375, absmax: 17.46875
inter b=0, seq_idx=4 mean: 0.00707244873046875, absmax: 13.90625
inter b=0, seq_idx=5 mean: 0.0014142990112304688, absmax: 7.62890625
inter b=1, seq_idx=0 mean: 1.3212890625, absmax: 11800.0
inter b=1, seq_idx=1 mean: 0.03338623046875, absmax: 211.875
inter b=1, seq_idx=2 mean: 0.0004088878631591797, absmax: 6.796875
inter b=1, seq_idx=3 mean: 0.0011835098266601562, absmax: 7.1484375
inter b=1, seq_idx=4 mean: 0.008331298828125, absmax: 17.515625
inter b=1, seq_idx=5 mean: 0.007049560546875, absmax: 13.8828125
call down_proj
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 4096])
output finite tensor(False, device='cuda:0')
output absmax tensor(inf, device='cuda:0', dtype=torch.float16)
output absmean tensor(inf, device='cuda:0', dtype=torch.float16)
down_proj b=0, seq_idx=0 finite: True
down_proj b=0, seq_idx=1 finite: True
down_proj b=0, seq_idx=2 finite: True
down_proj b=0, seq_idx=3 finite: True
down_proj b=0, seq_idx=4 finite: True
down_proj b=0, seq_idx=5 finite: True
down_proj b=1, seq_idx=0 finite: False
down_proj b=1, seq_idx=1 finite: True
down_proj b=1, seq_idx=2 finite: True
down_proj b=1, seq_idx=3 finite: True
down_proj b=1, seq_idx=4 finite: True
down_proj b=1, seq_idx=5 finite: True

In the layer 30 with unmasking fully masked rows:

hidden_states after layernorm torch.Size([2, 6, 4096])
hidden_states b=0, seq_idx=0 mean: 0.0012102127075195312
hidden_states b=0, seq_idx=1 mean: -0.01690673828125
hidden_states b=0, seq_idx=2 mean: -0.002384185791015625
hidden_states b=0, seq_idx=3 mean: 0.0007028579711914062
hidden_states b=0, seq_idx=4 mean: -0.01085662841796875
hidden_states b=0, seq_idx=5 mean: -0.006946563720703125
hidden_states b=1, seq_idx=0 mean: -0.0006947517395019531
hidden_states b=1, seq_idx=1 mean: 0.00121307373046875
hidden_states b=1, seq_idx=2 mean: -0.0168609619140625
hidden_states b=1, seq_idx=3 mean: -0.0023975372314453125
hidden_states b=1, seq_idx=4 mean: 0.0006928443908691406
hidden_states b=1, seq_idx=5 mean: -0.01084136962890625
up_proj, down_proj
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 11008])
output finite tensor(True, device='cuda:0')
output absmax tensor(3.3969e+01, device='cuda:0', dtype=torch.float16)
output absmean tensor(4.5752e-01, device='cuda:0', dtype=torch.float16)
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 11008])
output finite tensor(True, device='cuda:0')
output absmax tensor(3.1141e+01, device='cuda:0', dtype=torch.float16)
output absmean tensor(4.5410e-01, device='cuda:0', dtype=torch.float16)
gate_proj b=0, seq_idx=0 mean: -0.047882, absmax: 14.078125
gate_proj b=0, seq_idx=1 mean: -0.208374, absmax: 23.09375
gate_proj b=0, seq_idx=2 mean: -0.252930, absmax: 23.875
gate_proj b=0, seq_idx=3 mean: -0.270508, absmax: 27.90625
gate_proj b=0, seq_idx=4 mean: -0.184692, absmax: 14.515625
gate_proj b=0, seq_idx=5 mean: -0.254639, absmax: 12.84375
gate_proj b=1, seq_idx=0 mean: -0.073853, absmax: 33.96875
gate_proj b=1, seq_idx=1 mean: -0.047852, absmax: 14.1015625
gate_proj b=1, seq_idx=2 mean: -0.208496, absmax: 23.21875
gate_proj b=1, seq_idx=3 mean: -0.253418, absmax: 23.953125
gate_proj b=1, seq_idx=4 mean: -0.270264, absmax: 27.984375
gate_proj b=1, seq_idx=5 mean: -0.184692, absmax: 14.5546875
up_proj b=0, seq_idx=0 mean: 0.001290, absmax: 15.046875
up_proj b=0, seq_idx=1 mean: -0.008347, absmax: 18.40625
up_proj b=0, seq_idx=2 mean: -0.016235, absmax: 17.984375
up_proj b=0, seq_idx=3 mean: -0.005745, absmax: 23.265625
up_proj b=0, seq_idx=4 mean: -0.000815, absmax: 6.4453125
up_proj b=0, seq_idx=5 mean: -0.003561, absmax: 11.6328125
up_proj b=1, seq_idx=0 mean: -0.004223, absmax: 31.140625
up_proj b=1, seq_idx=1 mean: 0.001290, absmax: 15.078125
up_proj b=1, seq_idx=2 mean: -0.008362, absmax: 18.5625
up_proj b=1, seq_idx=3 mean: -0.016251, absmax: 18.03125
up_proj b=1, seq_idx=4 mean: -0.005783, absmax: 23.328125
up_proj b=1, seq_idx=5 mean: -0.000843, absmax: 6.4765625
act_gate b=0, seq_idx=0 mean: -0.011971, absmax: 14.078125
act_gate b=0, seq_idx=1 mean: 0.004372, absmax: 4.8046875
act_gate b=0, seq_idx=2 mean: 0.010483, absmax: 5.86328125
act_gate b=0, seq_idx=3 mean: -0.015427, absmax: 6.46875
act_gate b=0, seq_idx=4 mean: 0.031860, absmax: 5.67578125
act_gate b=0, seq_idx=5 mean: -0.007015, absmax: 6.4921875
act_gate b=1, seq_idx=0 mean: 0.002026, absmax: 4.19140625
act_gate b=1, seq_idx=1 mean: -0.011955, absmax: 14.1015625
act_gate b=1, seq_idx=2 mean: 0.004314, absmax: 4.8125
act_gate b=1, seq_idx=3 mean: 0.010254, absmax: 5.86328125
act_gate b=1, seq_idx=4 mean: -0.015503, absmax: 6.4609375
act_gate b=1, seq_idx=5 mean: 0.031891, absmax: 5.6640625
inter b=0, seq_idx=0 mean: 0.033355712890625, absmax: 211.875
inter b=0, seq_idx=1 mean: 0.00041985511779785156, absmax: 6.76953125
inter b=0, seq_idx=2 mean: 0.0011568069458007812, absmax: 7.1328125
inter b=0, seq_idx=3 mean: 0.008331298828125, absmax: 17.421875
inter b=0, seq_idx=4 mean: 0.007068634033203125, absmax: 13.8828125
inter b=0, seq_idx=5 mean: 0.0014171600341796875, absmax: 7.63671875
inter b=1, seq_idx=0 mean: 0.0037746429443359375, absmax: 21.890625
inter b=1, seq_idx=1 mean: 0.033477783203125, absmax: 212.625
inter b=1, seq_idx=2 mean: 0.00041794776916503906, absmax: 6.78125
inter b=1, seq_idx=3 mean: 0.001155853271484375, absmax: 7.1328125
inter b=1, seq_idx=4 mean: 0.00830078125, absmax: 17.4375
inter b=1, seq_idx=5 mean: 0.007068634033203125, absmax: 13.828125
call down_proj
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 4096])
output finite tensor(True, device='cuda:0')
output absmax tensor(5.3750e+02, device='cuda:0', dtype=torch.float16)
output absmean tensor(4.9854e-01, device='cuda:0', dtype=torch.float16)
down_proj b=0, seq_idx=0 finite: True
down_proj b=0, seq_idx=1 finite: True
down_proj b=0, seq_idx=2 finite: True
down_proj b=0, seq_idx=3 finite: True
down_proj b=0, seq_idx=4 finite: True
down_proj b=0, seq_idx=5 finite: True
down_proj b=1, seq_idx=0 finite: True
down_proj b=1, seq_idx=1 finite: True
down_proj b=1, seq_idx=2 finite: True
down_proj b=1, seq_idx=3 finite: True
down_proj b=1, seq_idx=4 finite: True
down_proj b=1, seq_idx=5 finite: True

It is unclear to me what is happening here and how it relates to fully masked rows.

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 8, 2023

Great details! I am thinking if maybe the original training saw the unmasked row but now at inference time, it saw another version, which leads to this large value now. (similar to the different behavior of SDPA between torch 2.0.1 / 2.1.0 on GPU as we saw previously.)

@fxmarty
Copy link
Contributor Author

fxmarty commented Nov 9, 2023

@ydshieh I want to give a try at some point to the original llama repo to see how padding is handled there.

@fxmarty
Copy link
Contributor Author

fxmarty commented Dec 4, 2023

not stale

@callanwu
Copy link
Contributor

mark

@huggingface huggingface deleted a comment from github-actions bot Jan 5, 2024
@huggingface huggingface deleted a comment from github-actions bot Jan 30, 2024
@ArthurZucker
Copy link
Collaborator

I think computing ROPE in float32 percision should partly fix this

@ArthurZucker
Copy link
Collaborator

I'll mark this as closed, because llama now computes rope in float32! 🥳 Feel free to ping me if you feel like this should not be closed

@PIPItian
Copy link

you can try to update optimum to the latest version to solve this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
6 participants