Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VideoMAEforPretrained cannot be trained with Bfloat16 #27295

Closed
2 of 4 tasks
ikergarcia1996 opened this issue Nov 5, 2023 · 2 comments
Closed
2 of 4 tasks

VideoMAEforPretrained cannot be trained with Bfloat16 #27295

ikergarcia1996 opened this issue Nov 5, 2023 · 2 comments

Comments

@ikergarcia1996
Copy link
Contributor

ikergarcia1996 commented Nov 5, 2023

System Info

  • transformers version: 4.35.0
  • Platform: Linux-6.5.6-76060506-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.1
  • Accelerate version: 0.24.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.0.dev20230907 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@amyeroberts

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

It is not possible to train VideoMAEForPreTraining with bfloat16, because the labels are always stored as float32.
This code snippet triggers the error.

from transformers import AutoImageProcessor, VideoMAEForPreTraining
import numpy as np
import torch

num_frames = 16
video = list(np.random.randint(0, 256, (num_frames, 3, 224, 224)))

image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base",torch_dtype=torch.bfloat16).to("cuda")

pixel_values = image_processor(video, return_tensors="pt").pixel_values

num_patches_per_frame = (model.config.image_size // model.config.patch_size) ** 2
seq_length = (num_frames // model.config.tubelet_size) * num_patches_per_frame
bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()

outputs = model(pixel_values.to(device=model.device,dtype=model.dtype), bool_masked_pos=bool_masked_pos)
loss = outputs.loss

loss.backward()

Full TraceBack

RuntimeError                              Traceback (most recent call last)
Cell In[1], line 20
     17 outputs = model(pixel_values.to(device=model.device,dtype=model.dtype), bool_masked_pos=bool_masked_pos)
     18 loss = outputs.loss
---> 20 loss.backward()

File ~/miniconda3/envs/transformers/lib/python3.10/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    482 if has_torch_function_unary(self):
    483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
   (...)
    490         inputs=inputs,
    491     )
--> 492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )

File ~/miniconda3/envs/transformers/lib/python3.10/site-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    246     retain_graph = create_graph
    248 # The reason we repeat the same comment below is that
    249 # some Python versions print out the first line of a multi-line function
    250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252     tensors,
    253     grad_tensors_,
    254     retain_graph,
    255     create_graph,
    256     inputs,
    257     allow_unreachable=True,
    258     accumulate_grad=True,
    259 )

RuntimeError: Found dtype Float but expected BFloat16

The problem is that when computing the loss, the labels are in float32 therefore, the returned loss is also in float32.

logits: torch.bfloat16
labels: torch.float32
loss: torch.float32

Expected behavior

Labels should be converted to the same dtype as the logits.

This PR #27296 fixes the error. Altough I am not 100% sure that is the best way to handle the problem.

@ikergarcia1996 ikergarcia1996 changed the title Fix VideoMAEforPretrained cannot be trained with Bfloat16 VideoMAEforPretrained cannot be trained with Bfloat16 Nov 5, 2023
@amyeroberts
Copy link
Collaborator

Hi @ikergarcia1996 thanks for reporting and opening a PR!

I've started a review on the PR around implementation specifics and I think once merged that should resolve the issue.

@ikergarcia1996
Copy link
Contributor Author

Fixed #27296

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants