Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some questions about loading pretrain model and training. #30

Closed
realgump opened this issue Apr 2, 2022 · 7 comments
Closed

Some questions about loading pretrain model and training. #30

realgump opened this issue Apr 2, 2022 · 7 comments

Comments

@realgump
Copy link

realgump commented Apr 2, 2022

Dear author:

I want to fine-tune your model on my dataset of video classification task using provided UniFormer-B model. However, I met some problems.

When I start training, the top1 error is always 100
image

I have successfully fine-tune my dataset using code facebook SlowFast with different models before.

Before I training UniFormer, I have:

  • replaced the dataset folder with my previous SlowFast project;
  • set AUG and MIXUP disable.

Have I possibly made any mistakes?

Besides that, I notice that SlowFast code using slowfast/utils/checkpoint.py to load pretrain model, while the Uniformer code add a new function in slowfast/model/uniformer.py and copy the origin slowfast/utils/checkpoint.py to slowfast/utils/checkpoint_amp.py. What is the difference between loading pretrain model from the these two functions?

您好:

我想用UniFormer模型在我的数据集上fine-tune,我的任务是video classification,用的是文档中提供的UniFormer-B model,但是遇到了一些问题。

我训练的时候,top-1 error一直是100,如上图。

我之前用facebook SlowFast 的code在不同的模型上面都可以成功训练。训Uniformer的时候,我把我之前project的dataset文件夹直接覆盖了过来,然后关掉了AUG和MIXUP。不知道是哪里出了问题?

此外,SlowFast的code是在slowfast/utils/checkpoint.py 里面读的预训练模型,但我看您的code是在 slowfast/model/uniformer.py 重新实现了一个函数来读,并且把原来的slowfast/utils/checkpoint.py重命名为slowfast/utils/checkpoint_amp.py(但我好像没看到两者之间的差别)。我想了解一下用uniformer.py和checkpoint_amp.py加载预训练模型有区别吗?

@Andy1621
Copy link
Collaborator

Andy1621 commented Apr 2, 2022

Thanks for your question. Maybe you do not load the pre-trained models correctly.
For using the pre-trained weight of the video model, you have to download the weight and add the path in uniformer.py.

model_path = 'path_to_models'
model_path = {
'uniformer_small_in1k': os.path.join(model_path, 'uniformer_small_in1k.pth'),
'uniformer_small_k400_8x8': os.path.join(model_path, 'uniformer_small_k400_8x8.pth'),
'uniformer_small_k400_16x4': os.path.join(model_path, 'uniformer_small_k400_16x4.pth'),
'uniformer_small_k600_16x4': os.path.join(model_path, 'uniformer_small_k600_16x4.pth'),
'uniformer_base_in1k': os.path.join(model_path, 'uniformer_base_in1k.pth'),
'uniformer_base_k400_8x8': os.path.join(model_path, 'uniformer_base_k400_8x8.pth'),
'uniformer_base_k400_16x4': os.path.join(model_path, 'uniformer_base_k400_16x4.pth'),
'uniformer_base_k600_16x4': os.path.join(model_path, 'uniformer_base_k600_16x4.pth'),
}

In build.py, it will load the pre-trained model.
if cfg.MODEL.ARCH in ['uniformer']:
checkpoint = model.get_pretrained_model(cfg)
if checkpoint:
logger.info('load pretrained model')
model.load_state_dict(checkpoint, strict=False)

You can check the log and find whether it outputs:

[INFO] build.py:  46: load pretrained model

In my codebase, the checkpoint_amp.py is used to save some weight for loss_scaler.

def save_checkpoint(path_to_job, model, optimizer, loss_scaler, epoch, cfg):

And I do not use it to load pre-trained models. I use the code in uniformer.py to load it. It will inflate the 2D kernel smartly.
def inflate_weight(self, weight_2d, time_dim, center=False):
if center:
weight_3d = torch.zeros(*weight_2d.shape)
weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
middle_idx = time_dim // 2
weight_3d[:, :, middle_idx, :, :] = weight_2d
else:
weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
weight_3d = weight_3d / time_dim
return weight_3d
def get_pretrained_model(self, cfg):
if cfg.UNIFORMER.PRETRAIN_NAME:
checkpoint = torch.load(model_path[cfg.UNIFORMER.PRETRAIN_NAME], map_location='cpu')
if 'model' in checkpoint:
checkpoint = checkpoint['model']
elif 'model_state' in checkpoint:
checkpoint = checkpoint['model_state']
state_dict_3d = self.state_dict()
for k in checkpoint.keys():
if checkpoint[k].shape != state_dict_3d[k].shape:
if len(state_dict_3d[k].shape) <= 2:
logger.info(f'Ignore: {k}')
continue
logger.info(f'Inflate: {k}, {checkpoint[k].shape} => {state_dict_3d[k].shape}')
time_dim = state_dict_3d[k].shape[2]
checkpoint[k] = self.inflate_weight(checkpoint[k], time_dim)
if self.num_classes != checkpoint['head.weight'].shape[0]:
del checkpoint['head.weight']
del checkpoint['head.bias']
return checkpoint
else:
return None

@realgump
Copy link
Author

realgump commented Apr 2, 2022

Thanks for your quick response. I think I did load the pretrain model, however, there's several warnings and I am not sure if it's matter.

[04/02 11:38:20][INFO] uniformer.py: 287: Use checkpoint: False
[04/02 11:38:20][INFO] uniformer.py: 288: Checkpoint number: [0, 0, 0, 0]
[04/02 11:38:32][INFO] uniformer.py: 410: Ignore: head.weight
[04/02 11:38:32][INFO] uniformer.py: 410: Ignore: head.bias
[04/02 11:38:32][INFO] build.py: 45: load pretrained model
[04/02 11:38:42][INFO] misc.py: 184: Params: 49,657,761
[04/02 11:38:42][INFO] misc.py: 185: Mem: 0.1866145133972168 MB
[04/02 11:38:44][WARNING] jit_analysis.py: 499: Unsupported operator aten::add encountered 120 time(s)
[04/02 11:38:44][WARNING] jit_analysis.py: 499: Unsupported operator aten::gelu encountered 40 time(s)
[04/02 11:38:44][WARNING] jit_analysis.py: 499: Unsupported operator aten::mul encountered 27 time(s)
[04/02 11:38:44][WARNING] jit_analysis.py: 499: Unsupported operator aten::softmax encountered 27 time(s)
[04/02 11:38:44][WARNING] jit_analysis.py: 499: Unsupported operator aten::mean encountered 1 time(s)
[04/02 11:38:44][WARNING] jit_analysis.py: 511: The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
blocks1.1.drop_path, blocks1.2.drop_path, blocks1.3.drop_path, blocks1.4.drop_path, blocks2.0.drop_path, blocks2.1.drop_path, blocks2.2.drop_path, blocks2.3.drop_path, blocks2.4.drop_path, blocks2.5.drop_path, blocks2.6.drop_path, blocks2.7.drop_path, blocks3.0.drop_path, blocks3.1.drop_path, blocks3.10.drop_path, blocks3.11.drop_path, blocks3.12.drop_path, blocks3.13.drop_path, blocks3.14.drop_path, blocks3.15.drop_path, blocks3.16.drop_path, blocks3.17.drop_path, blocks3.18.drop_path, blocks3.19.drop_path, blocks3.2.drop_path, blocks3.3.drop_path, blocks3.4.drop_path, blocks3.5.drop_path, blocks3.6.drop_path, blocks3.7.drop_path, blocks3.8.drop_path, blocks3.9.drop_path, blocks4.0.drop_path, blocks4.1.drop_path, blocks4.2.drop_path, blocks4.3.drop_path, blocks4.4.drop_path, blocks4.5.drop_path, blocks4.6.drop_path
[04/02 11:38:44][INFO] misc.py: 186: Flops: 96.741886464 G

@realgump
Copy link
Author

realgump commented Apr 2, 2022

There's another error I forgot to mention:

the training is always interrupted with:

RuntimeError: ERROR: Got NaN losses 2022-04-02 11:42:54.428345

@Andy1621
Copy link
Collaborator

Andy1621 commented Apr 2, 2022

NaN is normal for training Transformer. There are some ways to handle this problem:

  1. Lower learning rate: large learning rate is unstable for training.
  2. Weak data augmentation: rand-m7-n4-mstd0.5-inc1 is too strong for some datasets and models, you can set weak data augmentation, such as rand.
  3. Do not use AMP/FP16: Close AMP or use FP32 for attention. You can see A resolution for NAN facebookresearch/deit#109

@Andy1621
Copy link
Collaborator

Andy1621 commented Apr 2, 2022

Moreover, you may use data augmentation for CNN to train UniFormer?
For a small dataset, it is okay sometimes, but you have to set a very low learning rate such as 1e-4. Previous CNN models often adopt large learning rate and weight decay. Besides, CNN models are more powerful for small datasets.

But for a large dataset, I suggest you use UniFormer and adopt the same training config as I used for Kinetics.

@Andy1621
Copy link
Collaborator

Andy1621 commented Apr 2, 2022

Thanks for your quick response. I think I did load the pretrain model, however, there's several warnings and I am not sure if it's matter.

[04/02 11:38:20][INFO] uniformer.py: 287: Use checkpoint: False
[04/02 11:38:20][INFO] uniformer.py: 288: Checkpoint number: [0, 0, 0, 0]
[04/02 11:38:32][INFO] uniformer.py: 410: Ignore: head.weight
[04/02 11:38:32][INFO] uniformer.py: 410: Ignore: head.bias
[04/02 11:38:32][INFO] build.py: 45: load pretrained model
[04/02 11:38:42][INFO] misc.py: 184: Params: 49,657,761
[04/02 11:38:42][INFO] misc.py: 185: Mem: 0.1866145133972168 MB
[04/02 11:38:44][WARNING] jit_analysis.py: 499: Unsupported operator aten::add encountered 120 time(s)
[04/02 11:38:44][WARNING] jit_analysis.py: 499: Unsupported operator aten::gelu encountered 40 time(s)
[04/02 11:38:44][WARNING] jit_analysis.py: 499: Unsupported operator aten::mul encountered 27 time(s)
[04/02 11:38:44][WARNING] jit_analysis.py: 499: Unsupported operator aten::softmax encountered 27 time(s)
[04/02 11:38:44][WARNING] jit_analysis.py: 499: Unsupported operator aten::mean encountered 1 time(s)
[04/02 11:38:44][WARNING] jit_analysis.py: 511: The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
blocks1.1.drop_path, blocks1.2.drop_path, blocks1.3.drop_path, blocks1.4.drop_path, blocks2.0.drop_path, blocks2.1.drop_path, blocks2.2.drop_path, blocks2.3.drop_path, blocks2.4.drop_path, blocks2.5.drop_path, blocks2.6.drop_path, blocks2.7.drop_path, blocks3.0.drop_path, blocks3.1.drop_path, blocks3.10.drop_path, blocks3.11.drop_path, blocks3.12.drop_path, blocks3.13.drop_path, blocks3.14.drop_path, blocks3.15.drop_path, blocks3.16.drop_path, blocks3.17.drop_path, blocks3.18.drop_path, blocks3.19.drop_path, blocks3.2.drop_path, blocks3.3.drop_path, blocks3.4.drop_path, blocks3.5.drop_path, blocks3.6.drop_path, blocks3.7.drop_path, blocks3.8.drop_path, blocks3.9.drop_path, blocks4.0.drop_path, blocks4.1.drop_path, blocks4.2.drop_path, blocks4.3.drop_path, blocks4.4.drop_path, blocks4.5.drop_path, blocks4.6.drop_path
[04/02 11:38:44][INFO] misc.py: 186: Flops: 96.741886464 G

Maybe you can set strict=True in build.py and check whether all the weights are loaded correctly.

@realgump
Copy link
Author

realgump commented Apr 2, 2022

Thanks for your detailed response. Your advice about training really helps me a lot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants