-
Notifications
You must be signed in to change notification settings - Fork 26.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible data converting problem when using flash attention 2 with whisper #27260
Comments
hi @changyeli |
Can you also share the content of the dataset? note only the |
@younesbelkada Sure, here is the full traceback using Traceback (most recent call last):
File "/home/suppl/scripts/fine_tune_whisper.py", line 161, in <module>
trainer.train()
File "/home//anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/home//anaconda3/envs/whisper/lib/python3.11/site-packages/accelerate/utils/memory.py", line 136, in decorator
return function(batch_size, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home//anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 2725, in training_step
loss = self.compute_loss(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 2748, in compute_loss
outputs = model(**inputs)
^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
outputs = self.parallel_apply(replicas, inputs, module_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 110, in parallel_apply
output.reraise()
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/_utils.py", line 694, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
output = module(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1683, in forward
outputs = self.model(
^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1543, in forward
encoder_outputs = self.encoder(
^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1119, in forward
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 310, in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 306, in _conv_forward
return F.conv1d(input, weight, bias, self.stride,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same |
I also have a somehow related question regarding the gradient checkpoint. I got the following error when I set
It looks like something related to memory. I used the whisper small model and ~500 samples for fine-tuning and got this error. Interestingly, It was perfectly fine with inference - I can inference using whisper large on a much bigger dataset. How should I proceed? |
The input features are indeed in format = {'type': 'torch', 'format_kwargs' :{'dtype': torch.float16}}
temo_dt['train'].set_format(columns=['input_features'], **format) |
Could you try removing: model = WhisperForConditionalGeneration.from_pretrained(
model_card, use_flash_attention_2=True,
- torch_dtype=torch.float16) And then setting: training_args = Seq2SeqTrainingArguments(
output_dir=f"../{model_name}",
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
learning_rate=1e-5,
warmup_steps=500,
max_steps=6000,
# speed up
gradient_checkpointing=True,
evaluation_strategy="steps",
per_device_eval_batch_size=16,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to="none",
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
auto_find_batch_size=True,
torch_compile=True,
+ fp16=True
) => we should cast the |
Hi @sanchit-gandhi, thanks for the suggestion! But I still got a data converting related error with your approach - here is full trace of error message:
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi @sanchit-gandhi I keep getting similar errors after upgrading to the most recent version (as of 01/13/24). Loaded the model using model = WhisperForConditionalGeneration.from_pretrained(
model_card,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
use_safetensors=True) As the I got this error when trying to fine-tune whisper model: RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same. This is the same error I got when I open this issue. I also converted the the input data using dt.set_format(
columns=['input_features', 'labels'],
**format
) And call the model via: model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_card,
attn_implementation="flash_attention_2",
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True) But I got a different error message: trainer.train() 0%| | 0/22 [00:08<?, ?it/s] It's interesting as I use FA2 in decoder-only language models and it works fine, but no luck on whisper yet. Any suggestion on how to proceed? |
Did several experiments - turned out I forgot to drop unused columns for preprocessing but the problem still exists. For fine-tuning, if I remove ValueError: Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed torch.float32, this might lead to unexpected behaviour. |
@sanchit-gandhi any update on this? How do we use flash attention 2 here? My guess is that this has to do with the mixed precision training not working fine. In |
I am encountering the same kind of issue. Has any solution been found yet? I only get the error on the evaluation step of the trainer, the training itself seems to work fine.
|
Hey @changyeli @cvl01 and @RohitMidha23, is it still an issue? |
@changyeli that sort of goes against the point as you aren't using |
@RohitMidha23 I think sdpa supports flash attention. This is what I found in the documentation:
But I agree regarding the speed - it does not speed up that much using SDPA. |
System Info
transformers
version: 4.35.0Who can help?
@sanchit-gandhi
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I was trying to fine-tune whisper small with flash attention 2 on a private data. Followed the post here for most of the code. Here are some changes I made:
It gave me this error:
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same.
So I tried to convert the
temo_dt
to half tensor using the following code:But it returned this error:
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding).
Very interestingly, I can fine-tune the whisper small model perfectly without flash attention 2 using the code above. Is there anything I missed?
Expected behavior
Fine-tuning whisper should go as expected with
use_flash_attention_2=True
.The text was updated successfully, but these errors were encountered: