Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible data converting problem when using flash attention 2 with whisper #27260

Closed
2 of 4 tasks
changyeli opened this issue Nov 3, 2023 · 18 comments
Closed
2 of 4 tasks

Comments

@changyeli
Copy link

System Info

  • transformers version: 4.35.0
  • Platform: Linux-5.14.0-284.25.1.el9_2.x86_64-x86_64-with-glibc2.34
  • Python version: 3.11.5
  • Huggingface_hub version: 0.17.3
  • Safetensors version: 0.4.0
  • Accelerate version: 0.23.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: trainer's default

Who can help?

@sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I was trying to fine-tune whisper small with flash attention 2 on a private data. Followed the post here for most of the code. Here are some changes I made:

model_card = "openai/whisper-small"
model_name = model_card.split("/")[-1]
config = configparser.ConfigParser()
config.read("config.ini")
tran_df = pd.read_csv("../total_df.csv")
processor = AutoProcessor.from_pretrained(
    model_card)
tokenizer = WhisperTokenizer.from_pretrained(
    model_card)
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_card)
temo_dt = load_dataset(
    "audiofolder", data_dir=config['DATA']['dataset'],
    split="train[:1%]")
temo_dt = temo_dt.train_test_split(test_size=0.3)
temo_dt = temo_dt.cast_column("audio", Audio(sampling_rate=16000))

model = WhisperForConditionalGeneration.from_pretrained(
    model_card, use_flash_attention_2=True,
    torch_dtype=torch.float16)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language="english", task="transcribe")
model.config.suppress_tokens = []
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# training process
training_args = Seq2SeqTrainingArguments(
    output_dir=f"../{model_name}", 
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=6000,
    # speed up
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    auto_find_batch_size=True,
    torch_compile=True,
)
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=temo_dt["train"],
    eval_dataset=temo_dt["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics_wer,
    tokenizer=processor.feature_extractor,
)
trainer.train()

It gave me this error: RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same.

So I tried to convert the temo_dt to half tensor using the following code:

format = {'type': 'torch', 'format_kwargs' :{'dtype': torch.float16}}
temo_dt.set_format(**format)

But it returned this error: RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding).

Very interestingly, I can fine-tune the whisper small model perfectly without flash attention 2 using the code above. Is there anything I missed?

Expected behavior

Fine-tuning whisper should go as expected with use_flash_attention_2=True.

@LysandreJik
Copy link
Member

cc @younesbelkada

@younesbelkada
Copy link
Contributor

hi @changyeli
Can you share the full traceback of the two errors you are getting?

@younesbelkada
Copy link
Contributor

Can you also share the content of the dataset? note only the input_features needs to be casted in float16 and not the entire entries on your dataset. Perhaps you can use dataset.map(xxx, batched=True) to only cast input_features in float16 : https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L1684

@changyeli
Copy link
Author

@younesbelkada Sure, here is the full traceback using torch_dtype=torch.float16:

Traceback (most recent call last):
  File "/home/suppl/scripts/fine_tune_whisper.py", line 161, in <module>
    trainer.train()
  File "/home//anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home//anaconda3/envs/whisper/lib/python3.11/site-packages/accelerate/utils/memory.py", line 136, in decorator
    return function(batch_size, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home//anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 2725, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 2748, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 110, in parallel_apply
    output.reraise()
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/_utils.py", line 694, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1683, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1543, in forward
    encoder_outputs = self.encoder(
                      ^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1119, in forward
    inputs_embeds = nn.functional.gelu(self.conv1(input_features))
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 310, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 306, in _conv_forward
    return F.conv1d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same

@changyeli
Copy link
Author

I also have a somehow related question regarding the gradient checkpoint. I got the following error when I set gradient_checkpointing=True on a small subset.

Segmentation fault (core dumped)

It looks like something related to memory. I used the whisper small model and ~500 samples for fine-tuning and got this error. Interestingly, It was perfectly fine with inference - I can inference using whisper large on a much bigger dataset. How should I proceed?

@changyeli
Copy link
Author

The input features are indeed in float32 dtype. Tried to cast using the following code, but didn't work:

format = {'type': 'torch', 'format_kwargs' :{'dtype': torch.float16}}
temo_dt['train'].set_format(columns=['input_features'], **format)

@sanchit-gandhi
Copy link
Contributor

Could you try removing:

model = WhisperForConditionalGeneration.from_pretrained(
    model_card, use_flash_attention_2=True,
-   torch_dtype=torch.float16)

And then setting:

training_args = Seq2SeqTrainingArguments(
    output_dir=f"../{model_name}", 
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=6000,
    # speed up
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    auto_find_batch_size=True,
    torch_compile=True,
+   fp16=True
)

=> we should cast the input_features to the correct dtype if we let the Trainer handle it.

@changyeli
Copy link
Author

Hi @sanchit-gandhi, thanks for the suggestion! But I still got a data converting related error with your approach - here is full trace of error message:

ializing it on CPU with `model.to('cuda')`.
  0%|                                                                                               | 0/870 [00:00<?, ?it/s]The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float32.
Traceback (most recent call last):
  File "/home/coraal-suppl/scripts/fine_tune_whisper.py", line 164, in <module>
    trainer.train()
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/accelerate/utils/memory.py", line 136, in decorator
    return function(batch_size, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 2725, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 2748, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 110, in parallel_apply
    output.reraise()
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/_utils.py", line 694, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1683, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1543, in forward
    encoder_outputs = self.encoder(
                      ^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1159, in forward
    layer_outputs = encoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 722, in forward
    hidden_states, attn_weights, _ = self.self_attn(
                                     ^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 569, in forward
    attn_output = self._flash_attention_forward(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 629, in _flash_attention_forward
    attn_output = flash_attn_func(
                  ^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 708, in flash_attn_func
    return FlashAttnFunc.apply(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 437, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 49, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
                                                                ^^^^^^^^^^^^^^^^^^^^
RuntimeError: FlashAttention only support fp16 and bf16 data type

  0%|          | 0/870 [00:09<?, ?it/s]

Copy link

github-actions bot commented Jan 4, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@changyeli
Copy link
Author

Hi @sanchit-gandhi I keep getting similar errors after upgrading to the most recent version (as of 01/13/24). Loaded the model using

model = WhisperForConditionalGeneration.from_pretrained(
        model_card,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        use_safetensors=True)

As the torch_dtype here is required, or it will raise the following error even if I set fp16=True in the Seq2SeqTrainingArguments: ValueError: Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed torch.float32, this might lead to unexpected behaviour.

I got this error when trying to fine-tune whisper model: RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same. This is the same error I got when I open this issue. I also converted the the input data using

dt.set_format(
        columns=['input_features', 'labels'],
        **format
    )

And call the model via:

model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_card,
        attn_implementation="flash_attention_2",
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        use_safetensors=True)

But I got a different error message:

trainer.train()
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 1537, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 1854, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 2735, in training_step
loss = self.compute_loss(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/trainer.py", line 2758, in compute_loss
outputs = model(**inputs)
^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
outputs = self.parallel_apply(replicas, inputs, module_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 110, in parallel_apply
output.reraise()
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/_utils.py", line 694, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
output = module(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1818, in forward
outputs = self.model(
^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1694, in forward
decoder_outputs = self.decoder(
^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1442, in forward
inputs_embeds = self.embed_tokens(input_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 162, in forward
return F.embedding(
^^^^^^^^^^^^
File "/home/anaconda3/envs/whisper/lib/python3.11/site-packages/torch/nn/functional.py", line 2233, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

0%| | 0/22 [00:08<?, ?it/s]

It's interesting as I use FA2 in decoder-only language models and it works fine, but no luck on whisper yet. Any suggestion on how to proceed?

@changyeli
Copy link
Author

Did several experiments - turned out I forgot to drop unused columns for preprocessing but the problem still exists.

For fine-tuning, if I remove torch_dtype=torch.float16 when loading the pre-trained model and enable fp16=True in the training arguments. I got the following error:

ValueError: Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed torch.float32, this might lead to unexpected behaviour.

@RohitMidha23
Copy link

@sanchit-gandhi any update on this? How do we use flash attention 2 here?

My guess is that this has to do with the mixed precision training not working fine. In accelerate, by passing mixed_precision as "no", this error wasn't coming up but how do we do the same in Trainer?

@cvl01
Copy link

cvl01 commented Aug 7, 2024

I am encountering the same kind of issue. Has any solution been found yet? I only get the error on the evaluation step of the trainer, the training itself seems to work fine.

RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

@amyeroberts
Copy link
Collaborator

cc @sanchit-gandhi @ylacombe

@ylacombe
Copy link
Contributor

ylacombe commented Sep 2, 2024

Hey @changyeli @cvl01 and @RohitMidha23, is it still an issue?

@changyeli
Copy link
Author

Hey @ylacombe Somehow I bypassed this issue with #32370 and #32366

It works with attn_implementation="sdpa" when retrieving model's pre-trained weights.

@RohitMidha23
Copy link

@changyeli that sort of goes against the point as you aren't using flash attention anymore - which is inherently designed to be faster...

@changyeli
Copy link
Author

changyeli commented Sep 3, 2024

@RohitMidha23 I think sdpa supports flash attention. This is what I found in the documentation:

PyTorch’s torch.nn.functional.scaled_dot_product_attention (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for torch>=2.1.1 when an implementation is available. You may also set attn_implementation="sdpa" in from_pretrained() to explicitly request SDPA to be used.

But I agree regarding the speed - it does not speed up that much using SDPA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants