Skip to content

Commit

Permalink
Add support for the yolov8-pose algorithm (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
CVHub520 committed Nov 4, 2023
1 parent b854ec7 commit 118ba48
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 25 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
## 🥳 What's New [⏏️](#📄-table-of-contents)

- Nov. 2023:
- Support pose estimation: [YOLOv8-Pose](https://github.com/ultralytics/ultralytics).
- Support object-level tag with yolov5_ram.
- Oct. 2023:
- Release the latest version [1.0.0](https://github.com/CVHub520/X-AnyLabeling/releases/tag/v1.0.0).
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
## 🥳 新功能 [⏏️](#📄-目录)

- Nov. 2023:
- Support pose estimation: [YOLOv8-Pose](https://github.com/ultralytics/ultralytics).
- Support object-level tag with yolov5_ram.
- Oct. 2023:
- Release the latest version [1.0.0](https://github.com/CVHub520/X-AnyLabeling/releases/tag/v1.0.0).
Expand Down
4 changes: 4 additions & 0 deletions anylabeling/configs/auto_labeling/models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@
config_file: ":/yolov8m.yaml"
- model_name: "yolov8n_efficientvit_sam_l0_vit_h-r20231020"
config_file: ":/yolov8n_efficientvit_sam_l0_vit_h.yaml"
- model_name: "yolov8n-pose-r20231103"
config_file: ":/yolov8n_pose.yaml"
- model_name: "yolov8n-seg-r20230620"
config_file: ":/yolov8n_seg.yaml"
- model_name: "yolov8n-r20230520"
Expand All @@ -128,6 +130,8 @@
config_file: ":/yolov8s_seg.yaml"
- model_name: "yolov8s-r20230520"
config_file: ":/yolov8s.yaml"
- model_name: "yolov8x-pose-p6-r20231103"
config_file: ":/yolov8x_pose_p6.yaml"
- model_name: "yolov8x-seg-r20230620"
config_file: ":/yolov8x_seg.yaml"
- model_name: "yolov8x-r20230520"
Expand Down
27 changes: 27 additions & 0 deletions anylabeling/configs/auto_labeling/yolov8n_pose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
type: yolov8_pose
name: yolov8n-pose-r20231103
display_name: YOLOv8n-Pose Ultralytics
model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v1.0.0/yolov8n-pose.onnx
confidence_threshold: 0.5
nms_threshold: 0.6
hide_box: False
classes:
- person
keypoints:
- nose
- left_eye
- right_eye
- left_ear
- right_ear
- left_shoulder
- right_shoulder
- left_elbow
- right_elbow
- left_wrist
- right_wrist
- left_hip
- right_hip
- left_knee
- right_knee
- left_ankle
- right_ankle
27 changes: 27 additions & 0 deletions anylabeling/configs/auto_labeling/yolov8x_pose_p6.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
type: yolov8_pose
name: yolov8x-pose-p6-r20231103
display_name: YOLOv8x-Pose-P6 Ultralytics
model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v1.0.0/yolov8x-pose-p6.onnx
confidence_threshold: 0.25
nms_threshold: 0.6
hide_box: False
classes:
- person
keypoints:
- nose
- left_eye
- right_eye
- left_ear
- right_ear
- left_shoulder
- right_shoulder
- left_elbow
- right_elbow
- left_wrist
- right_wrist
- left_hip
- right_hip
- left_knee
- right_knee
- left_ankle
- right_ankle
2 changes: 2 additions & 0 deletions anylabeling/services/auto_labeling/__base__/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(self, model_config, on_message) -> None:
self.net = OnnxBaseModel(model_abs_path, __preferred_device__)
self.classes = self.config["classes"]
self.input_shape = self.net.get_input_shape()[-2:]
self.hide_box = self.config.get("hide_box", True)
self.keypoints = self.config.get("keypoints", [])
self.nms_thres = self.config["nms_threshold"]
self.conf_thres = self.config["confidence_threshold"]
self.stride = self.config.get("stride", 32)
Expand Down
23 changes: 23 additions & 0 deletions anylabeling/services/auto_labeling/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def load_custom_model(self, config_file):
"ram",
"yolov5_seg",
"yolov5_ram",
"yolov8_pose",
]
):
self.new_model_status.emit(
Expand Down Expand Up @@ -470,6 +471,28 @@ def _load_model(self, model_id):
)
)
return
elif model_config["type"] == "yolov8_pose":
from .yolov8_pose import YOLOv8_Pose

try:
model_config["model"] = YOLOv8_Pose(
model_config, on_message=self.new_model_status.emit
)
self.auto_segmentation_model_unselected.emit()
except Exception as e: # noqa
self.new_model_status.emit(
self.tr(
"Error in loading model: {error_message}".format(
error_message=str(e)
)
)
)
print(
"Error in loading model: {error_message}".format(
error_message=str(e)
)
)
return
elif model_config["type"] == "yolox":
from .yolox import YOLOX

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .tracker.byte_tracker import BYTETracker
import numpy as np
from ...utils.points_conversion import tlwh_to_xyxy
from ...utils.points_conversion import tlwh2xyxy


class ByteTrack(object):
Expand Down Expand Up @@ -51,7 +51,7 @@ def _tracker_update(self, dets: np.ndarray, image_info: dict):
track_id = online_target.track_id
vertical = tlwh[2] / tlwh[3] > self.aspect_ratio_thresh
if tlwh[2] * tlwh[3] > self.min_box_area and not vertical:
online_xyxys.append(tlwh_to_xyxy(tlwh))
online_xyxys.append(tlwh2xyxy(tlwh))
online_ids.append(track_id)
online_scores.append(online_target.score)
return online_xyxys, online_ids, online_scores
35 changes: 34 additions & 1 deletion anylabeling/services/auto_labeling/utils/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def box_iou(box1, box2):
return iou # NxM


def rescale_box(input_shape, boxes, image_shape):
def rescale_box(input_shape, boxes, image_shape, kpts=False):
'''Rescale the output to the original image shape'''
ratio = min(
input_shape[0] / image_shape[0],
Expand All @@ -35,6 +35,11 @@ def rescale_box(input_shape, boxes, image_shape):
boxes[:, 1] = np.clip(boxes[:, 1], 0, image_shape[0]) # y1
boxes[:, 2] = np.clip(boxes[:, 2], 0, image_shape[1]) # x2
boxes[:, 3] = np.clip(boxes[:, 3], 0, image_shape[0]) # y2
if kpts:
num_kpts = boxes.shape[1] // 3
for i in range(2, num_kpts + 1):
boxes[:, i * 3 - 1] = (boxes[:, i * 3 - 1] - padding[0]) / ratio
boxes[:, i * 3] = (boxes[:, i * 3] - padding[1]) / ratio
return boxes


Expand Down Expand Up @@ -72,6 +77,34 @@ def rescale_box_and_landmark(input_shape, boxes, lmdks, image_shape):
return np.round(boxes), np.round(lmdks)


def rescale_tlwh(input_shape, boxes, image_shape, kpts=False):
'''Rescale the output to the original image shape'''
ratio = min(
input_shape[0] / image_shape[0],
input_shape[1] / image_shape[1],
)
padding = (
(input_shape[1] - image_shape[1] * ratio) / 2,
(input_shape[0] - image_shape[0] * ratio) / 2,
)
boxes[:, 0] -= padding[0]
boxes[:, 1] -= padding[1]
boxes[:, :4] /= ratio
boxes[:, 0] = np.clip(boxes[:, 0], 0, image_shape[1]) # x1
boxes[:, 1] = np.clip(boxes[:, 1], 0, image_shape[0]) # y1
boxes[:, 2] = np.clip(
(boxes[:, 0] + boxes[:, 2]), 0, image_shape[1]
) # x2
boxes[:, 3] = np.clip(
(boxes[:, 1] + boxes[:, 3]), 0, image_shape[0]
) # y2
if kpts:
num_kpts = boxes.shape[1] // 3
for i in range(2, num_kpts + 1):
boxes[:, i * 3 - 1] = (boxes[:, i * 3 - 1] - padding[0]) / ratio
boxes[:, i * 3] = (boxes[:, i * 3] - padding[1]) / ratio
return boxes


def numpy_nms(boxes, scores, iou_threshold):
idxs = scores.argsort()
Expand Down
28 changes: 11 additions & 17 deletions anylabeling/services/auto_labeling/utils/points_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,20 @@ def xywh2xyxy(x):
return y


def tlwh_to_xyxy(x):
def tlwh2xyxy(x):
"""" Convert tlwh to xyxy """
x1 = x[0]
y1 = x[1]
x2 = x[2] + x1
y2 = x[3] + y1
return [x1, y1, x2, y2]
y = np.copy(x)
y[:, 2] = x[:, 2] + x[:, 0]
y[:, 3] = x[:, 3] + x[:, 1]
return y


def xyxy_to_tlwh(x):
tlwh_bboxs = []
for i, box in enumerate(x):
x1, y1, x2, y2 = [int(i) for i in box]
top = x1
left = y1
w = int(x2 - x1)
h = int(y2 - y1)
tlwh_obj = [top, left, w, h]
tlwh_bboxs.append(tlwh_obj)
return tlwh_bboxs
def xyxy2tlwh(x):
"""" Convert xyxy to tlwh """
y = np.copy(x)
y[:, 2] = x[:, 2] - x[:, 0]
y[:, 3] = x[:, 3] - x[:, 1]
return y


def bbox_cxcywh_to_xyxy(x):
Expand Down
113 changes: 113 additions & 0 deletions anylabeling/services/auto_labeling/yolov8_pose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import logging
import numpy as np

from PyQt5 import QtCore
from PyQt5.QtCore import QCoreApplication

from anylabeling.app_info import __preferred_device__
from anylabeling.views.labeling.shape import Shape
from anylabeling.views.labeling.utils.opencv import qt_img_to_rgb_cv_img
from .types import AutoLabelingResult
from .__base__.yolo import YOLO
from .utils import (
numpy_nms,
xywh2xyxy,
xyxy2tlwh,
rescale_tlwh,
)

class YOLOv8_Pose(YOLO):

class Meta:
required_config_names = [
"type",
"name",
"display_name",
"model_path",
"nms_threshold",
"confidence_threshold",
"classes",
]
widgets = ["button_run"]
output_modes = {
"rectangle": QCoreApplication.translate("Model", "Rectangle"),
"point": QCoreApplication.translate("Model", "Point"),
}
default_output_mode = "rectangle"

def postprocess(
self,
prediction,
max_det=1000,
):

"""
Args:
prediction: (1, 56, *), where 56 = 4 + 1 + 3 * 17
4 -> box_xywh
1 -> box_score
3*17 -> (x, y, kpt_score) * 17 keypoints
"""
prediction = prediction.transpose((0, 2, 1))[0]
x = prediction[prediction[:, 4] > self.conf_thres]
if len(x) == 0:
return []
x[:, :4] = xywh2xyxy(x[:, :4])
keep_idx = numpy_nms(x[:, :4], x[:, 4], self.nms_thres) # NMS
if keep_idx.shape[0] > max_det: # limit detections
keep_idx = keep_idx[:max_det]
keep_label = []
for i in keep_idx:
keep_label.append(x[i].tolist())
xyxy = np.array(keep_label)
tlwh = xyxy2tlwh(xyxy)
return tlwh

def predict_shapes(self, image, image_path=None):
"""
Predict shapes from image
"""

if image is None:
return []

try:
image = qt_img_to_rgb_cv_img(image, image_path)
except Exception as e: # noqa
logging.warning("Could not inference model")
logging.warning(e)
return []

blob = self.preprocess(image)
predictions = self.net.get_ort_inference(blob)
results = self.postprocess(predictions)

if len(results) == 0:
return AutoLabelingResult([], replace=True)
results = rescale_tlwh(
self.input_shape, results, image.shape, kpts=True
)

shapes = []
for r in reversed(results):
xyxy, _, kpts = r[:4], r[4], r[5:]

if not self.hide_box:
rectangle_shape = Shape(
label=str(self.classes[0]),
shape_type="rectangle",
)
rectangle_shape.add_point(QtCore.QPointF(xyxy[0], xyxy[1]))
rectangle_shape.add_point(QtCore.QPointF(xyxy[2], xyxy[3]))
shapes.append(rectangle_shape)

interval = 3
for i in range(0, len(kpts), interval):
x, y, kpt_score = kpts[i: i + 3]
if kpt_score > self.conf_thres:
label = self.keypoints[int(i//interval)]
point_shape = Shape(label=label, shape_type="point")
point_shape.add_point(QtCore.QPointF(x, y))
shapes.append(point_shape)
result = AutoLabelingResult(shapes, replace=True)
return result
26 changes: 21 additions & 5 deletions docs/models_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,31 @@
| yolov5m.onnx | [YOLOv5m-ByteTrack](https://github.com/ifzhang/ByteTrack) | [yolov5m_bytetrack.yaml](../anylabeling/configs/auto_labeling/yolov5m_bytetrack.yaml) | 81.19MB | [baidu](https://pan.baidu.com/s/1oB9Vp-s7viOaLxAhjzAnFQ?pwd=vc4v) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov5m.onnx) |
| yolov5m.onnx | [YOLOv5m-OCSort](https://github.com/noahcao/OC_SORT) | [yolov5m_ocsort.yaml](../anylabeling/configs/auto_labeling/yolov5m_ocsort.yaml) | 81.19MB | [baidu](https://pan.baidu.com/s/1oB9Vp-s7viOaLxAhjzAnFQ?pwd=vc4v) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov5m.onnx) |


### Keypoint Detection

- Facial Landmark Detection

|Name|Description|Configuration|Size|Link|
| --- | --- | --- | --- | --- |
| yolov6lite_l_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_l_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_l_face.yaml) | 4.16MB | [baidu](https://pan.baidu.com/s/1Ot14duf7GUPMcHHIUY4TSg?pwd=xy1m) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_l_face.onnx) |
| yolov6lite_m_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_m_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_m_face.yaml) | 3.00MB | [baidu](https://pan.baidu.com/s/1uL01_pqldSXdFnlDM0OFZg?pwd=n9xl) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_m_face.onnx) |
| yolov6lite_s_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_s_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_s_face.yaml) | 2.10MB | [baidu](https://pan.baidu.com/s/1nN_b4_WutwhnJYEhyqjDYg?pwd=ig4d) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_s_face.onnx) |

- Pose Estimation

|Name|Description|Configuration|Size|Link|
| --- | --- | --- | --- | --- |
| yolov8n-pose.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-COCO | [yolov8n_pose.yaml](../anylabeling/configs/auto_labeling/yolov8n_pose.yaml) | 12.75MB | [baidu](https://pan.baidu.com/s/1nFxtNTuLn9vQJWId-EOAKQ?pwd=37ej) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v1.0.0/yolov8n_pose.onnx) |
| yolov8x-pose-p6.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-COCO | [yolov8x_pose_p6.yaml](../anylabeling/configs/auto_labeling/yolov8x_pose_p6.yaml) | 378.92MB | [baidu](https://pan.baidu.com/s/1VRszCtMnHo2zeshuk93Aew?pwd=xvko) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v1.0.0/yolov8x_pose_p6.onnx) |
| dw-ll_ucoco_384.onnx | [DWPose](https://github.com/IDEA-Research/DWPose/tree/main)(2D human whole-body pose estimation) | [yolox_l_dwpose_ucoco.yaml](../anylabeling/configs/auto_labeling/yolox_l_dwpose_ucoco.yaml) | 128.17MB | [baidu](https://pan.baidu.com/s/1I6CAFhW2YAowN80yweGVpg?pwd=pzf4) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.0/dw-ll_ucoco_384.onnx) |
| yolox_l.onnx | [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)(2D human whole-body pose estimation) | [yolox_l_dwpose_ucoco.yaml](../anylabeling/configs/auto_labeling/yolox_l_dwpose_ucoco.yaml) | 206.71MB | [baidu](https://pan.baidu.com/s/1NpFiX1JN-0jIvd38tQIDcQ?pwd=aqk5) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.0/yolox_l.onnx) |


### Union Task

|Name|Description|Configuration|Size|Link|
| --- | --- | --- | --- | --- |
| yolov6lite_l_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_l_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_l_face.yaml) | 4.16MB | [baidu](https://pan.baidu.com/s/1Ot14duf7GUPMcHHIUY4TSg?pwd=xy1m) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_l_face.onnx) |
| yolov6lite_m_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_m_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_m_face.yaml) | 3.00MB | [baidu](https://pan.baidu.com/s/1uL01_pqldSXdFnlDM0OFZg?pwd=n9xl) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_m_face.onnx) |
| yolov6lite_s_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_s_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_s_face.yaml) | 2.10MB | [baidu](https://pan.baidu.com/s/1nN_b4_WutwhnJYEhyqjDYg?pwd=ig4d) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_s_face.onnx) |
| dw-ll_ucoco_384.onnx | [DWPose](https://github.com/IDEA-Research/DWPose/tree/main)(全身人体姿态估计) | [yolox_l_dwpose_ucoco.yaml](../anylabeling/configs/auto_labeling/yolox_l_dwpose_ucoco.yaml) | 128.17MB | [baidu](https://pan.baidu.com/s/1I6CAFhW2YAowN80yweGVpg?pwd=pzf4) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.0/dw-ll_ucoco_384.onnx) |
| yolox_l.onnx | [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)(全身人体姿态估计) | [yolox_l_dwpose_ucoco.yaml](../anylabeling/configs/auto_labeling/yolox_l_dwpose_ucoco.yaml) | 206.71MB | [baidu](https://pan.baidu.com/s/1NpFiX1JN-0jIvd38tQIDcQ?pwd=aqk5) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.0/yolox_l.onnx) |
| resnet50.onnx | [ResNet50](https://arxiv.org/abs/1512.03385)-ImageNet(检测+分类级联模型) | [yolov5s_resnet50.yaml](../anylabeling/configs/auto_labeling/yolov5s_resnet50.yaml) | 97.42MB | [baidu](https://pan.baidu.com/s/1byapPRVib7rarAMTmoSkKQ?pwd=xllt) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/resnet50.onnx) |
| yolov5s.onnx | [YOLOv5](https://github.com/ultralytics/yolov5)-COCO(检测+分类级联模型) | [yolov5s_resnet50.yaml](../anylabeling/configs/auto_labeling/yolov5s_resnet50.yaml) | 27.98MB | [baidu](https://pan.baidu.com/s/18I8ugM29NKjNlVEsnYYuWA?pwd=z8dl) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov5s.onnx) |
| mobile_sam.encoder.onnx | [MobileSAM](https://arxiv.org/abs/2306.14289) encoder(YOLOv5-SAM) | [yolov5s_mobile_sam_vit_h.yaml](../anylabeling/configs/auto_labeling/yolov5s_mobile_sam_vit_h.yaml) | 26.85MB | [baidu](https://pan.baidu.com/s/1knn_NkrRAgfJms6d9FsErg?pwd=xbi6) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.0/mobile_sam.encoder.onnx) |
Expand Down

0 comments on commit 118ba48

Please sign in to comment.