You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Mar 15, 2024. It is now read-only.
As shown in ISSUE#29, the vision transformer is not stable when training it with AMP, even when the model is deep.
Without AMP, the loss will not become NAN. However, it will run very slowly for training.
I have found that the loss becomes NAN in attention, and simply use FP32 for attention will solve the problem. In my experiments, replacing the attention block with the followed code, the model can be resumed normally. (sometimes you will need to change random seed)
classAttention(nn.Module):
def__init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads=num_headshead_dim=dim//num_heads# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weightsself.scale=qk_scaleorhead_dim**-0.5self.qkv=nn.Linear(dim, dim*3, bias=qkv_bias)
self.attn_drop=nn.Dropout(attn_drop)
self.proj=nn.Linear(dim, dim)
self.proj_drop=nn.Dropout(proj_drop)
defforward(self, x):
B, N, C=x.shapeqkv=self.qkv(x).reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(2, 0, 3, 1, 4)
withtorch.cuda.amp.autocast(enabled=False):
q, k, v=qkv[0].float(), qkv[1].float(), qkv[2].float() # make torchscript happy (cannot use tensor as tuple)attn= (q @ k.transpose(-2, -1)) *self.scaleattn=attn.softmax(dim=-1)
attn=self.attn_drop(attn)
x= (attn @ v).transpose(1, 2).reshape(B, N, C)
x=self.proj(x)
x=self.proj_drop(x)
returnx
The text was updated successfully, but these errors were encountered:
Hi @Andy1621,
Thank you for your message,
Please could you post it in issue#29 instead of in a new issue.
(Don't hesitate to re-open issue#29 if needed)
This way it is easier to find all the information related to this subject ;)
Best,
Hugo
As shown in ISSUE#29, the vision transformer is not stable when training it with AMP, even when the model is deep.
Without AMP, the loss will not become NAN. However, it will run very slowly for training.
I have found that the loss becomes NAN in attention, and simply use FP32 for attention will solve the problem. In my experiments, replacing the attention block with the followed code, the model can be resumed normally. (sometimes you will need to change random seed)
The text was updated successfully, but these errors were encountered: