Skip to content
This repository has been archived by the owner on Mar 15, 2024. It is now read-only.

Loss NAN for Deit Base #29

Closed
ChengyueGongR opened this issue Jan 9, 2021 · 24 comments
Closed

Loss NAN for Deit Base #29

ChengyueGongR opened this issue Jan 9, 2021 · 24 comments

Comments

@ChengyueGongR
Copy link

I have reproduced the small and tiny model but met with problems for reproducing the base model with 224 and 384 image size. With a large probability, the loss came to NAN after training with few epochs.
My setting is 16 GPUs and the batch size is 64 on each GPU and I do not change any hyper-parameters in run_with_submitit.py . Do you have any idea to solve this problem?
Thanks for your help.

@vtddggg
Copy link

vtddggg commented Jan 9, 2021

I also met this problem. I use 32 GPUs for training.

@vtddggg
Copy link

vtddggg commented Jan 11, 2021

Give a simple solution for my case. I found the transformer training is sensitive to the learning rate. We must keep a vary small learning rate in the training stage (it would be best for lr < 0.0015). Otherwise the gradient will become nan caused by amp. So, you can first try to reduce the learning rate.

Another alternative way is to dis-enable the amp. I have tried this way and it also works for me.

@fmassa
Copy link
Contributor

fmassa commented Jan 11, 2021

Hi,

I've just run the default command-line for DeiT base with the state of the codebase as of yesterday on 16 GPUs, and training is going fine so far (we are already at epoch 100 without issues).

We are using PyTorch 1.7.0 with CUDA 10.1, and the default hyperparameters work fine for us on this setup.

Can you share your PyTorch / torchvision / CUDA versions?

@ChengyueGongR
Copy link
Author

@vtddggg Thanks for your advice, it works for me.

@xwjabc
Copy link

xwjabc commented Jan 16, 2021

Hi @vtddggg @ChengyueGongR , I have also encountered the NaN issue. I wonder if there is an easy way to disable amp?
It seems with torch.cuda.amp.autocast(): in engine.py should removed. However, I wonder if we need to rewrite the loss_scaler? Thanks!

@cxxgtxy
Copy link

cxxgtxy commented Jan 30, 2021

I meet with the same problem

@haoweiz23
Copy link

haoweiz23 commented Feb 18, 2021

Hi,

I've just run the default command-line for DeiT base with the state of the codebase as of yesterday on 16 GPUs, and training is going fine so far (we are already at epoch 100 without issues).

We are using PyTorch 1.7.0 with CUDA 10.1, and the default hyperparameters work fine for us on this setup.

Can you share your PyTorch / torchvision / CUDA versions?

Hi, I meet the NaN problem. My version is :
torch == 1.7.1
cuda == 11.2
Gpu num = 8

@HubHop
Copy link

HubHop commented Mar 9, 2021

Same issue with torch 1.7.1, cuda 11.2, torchvision 0.8.2

@cxxgtxy
Copy link

cxxgtxy commented Mar 11, 2021

Give a simple solution for my case. I found the transformer training is sensitive to the learning rate. We must keep a vary small learning rate in the training stage (it would be best for lr < 0.0015). Otherwise the gradient will become nan caused by amp. So, you can first try to reduce the learning rate.

Another alternative way is to dis-enable the amp. I have tried this way and it also works for me.

Changing the learning rate of the base model degrades its performance with a clear margin.

@HubHop
Copy link

HubHop commented Mar 11, 2021

Give a simple solution for my case. I found the transformer training is sensitive to the learning rate. We must keep a vary small learning rate in the training stage (it would be best for lr < 0.0015). Otherwise the gradient will become nan caused by amp. So, you can first try to reduce the learning rate.
Another alternative way is to dis-enable the amp. I have tried this way and it also works for me.

Changing the learning rate of the base model degrades its performance with a clear margin.

I tried to disable the amp and replace the scaler with vanilla loss.backward() and optimizer.step(), unfortunately this won't help. Setting longer warmup epoch also has this problem.

As far as I know, the input would never be NAN. My experiment shows the NAN happens in the input x after the first Transformer block at a specific iteration step in epoch 7.

@wangpichao
Copy link

Give a simple solution for my case. I found the transformer training is sensitive to the learning rate. We must keep a vary small learning rate in the training stage (it would be best for lr < 0.0015). Otherwise the gradient will become nan caused by amp. So, you can first try to reduce the learning rate.
Another alternative way is to dis-enable the amp. I have tried this way and it also works for me.

Changing the learning rate of the base model degrades its performance with a clear margin.

Have you solved the problem?

@cheerss
Copy link

cheerss commented Apr 6, 2021

Same problem for me. "Loss is nan, stop training". My packages are pytorch-1.7.1/CUDA-10.1

Also, deit-tiny and deit-small work well, but deit-base does not.

@TouvronHugo
Copy link
Contributor

Hi Everyone,
Concerning DeiT I have not encountered any NaN problems with the Tiny, Small and Base models.
Nevertheless, in our last paper Going deeper with Image Transformers with deeper architectures we had this problem. In order to solve it we used the LayerScale method with an adjustment of the stochastic depth coefficient. This may help you to solve your stability problems.
Best,
Hugo

@TouvronHugo
Copy link
Contributor

As there is no more activity, I am closing the issue, don't hesitate to reopen it if necessary

@liuzhuang13
Copy link

In my experiments on the base model, I found that if I disable RepeatedAug, i.e., run with "--no-repeated-aug", the NAN error would disappear. This helps even if I run with "--drop-path 0", which increases the probability of NAN on the default setting in my observation.

@ShoufaChen
Copy link

Hi, @liuzhuang13

Thanks for your suggestion. May I ask will the performance drop with --no-repeated-aug?

@netw0rkf10w
Copy link

netw0rkf10w commented Jul 11, 2021

Same issue when submitting the following command

python -u main.py --model deit_base_patch16_224 --batch-size 64 --data-path ~/data/imagenet --output_dir ~/experiments/deit_base

to 16 GPUs (4 Slurm nodes).

There seems to be a tricky bug somewhere in the code...

(Edited to add more information: PyTorch 1.9, CUDA 10.2)

@liuzhuang13
Copy link

Hi, @liuzhuang13

Thanks for your suggestion. May I ask will the performance drop with --no-repeated-aug?

@ShoufaChen Yes, in my observation the performance drops from 81.5% to 81.1%, when I remove repeated aug. In my later experiments, I found if you don't use AMP, the NAN will also disappear, but the training will be much slower

@liuzhuang13
Copy link

@ShoufaChen I just found that in DeiT it is still not converging even if I disable AMP, though it won't report NAN. The case I mentioned where disabling AMP can solve the NAN problem is for another architecture. Not sure this is useful but just FYI

@ShoufaChen
Copy link

@liuzhuang13 Great thanks for your information.

@Andy1621
Copy link

Andy1621 commented Aug 6, 2021

Without AMP, the loss will not become NAN. However, it will run very slowly for training.
I have found that the loss becomes NAN in attention, and simply use FP32 for attention will solve the problem. In my experiments, replacing the attention block with the followed code, the model can be resumed normally. (sometimes you will need to change random seed)

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        with torch.cuda.amp.autocast(enabled=False):
            q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float()   # make torchscript happy (cannot use tensor as tuple)
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2).reshape(B, N, C)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x

I hope everyone who meets NAN can try this code and let me know if it works.

@HeegonJin
Copy link

@Andy1621 Your suggestion did not solve my case. I detoured the same problem by changing the random seed.

@Andy1621
Copy link

Andy1621 commented Mar 2, 2022

@HeegonJin You can try to use FP32, which will cause more GPU cost. But it's stable. Or you can use smaller learning rate or weak data augmentation.

@cxxgtxy
Copy link

cxxgtxy commented Mar 28, 2022

It seems that this issue is not addressed. I have run many times, the NAN of the DeiT-Base is easy to be reproduced.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

16 participants