Skip to content

Commit

Permalink
Remove LayerNorm in Transformer (#81)
Browse files Browse the repository at this point in the history
* Rename module to TransformerAttentionNetwork

* Remove LayerNorm layers

* Fixing module loading

* Update yolotr checkpoints
  • Loading branch information
zhiqwang authored Mar 18, 2021
1 parent 293520e commit 2332a2c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 30 deletions.
4 changes: 2 additions & 2 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.yolotr import darknet_pan_tr_backbone
from yolort.models.transformer import darknet_tan_backbone
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.box_head import YoloHead, PostProcess, SetCriterion

Expand Down Expand Up @@ -112,7 +112,7 @@ def _init_test_backbone_with_pan_tr(self):
backbone_name = 'darknet_s_r4_0'
depth_multiple = 0.33
width_multiple = 0.5
backbone_with_fpn_tr = darknet_pan_tr_backbone(backbone_name, depth_multiple, width_multiple)
backbone_with_fpn_tr = darknet_tan_backbone(backbone_name, depth_multiple, width_multiple)
return backbone_with_fpn_tr

def test_backbone_with_pan_tr(self):
Expand Down
2 changes: 1 addition & 1 deletion yolort/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def yolotr(upstream_version: str = 'v4.0', export_friendly: bool = False, **kwar
export_friendly (bool): Deciding whether to use (ONNX/TVM) export friendly mode.
"""
if upstream_version == 'v4.0':
model = YOLOModule(arch="yolov5_darknet_pan_s_tr", **kwargs)
model = YOLOModule(arch="yolov5_darknet_tan_s_r40", **kwargs)
else:
raise NotImplementedError("Currently only supports v4.0 versions")

Expand Down
44 changes: 24 additions & 20 deletions yolort/models/yolotr.py → yolort/models/transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
"""
The transformer attention network blocks.
Mostly copy-paste from <https://github.com/dingyiwei/yolov5/tree/Transformer>.
"""
from torch import nn

from .common import Conv, C3
Expand All @@ -10,7 +15,7 @@
from typing import Callable, List, Optional


def darknet_pan_tr_backbone(
def darknet_tan_backbone(
backbone_name: str,
depth_multiple: float,
width_multiple: float,
Expand All @@ -19,13 +24,13 @@ def darknet_pan_tr_backbone(
version: str = 'v4.0',
):
"""
Constructs a specified DarkNet backbone with PAN on top. Freezes the specified number of
Constructs a specified DarkNet backbone with TAN on top. Freezes the specified number of
layers in the backbone.
Examples::
>>> from models.backbone_utils import darknet_pan_tr_backbone
>>> backbone = darknet_pan_tr_backbone('darknet3_1', pretrained=True, trainable_layers=3)
>>> from models.backbone_utils import darknet_tan_backbone
>>> backbone = darknet_tan_backbone('darknet3_1', pretrained=True, trainable_layers=3)
>>> # get some dummy image
>>> x = torch.rand(1, 3, 64, 64)
>>> # compute the output
Expand Down Expand Up @@ -55,20 +60,23 @@ def darknet_pan_tr_backbone(

in_channels_list = [int(gw * width_multiple) for gw in [256, 512, 1024]]

return BackboneWithPANTranformer(backbone, return_layers, in_channels_list, depth_multiple, version)
return BackboneWithTAN(backbone, return_layers, in_channels_list, depth_multiple, version)


class BackboneWithPANTranformer(BackboneWithPAN):
class BackboneWithTAN(BackboneWithPAN):
"""
Adds a TAN on top of a model.
"""
def __init__(self, backbone, return_layers, in_channels_list, depth_multiple, version):
super().__init__(backbone, return_layers, in_channels_list, depth_multiple, version)
self.pan = PathAggregationNetworkTransformer(
self.pan = TransformerAttentionNetwork(
in_channels_list,
depth_multiple,
version=version,
)


class PathAggregationNetworkTransformer(PathAggregationNetwork):
class TransformerAttentionNetwork(PathAggregationNetwork):
def __init__(
self,
in_channels_list: List[int],
Expand Down Expand Up @@ -116,24 +124,20 @@ class TransformerLayer(nn.Module):
def __init__(self, c, num_heads):
"""
Args:
c (int):
num_heads:
c (int): number of channels
num_heads: number of heads
"""
super().__init__()

self.ln1 = nn.LayerNorm(c)
self.q = nn.Linear(c, c, bias=False)
self.k = nn.Linear(c, c, bias=False)
self.v = nn.Linear(c, c, bias=False)

self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
self.ln2 = nn.LayerNorm(c)
self.fc1 = nn.Linear(c, c, bias=False)
self.fc2 = nn.Linear(c, c, bias=False)

def forward(self, x):
x_ = self.ln1(x)
x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x
x = self.ln2(x)
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
x = self.fc2(self.fc1(x)) + x
return x

Expand All @@ -142,10 +146,10 @@ class TransformerBlock(nn.Module):
def __init__(self, c1, c2, num_heads, num_layers):
"""
Args:
c1 (int): ch_in
c2 (int): ch_out
num_heads:
num_layers:
c1 (int): number of input channels
c2 (int): number of output channels
num_heads: number of heads
num_layers: number of layers
"""
super().__init__()

Expand Down
16 changes: 9 additions & 7 deletions yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from torchvision.models.utils import load_state_dict_from_url

from .backbone_utils import darknet_pan_backbone
from .yolotr import darknet_pan_tr_backbone
from .transformer import darknet_tan_backbone
from .anchor_utils import AnchorGenerator
from .box_head import YoloHead, SetCriterion, PostProcess

from typing import Tuple, Any, List, Dict, Optional

__all__ = ['YOLO', 'yolov5_darknet_pan_s_r31', 'yolov5_darknet_pan_m_r31', 'yolov5_darknet_pan_l_r31',
'yolov5_darknet_pan_s_r40', 'yolov5_darknet_pan_m_r40', 'yolov5_darknet_pan_l_r40',
'yolov5_darknet_pan_s_tr']
'yolov5_darknet_tan_s_r40']


class YOLO(nn.Module):
Expand Down Expand Up @@ -128,13 +128,15 @@ def forward(
model_urls_root = 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0'

model_urls = {
# Path Aggregation Network
'yolov5_darknet_pan_s_r31_coco': f'{model_urls_root}/yolov5_darknet_pan_s_r31_coco-eb728698.pt',
'yolov5_darknet_pan_m_r31_coco': f'{model_urls_root}/yolov5_darknet_pan_m_r31_coco-670dc553.pt',
'yolov5_darknet_pan_l_r31_coco': f'{model_urls_root}/yolov5_darknet_pan_l_r31_coco-4dcc8209.pt',
'yolov5_darknet_pan_s_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_s_r40_coco-e3fd213d.pt',
'yolov5_darknet_pan_m_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_m_r40_coco-d295cb02.pt',
'yolov5_darknet_pan_l_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_l_r40_coco-4416841f.pt',
'yolov5_darknet_pan_s_tr_coco': f'{model_urls_root}/yolov5_darknet_pan_s_tr_coco-f09f21f7.pt',
# Tranformer Attention Network
'yolov5_darknet_tan_s_r40_coco': f'{model_urls_root}/yolov5_darknet_tan_s_r40_coco-fe1069ce.pt',
}


Expand Down Expand Up @@ -303,21 +305,21 @@ def yolov5_darknet_pan_l_r40(pretrained: bool = False, progress: bool = True, nu
pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs)


def yolov5_darknet_pan_s_tr(pretrained: bool = False, progress: bool = True, num_classes: int = 80,
**kwargs: Any) -> YOLO:
def yolov5_darknet_tan_s_r40(pretrained: bool = False, progress: bool = True, num_classes: int = 80,
**kwargs: Any) -> YOLO:
r"""yolov5 small with a transformer block model from
`"dingyiwei/yolov5" <https://github.com/ultralytics/yolov5/pull/2333>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
backbone_name = 'darknet_s_r4_0'
weights_name = 'yolov5_darknet_pan_s_tr_coco'
weights_name = 'yolov5_darknet_tan_s_r40_coco'
depth_multiple = 0.33
width_multiple = 0.5
version = 'v4.0'

backbone = darknet_pan_tr_backbone(backbone_name, depth_multiple, width_multiple, version=version)
backbone = darknet_tan_backbone(backbone_name, depth_multiple, width_multiple, version=version)

anchor_grids = [[10, 13, 16, 30, 33, 23],
[30, 61, 62, 45, 59, 119],
Expand Down

0 comments on commit 2332a2c

Please sign in to comment.