Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ImportError of circular reference #369

Merged
merged 3 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import numpy as np
import torch
from torch import Tensor
from yolort.data import contains_any_tensor, _helper as data_helper
from yolort.data import _helper as data_helper
from yolort.data.data_module import DetectionDataModule
from yolort.utils import contains_any_tensor


def test_contains_any_tensor():
Expand Down
3 changes: 2 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from torchvision.io import read_image
from yolort import models
from yolort.models import YOLOv5
from yolort.utils import get_image_from_url, load_from_ultralytics, read_image_to_tensor
from yolort.models._utils import load_from_ultralytics
from yolort.utils import get_image_from_url, read_image_to_tensor
from yolort.utils.image_utils import box_cxcywh_to_xyxy
from yolort.v5 import letterbox, scale_coords, attempt_download

Expand Down
4 changes: 0 additions & 4 deletions yolort/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
# Copyright (c) 2021, yolort team. All rights reserved.

from ._helper import contains_any_tensor

__all__ = ["contains_any_tensor"]
15 changes: 0 additions & 15 deletions yolort/data/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

import logging
from pathlib import Path, PosixPath
from typing import Type, Any
from zipfile import ZipFile

import torch
from tabulate import tabulate
from torch import Tensor

from .coco import COCODetection
from .transforms import collate_fn, default_train_transforms, default_val_transforms
Expand Down Expand Up @@ -49,19 +47,6 @@ def get_coco_api_from_dataset(dataset):
raise NotImplementedError("Currently only supports COCO datasets")


def contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool:
"""
Determine whether or not a list contains any Type
"""
if isinstance(value, dtype):
return True
if isinstance(value, (list, tuple)):
return any(contains_any_tensor(v, dtype=dtype) for v in value)
elif isinstance(value, dict):
return any(contains_any_tensor(v, dtype=dtype) for v in value.values())
return False


def prepare_coco128(
data_path: PosixPath,
dirname: str = "coco128",
Expand Down
201 changes: 200 additions & 1 deletion yolort/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright (c) 2020, yolort team. All rights reserved.

import math
from typing import Tuple, Optional
from functools import reduce
from typing import Dict, List, Optional, Tuple

import torch
from torch import nn, Tensor
from yolort.v5 import get_yolov5_size, load_yolov5_model

from . import yolo


def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
Expand Down Expand Up @@ -107,6 +111,201 @@ def bbox_iou(box1: Tensor, box2: Tensor, x1y1x2y2: bool = True, eps: float = 1e-
return iou - (rho2 / c2 + v * alpha) # CIoU


def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
"""
Allows the user to load model state file from the checkpoint trained from
the ultralytics/yolov5.

Args:
checkpoint_path (str): Path of the YOLOv5 checkpoint model.
version (str): upstream version released by the ultralytics/yolov5, Possible
values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0".
"""

assert version in ["r3.1", "r4.0", "r6.0"], "Currently does not support this version."

checkpoint_yolov5 = load_yolov5_model(checkpoint_path)
num_classes = checkpoint_yolov5.yaml["nc"]
strides = checkpoint_yolov5.stride
# YOLOv5 will change the anchors setting when using the auto-anchor mechanism. So we
# use the following formula to compute the anchor_grids instead of attaching it via
# checkpoint_yolov5.yaml["anchors"]
num_anchors = checkpoint_yolov5.model[-1].anchors.shape[1]
anchor_grids = (
(checkpoint_yolov5.model[-1].anchors * checkpoint_yolov5.model[-1].stride.view(-1, 1, 1))
.reshape(1, -1, 2 * num_anchors)
.tolist()[0]
)

depth_multiple = checkpoint_yolov5.yaml["depth_multiple"]
width_multiple = checkpoint_yolov5.yaml["width_multiple"]

use_p6 = False
if len(strides) == 4:
use_p6 = True

if use_p6:
inner_block_maps = {"0": "11", "1": "12", "3": "15", "4": "16", "6": "19", "7": "20"}
layer_block_maps = {"0": "23", "1": "24", "2": "26", "3": "27", "4": "29", "5": "30", "6": "32"}
p6_block_maps = {"0": "9", "1": "10"}
head_ind = 33
head_name = "m"
else:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}
p6_block_maps = None
head_ind = 24
head_name = "m"

module_state_updater = ModuleStateUpdate(
depth_multiple,
width_multiple,
inner_block_maps=inner_block_maps,
layer_block_maps=layer_block_maps,
p6_block_maps=p6_block_maps,
strides=strides,
anchor_grids=anchor_grids,
head_ind=head_ind,
head_name=head_name,
num_classes=num_classes,
version=version,
use_p6=use_p6,
)
module_state_updater.updating(checkpoint_yolov5)
state_dict = module_state_updater.model.half().state_dict()

size = get_yolov5_size(depth_multiple, width_multiple)

return {
"num_classes": num_classes,
"depth_multiple": depth_multiple,
"width_multiple": width_multiple,
"strides": strides,
"anchor_grids": anchor_grids,
"use_p6": use_p6,
"size": size,
"state_dict": state_dict,
}


class ModuleStateUpdate:
"""
Update checkpoint from ultralytics yolov5.
"""

def __init__(
self,
depth_multiple: float,
width_multiple: float,
inner_block_maps: Optional[Dict[str, str]] = None,
layer_block_maps: Optional[Dict[str, str]] = None,
p6_block_maps: Optional[Dict[str, str]] = None,
strides: Optional[List[int]] = None,
anchor_grids: Optional[List[List[float]]] = None,
head_ind: int = 24,
head_name: str = "m",
num_classes: int = 80,
version: str = "r6.0",
use_p6: bool = False,
) -> None:

# Configuration for making the keys consistent
if inner_block_maps is None:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
self.inner_block_maps = inner_block_maps
if layer_block_maps is None:
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}
self.layer_block_maps = layer_block_maps
self.p6_block_maps = p6_block_maps
self.head_ind = head_ind
self.head_name = head_name

# Set model
yolov5_size = get_yolov5_size(depth_multiple, width_multiple)
backbone_name = f"darknet_{yolov5_size}_{version.replace('.', '_')}"
self.model = yolo.build_model(
backbone_name,
depth_multiple,
width_multiple,
version,
num_classes=num_classes,
use_p6=use_p6,
strides=strides,
anchor_grids=anchor_grids,
)

def updating(self, state_dict):
# Obtain module state
state_dict = obtain_module_sequential(state_dict)

# Update backbone weights
for name, params in self.model.backbone.body.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, None))

for name, buffers in self.model.backbone.body.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, None))

# Update PAN weights
# Updating P6 weights
if self.p6_block_maps is not None:
for name, params in self.model.backbone.pan.intermediate_blocks.p6.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, self.p6_block_maps))

for name, buffers in self.model.backbone.pan.intermediate_blocks.p6.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, self.p6_block_maps))

# Updating inner_block weights
for name, params in self.model.backbone.pan.inner_blocks.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, self.inner_block_maps))

for name, buffers in self.model.backbone.pan.inner_blocks.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, self.inner_block_maps))

# Updating layer_block weights
for name, params in self.model.backbone.pan.layer_blocks.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, self.layer_block_maps))

for name, buffers in self.model.backbone.pan.layer_blocks.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, self.layer_block_maps))

# Update YOLOHead weights
for name, params in self.model.head.named_parameters():
params.data.copy_(self.attach_parameters_heads(state_dict, name))

for name, buffers in self.model.head.named_buffers():
buffers.copy_(self.attach_parameters_heads(state_dict, name))

@staticmethod
def attach_parameters_block(state_dict, name, block_maps=None):
keys = name.split(".")
ind = int(block_maps[keys[0]]) if block_maps else int(keys[0])
return rgetattr(state_dict[ind], keys[1:])

def attach_parameters_heads(self, state_dict, name):
keys = name.split(".")
ind = int(keys[1])
return rgetattr(getattr(state_dict[self.head_ind], self.head_name)[ind], keys[2:])


def rgetattr(obj, attr, *args):
"""
Nested version of getattr.
Ref: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects
"""

def _getattr(obj, attr):
return getattr(obj, attr, *args)

return reduce(_getattr, [obj] + attr)


def obtain_module_sequential(state_dict):
if isinstance(state_dict, nn.Sequential):
return state_dict
else:
return obtain_module_sequential(state_dict.model)


def smooth_binary_cross_entropy(eps: float = 0.1) -> Tuple[float, float]:
# https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
# return positive, negative label smoothing binary cross entropy targets
Expand Down
3 changes: 2 additions & 1 deletion yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import torch
from torch import nn, Tensor
from yolort.utils import load_from_ultralytics, load_state_dict_from_url
from yolort.utils import load_state_dict_from_url

from ._utils import load_from_ultralytics
from .anchor_utils import AnchorGenerator
from .backbone_utils import darknet_pan_backbone
from .box_head import YOLOHead, SetCriterion, PostProcess
Expand Down
2 changes: 1 addition & 1 deletion yolort/models/yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torchvision
from torch import nn, Tensor
from torchvision.io import read_image
from yolort.data import contains_any_tensor
from yolort.utils import contains_any_tensor

from . import yolo
from .transform import YOLOTransform, _get_shape_onnx
Expand Down
2 changes: 1 addition & 1 deletion yolort/relay/trt_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch
from torch import nn, Tensor
from yolort.models import YOLO
from yolort.models._utils import load_from_ultralytics
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.utils import load_from_ultralytics

from .logits_decoder import LogitsDecoder

Expand Down
2 changes: 1 addition & 1 deletion yolort/runtime/y_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, List, Callable, Optional

import numpy as np
from yolort.data import contains_any_tensor
from yolort.utils import contains_any_tensor

try:
import onnxruntime as ort
Expand Down
2 changes: 1 addition & 1 deletion yolort/runtime/y_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch
from torch import Tensor
from torchvision.io import read_image
from yolort.data import contains_any_tensor
from yolort.models.transform import YOLOTransform
from yolort.utils import contains_any_tensor

try:
import tensorrt as trt
Expand Down
21 changes: 18 additions & 3 deletions yolort/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2020, yolort team. All rights reserved.

from typing import Callable, Dict, Mapping, Sequence, Union
from typing import Any, Callable, Dict, Mapping, Sequence, Type, Union

from torch import Tensor

try:
from torch.hub import load_state_dict_from_url
Expand All @@ -10,17 +12,17 @@
from .dependency import check_version
from .hooks import FeatureExtractor
from .image_utils import cv2_imshow, get_image_from_url, read_image_to_tensor
from .update_module_state import convert_yolov5_to_yolort, load_from_ultralytics
from .update_module_state import convert_yolov5_to_yolort
from .visualizer import Visualizer


__all__ = [
"check_version",
"contains_any_tensor",
"cv2_imshow",
"get_image_from_url",
"get_callable_dict",
"convert_yolov5_to_yolort",
"load_from_ultralytics",
"load_state_dict_from_url",
"read_image_to_tensor",
"FeatureExtractor",
Expand All @@ -39,3 +41,16 @@ def get_callable_dict(fn: Union[Callable, Mapping, Sequence]) -> Union[Dict, Map
return {get_callable_name(f): f for f in fn}
elif callable(fn):
return {get_callable_name(fn): fn}


def contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool:
"""
Determine whether or not a list contains any Type
"""
if isinstance(value, dtype):
return True
if isinstance(value, (list, tuple)):
return any(contains_any_tensor(v, dtype=dtype) for v in value)
elif isinstance(value, dict):
return any(contains_any_tensor(v, dtype=dtype) for v in value.values())
return False
Loading