Skip to content
This repository has been archived by the owner on Mar 15, 2024. It is now read-only.

A resolution for NAN #109

Closed
Andy1621 opened this issue Aug 6, 2021 · 2 comments
Closed

A resolution for NAN #109

Andy1621 opened this issue Aug 6, 2021 · 2 comments

Comments

@Andy1621
Copy link

Andy1621 commented Aug 6, 2021

As shown in ISSUE#29, the vision transformer is not stable when training it with AMP, even when the model is deep.
Without AMP, the loss will not become NAN. However, it will run very slowly for training.
I have found that the loss becomes NAN in attention, and simply use FP32 for attention will solve the problem. In my experiments, replacing the attention block with the followed code, the model can be resumed normally. (sometimes you will need to change random seed)

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        with torch.cuda.amp.autocast(enabled=False):
            q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float()   # make torchscript happy (cannot use tensor as tuple)
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2).reshape(B, N, C)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x
@Andy1621
Copy link
Author

Andy1621 commented Aug 6, 2021

I hope everyone who meets NAN can try this code and let me know if it works.

@TouvronHugo
Copy link
Contributor

Hi @Andy1621,
Thank you for your message,
Please could you post it in issue#29 instead of in a new issue.
(Don't hesitate to re-open issue#29 if needed)
This way it is easier to find all the information related to this subject ;)
Best,
Hugo

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants