Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem in migrating InternImage into mmdetection 3 #11039

Open
VanLinLin opened this issue Oct 15, 2023 · 5 comments
Open

Problem in migrating InternImage into mmdetection 3 #11039

VanLinLin opened this issue Oct 15, 2023 · 5 comments
Assignees
Labels
reimplementation Issues in model reimplementation

Comments

@VanLinLin
Copy link

Describe the issue

I'm migrating the InternImage backbone into mmdetection 3 from InternImage repo(it's built by mmdetection 2.28), it's almost done. However, when I try to train the model, it occurred some error and I can't figure out what happened.

Reproduction

  1. What command or script did you run?
python mmdetection/tools/train.py
  1. What config dir you run?
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
_base_ = [
    '../_base_/datasets/coco_detection.py',
    '../_base_/default_runtime.py',
    '../_base_/schedules/schedule_1x.py',
]
pretrained = 'https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_l_22k_192to384.pth'
model = dict(
    type='DINO',
    num_queries=900,  # num_matching_queries
    with_box_refine=True,
    as_two_stage=True,
    backbone=dict(
        type='InternImage',
        core_op='DCNv3',
        channels=160,
        depths=[5, 5, 22, 5],
        groups=[10, 20, 40, 80],
        mlp_ratio=4.,
        drop_path_rate=0.4,
        norm_layer='LN',
        layer_scale=1.0,
        offset_scale=2.0,
        post_norm=True,
        with_cp=False,
        out_indices=(1, 2, 3),
        init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
    neck=dict(
        type='ChannelMapper',
        in_channels=[320, 640, 1280],
        kernel_size=1,
        out_channels=256,
        act_cfg=None,
        norm_cfg=dict(type='GN', num_groups=32),
        num_outs=4),
    encoder=dict(
        num_layers=6,
        layer_cfg=dict(
            self_attn_cfg=dict(embed_dims=256, num_levels=4,
                               dropout=0.0),  # 0.1 for DeformDETR
            ffn_cfg=dict(
                embed_dims=256,
                feedforward_channels=2048,  # 1024 for DeformDETR
                ffn_drop=0.0))),  # 0.1 for DeformDETR
    decoder=dict(
        num_layers=6,
        return_intermediate=True,
        layer_cfg=dict(
            self_attn_cfg=dict(embed_dims=256, num_heads=8,
                               dropout=0.0),  # 0.1 for DeformDETR
            cross_attn_cfg=dict(embed_dims=256, num_levels=4,
                                dropout=0.0),  # 0.1 for DeformDETR
            ffn_cfg=dict(
                embed_dims=256,
                feedforward_channels=2048,  # 1024 for DeformDETR
                ffn_drop=0.0)),  # 0.1 for DeformDETR
        post_norm_cfg=None),
    positional_encoding=dict(
        num_feats=128,
        normalize=True,
        offset=0.0,  # -0.5 for DeformDETR
        temperature=20),  # 10000 for DeformDETR
    bbox_head=dict(
        type='DINOHead',
        # num_classes=80,
        num_classes=5,
        sync_cls_avg_factor=True,
        loss_cls=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2.0,
            alpha=0.25,
            loss_weight=1.0),  # 2.0 in DeformDETR
        loss_bbox=dict(type='L1Loss', loss_weight=5.0),
        loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
    dn_cfg=dict(  # TODO: Move to model.train_cfg ?
        label_noise_scale=0.5,
        box_noise_scale=1.0,  # 0.4 for DN-DETR
        group_cfg=dict(dynamic=True, num_groups=None,
                       num_dn_queries=100)),  # TODO: half num_dn_queries
    # training and testing settings
    train_cfg=dict(
        assigner=dict(
            type='HungarianAssigner',
            match_costs=[
                dict(type='FocalLossCost', weight=2.0),
                dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
                dict(type='IoUCost', iou_mode='giou', weight=2.0)
            ])),
    test_cfg=dict(max_per_img=300))
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='RandomFlip', prob=0.5),
    dict(
        type='AutoAugment',
        policies=[
            [
                dict(
                    type='RandomChoiceResize',
                    scales=[(480, 1333), (512, 1333), (544, 1333),
                            (576, 1333), (608, 1333), (640, 1333),
                            (672, 1333), (704, 1333), (736, 1333),
                            (768, 1333), (800, 1333)],
                    keep_ratio=True)
            ],
            [
                dict(
                    type='RandomChoiceResize',
                    scales=[(400, 4200), (500, 4200), (600, 4200)],
                    keep_ratio=True),
                dict(
                    type='RandomCrop',
                    crop_type='absolute_range',
                    crop_size=(384, 600),
                    allow_negative_crop=False),
                dict(
                    type='RandomChoiceResize',
                    scales=[(480, 1333), (512, 1333), (544, 1333),
                            (576, 1333), (608, 1333), (640, 1333),
                            (672, 1333), (704, 1333), (736, 1333),
                            (768, 1333), (800, 1333)],
                    # override=True,
                    keep_ratio=True)
            ]
        ]),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='PackDetInputs')
]
# By default, models are trained on 8 GPUs with 2 images per GPU
dataset_type = 'CocoDataset'  # 数据集类型,这将被用来定义数据集。
data_root = 'data/packages_coco/'
metainfo = {
    'classes': ("Folding_Knife", "Straight_Knife", "Scissor",
                "Utility_Knife", "Multi-tool_Knife")
}
train_dataloader = dict(
    batch_size=6,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    # 批数据采样器,用于确保每一批次内的数据拥有相似的长宽比,可用于节省显存
    batch_sampler=dict(type='AspectRatioBatchSampler'),
    dataset=dict(
        type=dataset_type,
        metainfo=metainfo,
        data_root=data_root,
        ann_file='annotations/instances_train2017.json',
        data_prefix=dict(img='train2017/'),  # 图片路径前缀
        pipeline=train_pipeline),
)

val_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    drop_last=False,  # 是否丢弃最后未能组成一个批次的数据
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        metainfo=metainfo,
        data_root=data_root,
        ann_file='annotations/instances_val2017.json',
        data_prefix=dict(img='val2017/'),
        test_mode=True,
    ),
)
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader

val_evaluator = dict(  # 验证过程使用的评测器
    type='CocoMetric',  # 用于评估检测和实例分割的 AR、AP 和 mAP 的 coco 评价指标
    ann_file=data_root + 'annotations/instances_val2017.json',  # 标注文件路径
    metric=['bbox'],  # 需要计算的评价指标,`bbox` 用于检测,`segm` 用于实例分割
    format_only=False)
test_evaluator = val_evaluator  # 测试过程使用的评测器

train_cfg = dict(
    # 训练循环的类型,请参考 https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/loops.py
    type='EpochBasedTrainLoop',
    max_epochs=30,  # 最大训练轮次
    val_interval=1)  # 验证间隔。每个 epoch 验证一次
val_cfg = dict(type='ValLoop')  # 验证循环的类型
test_cfg = dict(type='TestLoop')  # 测试循环的类型


# optimizer
optim_wrapper_cfg = dict(  # 优化器封装的配置
    type='OptimWrapper',  # 优化器封装的类型。可以切换至 AmpOptimWrapper 来启用混合精度训练
    optimizer=dict(
        _delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001,
        constructor='CustomLayerDecayOptimizerConstructor',
        paramwise_cfg=dict(num_layers=37, layer_decay_rate=0.90,
                           depths=[5, 5, 22, 5])),
    optimizer_config=dict(
        _delete_=True, grad_clip=dict(max_norm=0.1, norm_type=2))
)

# learning policy
param_scheduler = [
    dict(
        type='LinearLR',  # 使用线性学习率预热
        start_factor=0.001,  # 学习率预热的系数
        by_epoch=False,  # 按 iteration 更新预热学习率
        begin=0,  # 从第一个 iteration 开始
        end=500),  # 到第 500 个 iteration 结束
    dict(
        type='MultiStepLR',  # 在训练过程中使用 multi step 学习率策略
        by_epoch=True,  # 按 epoch 更新学习率
        begin=0,   # 从第一个 epoch 开始
        end=12,  # 到第 12 个 epoch 结束
        milestones=[8, 11],  # 在哪几个 epoch 进行学习率衰减
        gamma=0.1)  # 学习率衰减系数
]

default_hooks = dict(
    checkpoint=dict(
        type='CheckpointHook',
        save_best='auto',
        max_keep_ckpts=3,
        interval=1))

vis_backends = [
    dict(type='LocalVisBackend'),
    dict(type='TensorboardVisBackend')
]
visualizer = dict(
    type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
  1. Did you make any modifications on the code or config? Did you understand what you have modified? Yes
  2. What dataset did you use? SIXray

Environment

sys.platform: win32
Python: 3.9.13 (tags/v3.9.13:6de2ca5, May 17 2022, 16:36:42) [MSC v.1929 64 bit (AMD64)]
CUDA available: True
numpy_random_seed: 2147483648
GPU 0: NVIDIA GeForce RTX 3060
CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7
NVCC: Cuda compilation tools, release 11.7, V11.7.99
MSVC: Microsoft (R) C/C++ Optimizing Compiler Version 19.32.31332 for x64
GCC: n/a
PyTorch: 1.13.1+cu117
PyTorch compiling details: PyTorch built with:
  - C++ Version: 199711
  - MSVC 192829337
  - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 2019
  - LAPACK is enabled (usually provided by MKL)
  - CPU capability usage: AVX2
  - CUDA Runtime 11.7
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.5
  - Magma 2.5.4
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.7, CUDNN_VERSION=8.5.0, CXX_COMPILER=C:/actions-runner/_work/pytorch/pytorch/builder/windows/tmp_bin/sccache-cl.exe, CXX_FLAGS=/DWIN32 /D_WINDOWS /GR /EHsc /w /bigobj -DUSE_PTHREADPOOL -openmp:experimental -IC:/actions-runner/_work/pytorch/pytorch/builder/windows/mkl/include -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_FBGEMM -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.13.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON, USE_ROCM=OFF,

TorchVision: 0.14.1+cu117
OpenCV: 4.8.1
MMEngine: 0.9.0
MMDetection: 3.1.0+f78af77

Results

Traceback (most recent call last):
  File "F:\Lab\mmdetection3\tools\train.py", line 137, in <module>
    main()
  File "F:\Lab\mmdetection3\tools\train.py", line 133, in main
    runner.train()
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\mmengine\runner\runner.py", line 1777, in train
    model = self.train_loop.run()  # type: ignore
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\mmengine\runner\loops.py", line 96, in run
    self.run_epoch()
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\mmengine\runner\loops.py", line 112, in run_epoch
    self.run_iter(idx, data_batch)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\mmengine\runner\loops.py", line 128, in run_iter
    outputs = self.runner.model.train_step(
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\mmengine\model\base_model\base_model.py", line 114, in train_step
    losses = self._run_forward(data, mode='loss')  # type: ignore
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\mmengine\model\base_model\base_model.py", line 346, in _run_forward
    results = self(**data, mode=mode)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\mmdet-3.1.0-py3.9.egg\mmdet\models\detectors\base.py", line 92, in forward
    return self.loss(inputs, data_samples)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\mmdet-3.1.0-py3.9.egg\mmdet\models\detectors\base_detr.py", line 98, in loss
    img_feats = self.extract_feat(batch_inputs)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\mmdet-3.1.0-py3.9.egg\mmdet\models\detectors\base_detr.py", line 237, in extract_feat
    x = self.backbone(batch_inputs)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\mmdet-3.1.0-py3.9.egg\mmdet\models\backbones\intern_image.py", line 706, in forward
    x = self.patch_embed(x)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\mmdet-3.1.0-py3.9.egg\mmdet\models\backbones\intern_image.py", line 275, in forward
    x = self.conv1(x)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\torch\nn\modules\conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "F:\Lab\mmdetection3\.venv\lib\site-packages\torch\nn\modules\conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
TypeError: conv2d() received an invalid combination of arguments - got (list, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (list, Parameter, Parameter, tuple, tuple, tuple, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (list, Parameter, Parameter, tuple, tuple, tuple, int)

Issue fix

Not yet
@VanLinLin VanLinLin added the reimplementation Issues in model reimplementation label Oct 15, 2023
@hhaAndroid
Copy link
Collaborator

@VanLinLin The configuration seems fine. I suspect there might be an error in your "InternImage" code.

@VanLinLin
Copy link
Author

@hhaAndroid Hi, thanks for your reply. I'm finished migrating InternImage into mmdetection3, that's pretty cool and thank you again!

@shreejalt
Copy link

@VanLinLin
What changes did you make to make it run in v3 of mmdetection? I only have one mmdetection repo with version 3.X, so i want to know that how hard is it to change the source code

@AlanBlanchet
Copy link

AlanBlanchet commented Apr 2, 2024

I'm having the same problem. The batch_size information is done list-wise while conv2d expects the item to be a tensor of shape (B, C, H, W). Instead we have a list of tensors of shape (C, H, W) :

TypeError: conv2d() received an invalid combination of arguments - got (list, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (list of [Tensor, Tensor], Parameter, NoneType, tuple of (int, int), tuple of (int, int), tuple of (int, int), int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (list of [Tensor, Tensor], Parameter, NoneType, tuple of (int, int), tuple of (int, int), tuple of (int, int), int)

Here I have a batch size of 2 ("list of [Tensor, Tensor]")

@AlanBlanchet
Copy link

Solution found in this issue
It seams that you have to specify the collate_fn in your dataloader config :

train_dataloader = dict(
    batch_size=2,  # Batch size of a single GPU
    num_workers=2,  # Worker to pre-fetch data for each single GPU
    persistent_workers=True,
    collate_fn=dict(type="default_collate"),
    ....
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
reimplementation Issues in model reimplementation
Projects
None yet
Development

No branches or pull requests

5 participants