-
Notifications
You must be signed in to change notification settings - Fork 153
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move
load_from_ultralytics
into _checkpoint.py (#373)
* Move load_from_ultralytics into _checkpoint.py * Minor updates * Replace asserts with exceptions * Apply pre-commit * Fix unittest * Fix path of load_from_ultralytics * Fixing lint
- Loading branch information
Showing
7 changed files
with
230 additions
and
232 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Copyright (c) 2020, yolort team. All rights reserved. | ||
|
||
from functools import reduce | ||
from typing import Dict, List, Optional | ||
|
||
from torch import nn | ||
from yolort.v5 import get_yolov5_size, load_yolov5_model | ||
|
||
from .backbone_utils import darknet_pan_backbone | ||
from .box_head import YOLOHead | ||
|
||
__all__ = ["load_from_ultralytics"] | ||
|
||
|
||
def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"): | ||
""" | ||
Allows the user to load model state file from the checkpoint trained from | ||
the ultralytics/yolov5. | ||
Args: | ||
checkpoint_path (str): Path of the YOLOv5 checkpoint model. | ||
version (str): upstream version released by the ultralytics/yolov5, Possible | ||
values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0". | ||
""" | ||
|
||
if version not in ["r3.1", "r4.0", "r6.0"]: | ||
raise NotImplementedError( | ||
f"Currently does not support version: {version}. Feel free to file an issue " | ||
"labeled enhancement to us." | ||
) | ||
|
||
checkpoint_yolov5 = load_yolov5_model(checkpoint_path) | ||
num_classes = checkpoint_yolov5.yaml["nc"] | ||
strides = checkpoint_yolov5.stride | ||
# YOLOv5 will change the anchors setting when using the auto-anchor mechanism. So we | ||
# use the following formula to compute the anchor_grids instead of attaching it via | ||
# checkpoint_yolov5.yaml["anchors"] | ||
num_anchors = checkpoint_yolov5.model[-1].anchors.shape[1] | ||
anchor_grids = ( | ||
(checkpoint_yolov5.model[-1].anchors * checkpoint_yolov5.model[-1].stride.view(-1, 1, 1)) | ||
.reshape(1, -1, 2 * num_anchors) | ||
.tolist()[0] | ||
) | ||
|
||
depth_multiple = checkpoint_yolov5.yaml["depth_multiple"] | ||
width_multiple = checkpoint_yolov5.yaml["width_multiple"] | ||
|
||
use_p6 = False | ||
if len(strides) == 4: | ||
use_p6 = True | ||
|
||
if use_p6: | ||
inner_block_maps = {"0": "11", "1": "12", "3": "15", "4": "16", "6": "19", "7": "20"} | ||
layer_block_maps = {"0": "23", "1": "24", "2": "26", "3": "27", "4": "29", "5": "30", "6": "32"} | ||
p6_block_maps = {"0": "9", "1": "10"} | ||
head_ind = 33 | ||
head_name = "m" | ||
else: | ||
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"} | ||
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"} | ||
p6_block_maps = None | ||
head_ind = 24 | ||
head_name = "m" | ||
|
||
convert_yolo_checkpoint = CheckpointConverter( | ||
depth_multiple, | ||
width_multiple, | ||
inner_block_maps=inner_block_maps, | ||
layer_block_maps=layer_block_maps, | ||
p6_block_maps=p6_block_maps, | ||
strides=strides, | ||
anchor_grids=anchor_grids, | ||
head_ind=head_ind, | ||
head_name=head_name, | ||
num_classes=num_classes, | ||
version=version, | ||
use_p6=use_p6, | ||
) | ||
convert_yolo_checkpoint.updating(checkpoint_yolov5) | ||
state_dict = convert_yolo_checkpoint.model.half().state_dict() | ||
|
||
size = get_yolov5_size(depth_multiple, width_multiple) | ||
|
||
return { | ||
"num_classes": num_classes, | ||
"depth_multiple": depth_multiple, | ||
"width_multiple": width_multiple, | ||
"strides": strides, | ||
"anchor_grids": anchor_grids, | ||
"use_p6": use_p6, | ||
"size": size, | ||
"state_dict": state_dict, | ||
} | ||
|
||
|
||
class ModelWrapper(nn.Module): | ||
def __init__(self, backbone, head): | ||
super().__init__() | ||
self.backbone = backbone | ||
self.head = head | ||
|
||
|
||
class CheckpointConverter: | ||
""" | ||
Update checkpoint from ultralytics yolov5. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
depth_multiple: float, | ||
width_multiple: float, | ||
inner_block_maps: Optional[Dict[str, str]] = None, | ||
layer_block_maps: Optional[Dict[str, str]] = None, | ||
p6_block_maps: Optional[Dict[str, str]] = None, | ||
strides: Optional[List[int]] = None, | ||
anchor_grids: Optional[List[List[float]]] = None, | ||
head_ind: int = 24, | ||
head_name: str = "m", | ||
num_classes: int = 80, | ||
version: str = "r6.0", | ||
use_p6: bool = False, | ||
) -> None: | ||
|
||
# Configuration for making the keys consistent | ||
if inner_block_maps is None: | ||
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"} | ||
self.inner_block_maps = inner_block_maps | ||
if layer_block_maps is None: | ||
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"} | ||
self.layer_block_maps = layer_block_maps | ||
self.p6_block_maps = p6_block_maps | ||
self.head_ind = head_ind | ||
self.head_name = head_name | ||
|
||
# Set model | ||
yolov5_size = get_yolov5_size(depth_multiple, width_multiple) | ||
backbone_name = f"darknet_{yolov5_size}_{version.replace('.', '_')}" | ||
|
||
backbone = darknet_pan_backbone( | ||
backbone_name, depth_multiple, width_multiple, version=version, use_p6=use_p6 | ||
) | ||
num_anchors = len(anchor_grids[0]) // 2 | ||
head = YOLOHead(backbone.out_channels, num_anchors, strides, num_classes) | ||
# Only backbone and head contain parameters inside, so we only wrap them both here. | ||
self.model = ModelWrapper(backbone, head) | ||
|
||
def updating(self, state_dict): | ||
# Obtain module state | ||
state_dict = obtain_module_sequential(state_dict) | ||
|
||
# Update backbone weights | ||
for name, params in self.model.backbone.body.named_parameters(): | ||
params.data.copy_(self.attach_parameters_block(state_dict, name, None)) | ||
|
||
for name, buffers in self.model.backbone.body.named_buffers(): | ||
buffers.copy_(self.attach_parameters_block(state_dict, name, None)) | ||
|
||
# Update PAN weights | ||
# Updating P6 weights | ||
if self.p6_block_maps is not None: | ||
for name, params in self.model.backbone.pan.intermediate_blocks.p6.named_parameters(): | ||
params.data.copy_(self.attach_parameters_block(state_dict, name, self.p6_block_maps)) | ||
|
||
for name, buffers in self.model.backbone.pan.intermediate_blocks.p6.named_buffers(): | ||
buffers.copy_(self.attach_parameters_block(state_dict, name, self.p6_block_maps)) | ||
|
||
# Updating inner_block weights | ||
for name, params in self.model.backbone.pan.inner_blocks.named_parameters(): | ||
params.data.copy_(self.attach_parameters_block(state_dict, name, self.inner_block_maps)) | ||
|
||
for name, buffers in self.model.backbone.pan.inner_blocks.named_buffers(): | ||
buffers.copy_(self.attach_parameters_block(state_dict, name, self.inner_block_maps)) | ||
|
||
# Updating layer_block weights | ||
for name, params in self.model.backbone.pan.layer_blocks.named_parameters(): | ||
params.data.copy_(self.attach_parameters_block(state_dict, name, self.layer_block_maps)) | ||
|
||
for name, buffers in self.model.backbone.pan.layer_blocks.named_buffers(): | ||
buffers.copy_(self.attach_parameters_block(state_dict, name, self.layer_block_maps)) | ||
|
||
# Update YOLOHead weights | ||
for name, params in self.model.head.named_parameters(): | ||
params.data.copy_(self.attach_parameters_heads(state_dict, name)) | ||
|
||
for name, buffers in self.model.head.named_buffers(): | ||
buffers.copy_(self.attach_parameters_heads(state_dict, name)) | ||
|
||
@staticmethod | ||
def attach_parameters_block(state_dict, name, block_maps=None): | ||
keys = name.split(".") | ||
ind = int(block_maps[keys[0]]) if block_maps else int(keys[0]) | ||
return rgetattr(state_dict[ind], keys[1:]) | ||
|
||
def attach_parameters_heads(self, state_dict, name): | ||
keys = name.split(".") | ||
ind = int(keys[1]) | ||
return rgetattr(getattr(state_dict[self.head_ind], self.head_name)[ind], keys[2:]) | ||
|
||
|
||
def rgetattr(obj, attr, *args): | ||
""" | ||
Nested version of getattr. | ||
Ref: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects | ||
""" | ||
|
||
def _getattr(obj, attr): | ||
return getattr(obj, attr, *args) | ||
|
||
return reduce(_getattr, [obj] + attr) | ||
|
||
|
||
def obtain_module_sequential(state_dict): | ||
if isinstance(state_dict, nn.Sequential): | ||
return state_dict | ||
else: | ||
return obtain_module_sequential(state_dict.model) |
Oops, something went wrong.