Skip to content

Commit

Permalink
Move load_from_ultralytics into _checkpoint.py (#373)
Browse files Browse the repository at this point in the history
* Move load_from_ultralytics into _checkpoint.py

* Minor updates

* Replace asserts with exceptions

* Apply pre-commit

* Fix unittest

* Fix path of load_from_ultralytics

* Fixing lint
  • Loading branch information
zhiqwang authored Mar 22, 2022
1 parent 7f81427 commit 5942ce8
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 232 deletions.
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.io import read_image
from yolort import models
from yolort.models import YOLOv5
from yolort.models._utils import load_from_ultralytics
from yolort.models._checkpoint import load_from_ultralytics
from yolort.utils import get_image_from_url, read_image_to_tensor
from yolort.utils.image_utils import box_cxcywh_to_xyxy
from yolort.v5 import letterbox, scale_coords, attempt_download
Expand Down
216 changes: 216 additions & 0 deletions yolort/models/_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) 2020, yolort team. All rights reserved.

from functools import reduce
from typing import Dict, List, Optional

from torch import nn
from yolort.v5 import get_yolov5_size, load_yolov5_model

from .backbone_utils import darknet_pan_backbone
from .box_head import YOLOHead

__all__ = ["load_from_ultralytics"]


def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
"""
Allows the user to load model state file from the checkpoint trained from
the ultralytics/yolov5.
Args:
checkpoint_path (str): Path of the YOLOv5 checkpoint model.
version (str): upstream version released by the ultralytics/yolov5, Possible
values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0".
"""

if version not in ["r3.1", "r4.0", "r6.0"]:
raise NotImplementedError(
f"Currently does not support version: {version}. Feel free to file an issue "
"labeled enhancement to us."
)

checkpoint_yolov5 = load_yolov5_model(checkpoint_path)
num_classes = checkpoint_yolov5.yaml["nc"]
strides = checkpoint_yolov5.stride
# YOLOv5 will change the anchors setting when using the auto-anchor mechanism. So we
# use the following formula to compute the anchor_grids instead of attaching it via
# checkpoint_yolov5.yaml["anchors"]
num_anchors = checkpoint_yolov5.model[-1].anchors.shape[1]
anchor_grids = (
(checkpoint_yolov5.model[-1].anchors * checkpoint_yolov5.model[-1].stride.view(-1, 1, 1))
.reshape(1, -1, 2 * num_anchors)
.tolist()[0]
)

depth_multiple = checkpoint_yolov5.yaml["depth_multiple"]
width_multiple = checkpoint_yolov5.yaml["width_multiple"]

use_p6 = False
if len(strides) == 4:
use_p6 = True

if use_p6:
inner_block_maps = {"0": "11", "1": "12", "3": "15", "4": "16", "6": "19", "7": "20"}
layer_block_maps = {"0": "23", "1": "24", "2": "26", "3": "27", "4": "29", "5": "30", "6": "32"}
p6_block_maps = {"0": "9", "1": "10"}
head_ind = 33
head_name = "m"
else:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}
p6_block_maps = None
head_ind = 24
head_name = "m"

convert_yolo_checkpoint = CheckpointConverter(
depth_multiple,
width_multiple,
inner_block_maps=inner_block_maps,
layer_block_maps=layer_block_maps,
p6_block_maps=p6_block_maps,
strides=strides,
anchor_grids=anchor_grids,
head_ind=head_ind,
head_name=head_name,
num_classes=num_classes,
version=version,
use_p6=use_p6,
)
convert_yolo_checkpoint.updating(checkpoint_yolov5)
state_dict = convert_yolo_checkpoint.model.half().state_dict()

size = get_yolov5_size(depth_multiple, width_multiple)

return {
"num_classes": num_classes,
"depth_multiple": depth_multiple,
"width_multiple": width_multiple,
"strides": strides,
"anchor_grids": anchor_grids,
"use_p6": use_p6,
"size": size,
"state_dict": state_dict,
}


class ModelWrapper(nn.Module):
def __init__(self, backbone, head):
super().__init__()
self.backbone = backbone
self.head = head


class CheckpointConverter:
"""
Update checkpoint from ultralytics yolov5.
"""

def __init__(
self,
depth_multiple: float,
width_multiple: float,
inner_block_maps: Optional[Dict[str, str]] = None,
layer_block_maps: Optional[Dict[str, str]] = None,
p6_block_maps: Optional[Dict[str, str]] = None,
strides: Optional[List[int]] = None,
anchor_grids: Optional[List[List[float]]] = None,
head_ind: int = 24,
head_name: str = "m",
num_classes: int = 80,
version: str = "r6.0",
use_p6: bool = False,
) -> None:

# Configuration for making the keys consistent
if inner_block_maps is None:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
self.inner_block_maps = inner_block_maps
if layer_block_maps is None:
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}
self.layer_block_maps = layer_block_maps
self.p6_block_maps = p6_block_maps
self.head_ind = head_ind
self.head_name = head_name

# Set model
yolov5_size = get_yolov5_size(depth_multiple, width_multiple)
backbone_name = f"darknet_{yolov5_size}_{version.replace('.', '_')}"

backbone = darknet_pan_backbone(
backbone_name, depth_multiple, width_multiple, version=version, use_p6=use_p6
)
num_anchors = len(anchor_grids[0]) // 2
head = YOLOHead(backbone.out_channels, num_anchors, strides, num_classes)
# Only backbone and head contain parameters inside, so we only wrap them both here.
self.model = ModelWrapper(backbone, head)

def updating(self, state_dict):
# Obtain module state
state_dict = obtain_module_sequential(state_dict)

# Update backbone weights
for name, params in self.model.backbone.body.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, None))

for name, buffers in self.model.backbone.body.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, None))

# Update PAN weights
# Updating P6 weights
if self.p6_block_maps is not None:
for name, params in self.model.backbone.pan.intermediate_blocks.p6.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, self.p6_block_maps))

for name, buffers in self.model.backbone.pan.intermediate_blocks.p6.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, self.p6_block_maps))

# Updating inner_block weights
for name, params in self.model.backbone.pan.inner_blocks.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, self.inner_block_maps))

for name, buffers in self.model.backbone.pan.inner_blocks.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, self.inner_block_maps))

# Updating layer_block weights
for name, params in self.model.backbone.pan.layer_blocks.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, self.layer_block_maps))

for name, buffers in self.model.backbone.pan.layer_blocks.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, self.layer_block_maps))

# Update YOLOHead weights
for name, params in self.model.head.named_parameters():
params.data.copy_(self.attach_parameters_heads(state_dict, name))

for name, buffers in self.model.head.named_buffers():
buffers.copy_(self.attach_parameters_heads(state_dict, name))

@staticmethod
def attach_parameters_block(state_dict, name, block_maps=None):
keys = name.split(".")
ind = int(block_maps[keys[0]]) if block_maps else int(keys[0])
return rgetattr(state_dict[ind], keys[1:])

def attach_parameters_heads(self, state_dict, name):
keys = name.split(".")
ind = int(keys[1])
return rgetattr(getattr(state_dict[self.head_ind], self.head_name)[ind], keys[2:])


def rgetattr(obj, attr, *args):
"""
Nested version of getattr.
Ref: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects
"""

def _getattr(obj, attr):
return getattr(obj, attr, *args)

return reduce(_getattr, [obj] + attr)


def obtain_module_sequential(state_dict):
if isinstance(state_dict, nn.Sequential):
return state_dict
else:
return obtain_module_sequential(state_dict.model)
Loading

0 comments on commit 5942ce8

Please sign in to comment.