Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move bias initializations from private methods to constructors #351

Merged
merged 3 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion yolort/models/anchor_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved.
# Copyright (c) 2020, yolort team. All rights reserved.

from typing import Tuple, List

import torch
Expand Down
54 changes: 28 additions & 26 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2020, yolort team. All rights reserved.

import math
from typing import Tuple, List, Dict
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
Expand All @@ -11,13 +12,18 @@


class YOLOHead(nn.Module):
def __init__(
self,
in_channels: List[int],
num_anchors: int,
strides: List[int],
num_classes: int,
):
"""
A regression and classification head for use in YOLO.

Args:
in_channels (List[int]): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
strides (List[int]): number of strides of the anchors
num_classes (int): number of classes to be predicted
"""

def __init__(self, in_channels: List[int], num_anchors: int, strides: List[int], num_classes: int):

super().__init__()
if not isinstance(in_channels, list):
in_channels = [in_channels] * len(strides)
Expand All @@ -26,25 +32,21 @@ def __init__(
self.num_outputs = num_classes + 5 # number of outputs per anchor
self.strides = strides

self.head = nn.ModuleList(
head_blocks = nn.ModuleList(
nn.Conv2d(ch, self.num_outputs * self.num_anchors, 1) for ch in in_channels
) # output conv

self._initialize_biases() # Init weights, biases
)

def _initialize_biases(self, cf=None):
"""
Initialize biases into YOLOHead, cf is class frequency
Check section 3.3 in <https://arxiv.org/abs/1708.02002>
"""
for mi, s in zip(self.head, self.strides):
# Initialize biases into head blocks
for mi, s in zip(head_blocks, self.strides):
b = mi.bias.view(self.num_anchors, -1) # conv.bias(255) to (3,85)
# obj (8 objects per 640 image)
b.data[:, 4] += math.log(8 / (640 / s) ** 2)
# classes
b.data[:, 5:] += torch.log(cf / cf.sum()) if cf else math.log(0.6 / (self.num_classes - 0.99))
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.999999))
mi.bias = nn.Parameter(b.view(-1), requires_grad=True)

self.head = head_blocks

def get_result_from_head(self, features: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.head[idx](features),
Expand Down Expand Up @@ -358,6 +360,12 @@ def _decode_pred_logits(pred_logits: Tensor):
class PostProcess(nn.Module):
"""
Performs Non-Maximum Suppression (NMS) on inference results

Args:
strides (List[int]): Strides of the AnchorGenerator.
score_thresh (float): Score threshold used for postprocessing the detections.
nms_thresh (float): NMS threshold used for postprocessing the detections.
detections_per_img (int): Number of best detections to keep after NMS.
"""

def __init__(
Expand All @@ -367,13 +375,7 @@ def __init__(
nms_thresh: float,
detections_per_img: int,
) -> None:
"""
Args:
strides (List[int]): Strides of the AnchorGenerator.
score_thresh (float): Score threshold used for postprocessing the detections.
nms_thresh (float): NMS threshold used for postprocessing the detections.
detections_per_img (int): Number of best detections to keep after NMS.
"""

super().__init__()
self.strides = strides
self.score_thresh = score_thresh
Expand Down
10 changes: 3 additions & 7 deletions yolort/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,7 @@ def get_image_from_url(url: str, flags: int = 1) -> np.ndarray:
return image


def read_image_to_tensor(
image: np.ndarray,
is_half: bool = False,
) -> Tensor:
def read_image_to_tensor(image: np.ndarray, is_half: bool = False) -> Tensor:
"""
Parse an image to Tensor.

Expand All @@ -122,9 +119,8 @@ def read_image_to_tensor(
image = np.ascontiguousarray(image, dtype=np.float32) # uint8 to float32
image = np.transpose(image / 255.0, [2, 0, 1])

image = torch.from_numpy(image)
image = image.half() if is_half else image.float()
return image
_dtype = torch.float16 if is_half else torch.float32
return torch.from_numpy(image).to(dtype=_dtype)


def load_names(category_path):
Expand Down