From 118ba48b5ae54627add808d5a381f90bd407375f Mon Sep 17 00:00:00 2001 From: cvhub Date: Sat, 4 Nov 2023 10:14:36 +0800 Subject: [PATCH] Add support for the yolov8-pose algorithm (#103) --- README.md | 1 + README_zh-CN.md | 1 + anylabeling/configs/auto_labeling/models.yaml | 4 + .../configs/auto_labeling/yolov8n_pose.yaml | 27 +++++ .../auto_labeling/yolov8x_pose_p6.yaml | 27 +++++ .../services/auto_labeling/__base__/yolo.py | 2 + .../services/auto_labeling/model_manager.py | 23 ++++ .../trackers/byte_track/bytetracker.py | 4 +- .../services/auto_labeling/utils/box.py | 35 +++++- .../auto_labeling/utils/points_conversion.py | 28 ++--- .../services/auto_labeling/yolov8_pose.py | 113 ++++++++++++++++++ docs/models_list.md | 26 +++- 12 files changed, 266 insertions(+), 25 deletions(-) create mode 100644 anylabeling/configs/auto_labeling/yolov8n_pose.yaml create mode 100644 anylabeling/configs/auto_labeling/yolov8x_pose_p6.yaml create mode 100644 anylabeling/services/auto_labeling/yolov8_pose.py diff --git a/README.md b/README.md index 0cfd029b..b4ab490f 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ ## 🥳 What's New [⏏️](#📄-table-of-contents) - Nov. 2023: + - Support pose estimation: [YOLOv8-Pose](https://github.com/ultralytics/ultralytics). - Support object-level tag with yolov5_ram. - Oct. 2023: - Release the latest version [1.0.0](https://github.com/CVHub520/X-AnyLabeling/releases/tag/v1.0.0). diff --git a/README_zh-CN.md b/README_zh-CN.md index cad749d5..5ee825da 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -68,6 +68,7 @@ ## 🥳 新功能 [⏏️](#📄-目录) - Nov. 2023: + - Support pose estimation: [YOLOv8-Pose](https://github.com/ultralytics/ultralytics). - Support object-level tag with yolov5_ram. - Oct. 2023: - Release the latest version [1.0.0](https://github.com/CVHub520/X-AnyLabeling/releases/tag/v1.0.0). diff --git a/anylabeling/configs/auto_labeling/models.yaml b/anylabeling/configs/auto_labeling/models.yaml index e29c1f9e..235b2bd7 100644 --- a/anylabeling/configs/auto_labeling/models.yaml +++ b/anylabeling/configs/auto_labeling/models.yaml @@ -118,6 +118,8 @@ config_file: ":/yolov8m.yaml" - model_name: "yolov8n_efficientvit_sam_l0_vit_h-r20231020" config_file: ":/yolov8n_efficientvit_sam_l0_vit_h.yaml" +- model_name: "yolov8n-pose-r20231103" + config_file: ":/yolov8n_pose.yaml" - model_name: "yolov8n-seg-r20230620" config_file: ":/yolov8n_seg.yaml" - model_name: "yolov8n-r20230520" @@ -128,6 +130,8 @@ config_file: ":/yolov8s_seg.yaml" - model_name: "yolov8s-r20230520" config_file: ":/yolov8s.yaml" +- model_name: "yolov8x-pose-p6-r20231103" + config_file: ":/yolov8x_pose_p6.yaml" - model_name: "yolov8x-seg-r20230620" config_file: ":/yolov8x_seg.yaml" - model_name: "yolov8x-r20230520" diff --git a/anylabeling/configs/auto_labeling/yolov8n_pose.yaml b/anylabeling/configs/auto_labeling/yolov8n_pose.yaml new file mode 100644 index 00000000..642845bf --- /dev/null +++ b/anylabeling/configs/auto_labeling/yolov8n_pose.yaml @@ -0,0 +1,27 @@ +type: yolov8_pose +name: yolov8n-pose-r20231103 +display_name: YOLOv8n-Pose Ultralytics +model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v1.0.0/yolov8n-pose.onnx +confidence_threshold: 0.5 +nms_threshold: 0.6 +hide_box: False +classes: + - person +keypoints: + - nose + - left_eye + - right_eye + - left_ear + - right_ear + - left_shoulder + - right_shoulder + - left_elbow + - right_elbow + - left_wrist + - right_wrist + - left_hip + - right_hip + - left_knee + - right_knee + - left_ankle + - right_ankle diff --git a/anylabeling/configs/auto_labeling/yolov8x_pose_p6.yaml b/anylabeling/configs/auto_labeling/yolov8x_pose_p6.yaml new file mode 100644 index 00000000..fbba9bd0 --- /dev/null +++ b/anylabeling/configs/auto_labeling/yolov8x_pose_p6.yaml @@ -0,0 +1,27 @@ +type: yolov8_pose +name: yolov8x-pose-p6-r20231103 +display_name: YOLOv8x-Pose-P6 Ultralytics +model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v1.0.0/yolov8x-pose-p6.onnx +confidence_threshold: 0.25 +nms_threshold: 0.6 +hide_box: False +classes: + - person +keypoints: + - nose + - left_eye + - right_eye + - left_ear + - right_ear + - left_shoulder + - right_shoulder + - left_elbow + - right_elbow + - left_wrist + - right_wrist + - left_hip + - right_hip + - left_knee + - right_knee + - left_ankle + - right_ankle diff --git a/anylabeling/services/auto_labeling/__base__/yolo.py b/anylabeling/services/auto_labeling/__base__/yolo.py index b37cf6b3..ee04a99d 100644 --- a/anylabeling/services/auto_labeling/__base__/yolo.py +++ b/anylabeling/services/auto_labeling/__base__/yolo.py @@ -52,6 +52,8 @@ def __init__(self, model_config, on_message) -> None: self.net = OnnxBaseModel(model_abs_path, __preferred_device__) self.classes = self.config["classes"] self.input_shape = self.net.get_input_shape()[-2:] + self.hide_box = self.config.get("hide_box", True) + self.keypoints = self.config.get("keypoints", []) self.nms_thres = self.config["nms_threshold"] self.conf_thres = self.config["confidence_threshold"] self.stride = self.config.get("stride", 32) diff --git a/anylabeling/services/auto_labeling/model_manager.py b/anylabeling/services/auto_labeling/model_manager.py index d829949f..44d94f8c 100644 --- a/anylabeling/services/auto_labeling/model_manager.py +++ b/anylabeling/services/auto_labeling/model_manager.py @@ -193,6 +193,7 @@ def load_custom_model(self, config_file): "ram", "yolov5_seg", "yolov5_ram", + "yolov8_pose", ] ): self.new_model_status.emit( @@ -470,6 +471,28 @@ def _load_model(self, model_id): ) ) return + elif model_config["type"] == "yolov8_pose": + from .yolov8_pose import YOLOv8_Pose + + try: + model_config["model"] = YOLOv8_Pose( + model_config, on_message=self.new_model_status.emit + ) + self.auto_segmentation_model_unselected.emit() + except Exception as e: # noqa + self.new_model_status.emit( + self.tr( + "Error in loading model: {error_message}".format( + error_message=str(e) + ) + ) + ) + print( + "Error in loading model: {error_message}".format( + error_message=str(e) + ) + ) + return elif model_config["type"] == "yolox": from .yolox import YOLOX diff --git a/anylabeling/services/auto_labeling/trackers/byte_track/bytetracker.py b/anylabeling/services/auto_labeling/trackers/byte_track/bytetracker.py index aa09db97..9ee16cdd 100755 --- a/anylabeling/services/auto_labeling/trackers/byte_track/bytetracker.py +++ b/anylabeling/services/auto_labeling/trackers/byte_track/bytetracker.py @@ -1,6 +1,6 @@ from .tracker.byte_tracker import BYTETracker import numpy as np -from ...utils.points_conversion import tlwh_to_xyxy +from ...utils.points_conversion import tlwh2xyxy class ByteTrack(object): @@ -51,7 +51,7 @@ def _tracker_update(self, dets: np.ndarray, image_info: dict): track_id = online_target.track_id vertical = tlwh[2] / tlwh[3] > self.aspect_ratio_thresh if tlwh[2] * tlwh[3] > self.min_box_area and not vertical: - online_xyxys.append(tlwh_to_xyxy(tlwh)) + online_xyxys.append(tlwh2xyxy(tlwh)) online_ids.append(track_id) online_scores.append(online_target.score) return online_xyxys, online_ids, online_scores diff --git a/anylabeling/services/auto_labeling/utils/box.py b/anylabeling/services/auto_labeling/utils/box.py index 3dba41ce..54db2586 100644 --- a/anylabeling/services/auto_labeling/utils/box.py +++ b/anylabeling/services/auto_labeling/utils/box.py @@ -18,7 +18,7 @@ def box_iou(box1, box2): return iou # NxM -def rescale_box(input_shape, boxes, image_shape): +def rescale_box(input_shape, boxes, image_shape, kpts=False): '''Rescale the output to the original image shape''' ratio = min( input_shape[0] / image_shape[0], @@ -35,6 +35,11 @@ def rescale_box(input_shape, boxes, image_shape): boxes[:, 1] = np.clip(boxes[:, 1], 0, image_shape[0]) # y1 boxes[:, 2] = np.clip(boxes[:, 2], 0, image_shape[1]) # x2 boxes[:, 3] = np.clip(boxes[:, 3], 0, image_shape[0]) # y2 + if kpts: + num_kpts = boxes.shape[1] // 3 + for i in range(2, num_kpts + 1): + boxes[:, i * 3 - 1] = (boxes[:, i * 3 - 1] - padding[0]) / ratio + boxes[:, i * 3] = (boxes[:, i * 3] - padding[1]) / ratio return boxes @@ -72,6 +77,34 @@ def rescale_box_and_landmark(input_shape, boxes, lmdks, image_shape): return np.round(boxes), np.round(lmdks) +def rescale_tlwh(input_shape, boxes, image_shape, kpts=False): + '''Rescale the output to the original image shape''' + ratio = min( + input_shape[0] / image_shape[0], + input_shape[1] / image_shape[1], + ) + padding = ( + (input_shape[1] - image_shape[1] * ratio) / 2, + (input_shape[0] - image_shape[0] * ratio) / 2, + ) + boxes[:, 0] -= padding[0] + boxes[:, 1] -= padding[1] + boxes[:, :4] /= ratio + boxes[:, 0] = np.clip(boxes[:, 0], 0, image_shape[1]) # x1 + boxes[:, 1] = np.clip(boxes[:, 1], 0, image_shape[0]) # y1 + boxes[:, 2] = np.clip( + (boxes[:, 0] + boxes[:, 2]), 0, image_shape[1] + ) # x2 + boxes[:, 3] = np.clip( + (boxes[:, 1] + boxes[:, 3]), 0, image_shape[0] + ) # y2 + if kpts: + num_kpts = boxes.shape[1] // 3 + for i in range(2, num_kpts + 1): + boxes[:, i * 3 - 1] = (boxes[:, i * 3 - 1] - padding[0]) / ratio + boxes[:, i * 3] = (boxes[:, i * 3] - padding[1]) / ratio + return boxes + def numpy_nms(boxes, scores, iou_threshold): idxs = scores.argsort() diff --git a/anylabeling/services/auto_labeling/utils/points_conversion.py b/anylabeling/services/auto_labeling/utils/points_conversion.py index 0c6a78df..c464a1bb 100755 --- a/anylabeling/services/auto_labeling/utils/points_conversion.py +++ b/anylabeling/services/auto_labeling/utils/points_conversion.py @@ -25,26 +25,20 @@ def xywh2xyxy(x): return y -def tlwh_to_xyxy(x): +def tlwh2xyxy(x): """" Convert tlwh to xyxy """ - x1 = x[0] - y1 = x[1] - x2 = x[2] + x1 - y2 = x[3] + y1 - return [x1, y1, x2, y2] + y = np.copy(x) + y[:, 2] = x[:, 2] + x[:, 0] + y[:, 3] = x[:, 3] + x[:, 1] + return y -def xyxy_to_tlwh(x): - tlwh_bboxs = [] - for i, box in enumerate(x): - x1, y1, x2, y2 = [int(i) for i in box] - top = x1 - left = y1 - w = int(x2 - x1) - h = int(y2 - y1) - tlwh_obj = [top, left, w, h] - tlwh_bboxs.append(tlwh_obj) - return tlwh_bboxs +def xyxy2tlwh(x): + """" Convert xyxy to tlwh """ + y = np.copy(x) + y[:, 2] = x[:, 2] - x[:, 0] + y[:, 3] = x[:, 3] - x[:, 1] + return y def bbox_cxcywh_to_xyxy(x): diff --git a/anylabeling/services/auto_labeling/yolov8_pose.py b/anylabeling/services/auto_labeling/yolov8_pose.py new file mode 100644 index 00000000..48dd1ae6 --- /dev/null +++ b/anylabeling/services/auto_labeling/yolov8_pose.py @@ -0,0 +1,113 @@ +import logging +import numpy as np + +from PyQt5 import QtCore +from PyQt5.QtCore import QCoreApplication + +from anylabeling.app_info import __preferred_device__ +from anylabeling.views.labeling.shape import Shape +from anylabeling.views.labeling.utils.opencv import qt_img_to_rgb_cv_img +from .types import AutoLabelingResult +from .__base__.yolo import YOLO +from .utils import ( + numpy_nms, + xywh2xyxy, + xyxy2tlwh, + rescale_tlwh, +) + +class YOLOv8_Pose(YOLO): + + class Meta: + required_config_names = [ + "type", + "name", + "display_name", + "model_path", + "nms_threshold", + "confidence_threshold", + "classes", + ] + widgets = ["button_run"] + output_modes = { + "rectangle": QCoreApplication.translate("Model", "Rectangle"), + "point": QCoreApplication.translate("Model", "Point"), + } + default_output_mode = "rectangle" + + def postprocess( + self, + prediction, + max_det=1000, + ): + + """ + Args: + prediction: (1, 56, *), where 56 = 4 + 1 + 3 * 17 + 4 -> box_xywh + 1 -> box_score + 3*17 -> (x, y, kpt_score) * 17 keypoints + """ + prediction = prediction.transpose((0, 2, 1))[0] + x = prediction[prediction[:, 4] > self.conf_thres] + if len(x) == 0: + return [] + x[:, :4] = xywh2xyxy(x[:, :4]) + keep_idx = numpy_nms(x[:, :4], x[:, 4], self.nms_thres) # NMS + if keep_idx.shape[0] > max_det: # limit detections + keep_idx = keep_idx[:max_det] + keep_label = [] + for i in keep_idx: + keep_label.append(x[i].tolist()) + xyxy = np.array(keep_label) + tlwh = xyxy2tlwh(xyxy) + return tlwh + + def predict_shapes(self, image, image_path=None): + """ + Predict shapes from image + """ + + if image is None: + return [] + + try: + image = qt_img_to_rgb_cv_img(image, image_path) + except Exception as e: # noqa + logging.warning("Could not inference model") + logging.warning(e) + return [] + + blob = self.preprocess(image) + predictions = self.net.get_ort_inference(blob) + results = self.postprocess(predictions) + + if len(results) == 0: + return AutoLabelingResult([], replace=True) + results = rescale_tlwh( + self.input_shape, results, image.shape, kpts=True + ) + + shapes = [] + for r in reversed(results): + xyxy, _, kpts = r[:4], r[4], r[5:] + + if not self.hide_box: + rectangle_shape = Shape( + label=str(self.classes[0]), + shape_type="rectangle", + ) + rectangle_shape.add_point(QtCore.QPointF(xyxy[0], xyxy[1])) + rectangle_shape.add_point(QtCore.QPointF(xyxy[2], xyxy[3])) + shapes.append(rectangle_shape) + + interval = 3 + for i in range(0, len(kpts), interval): + x, y, kpt_score = kpts[i: i + 3] + if kpt_score > self.conf_thres: + label = self.keypoints[int(i//interval)] + point_shape = Shape(label=label, shape_type="point") + point_shape.add_point(QtCore.QPointF(x, y)) + shapes.append(point_shape) + result = AutoLabelingResult(shapes, replace=True) + return result diff --git a/docs/models_list.md b/docs/models_list.md index 70edc724..22b1b13d 100644 --- a/docs/models_list.md +++ b/docs/models_list.md @@ -133,15 +133,31 @@ | yolov5m.onnx | [YOLOv5m-ByteTrack](https://github.com/ifzhang/ByteTrack) | [yolov5m_bytetrack.yaml](../anylabeling/configs/auto_labeling/yolov5m_bytetrack.yaml) | 81.19MB | [baidu](https://pan.baidu.com/s/1oB9Vp-s7viOaLxAhjzAnFQ?pwd=vc4v) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov5m.onnx) | | yolov5m.onnx | [YOLOv5m-OCSort](https://github.com/noahcao/OC_SORT) | [yolov5m_ocsort.yaml](../anylabeling/configs/auto_labeling/yolov5m_ocsort.yaml) | 81.19MB | [baidu](https://pan.baidu.com/s/1oB9Vp-s7viOaLxAhjzAnFQ?pwd=vc4v) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov5m.onnx) | + +### Keypoint Detection + +- Facial Landmark Detection + +|Name|Description|Configuration|Size|Link| +| --- | --- | --- | --- | --- | +| yolov6lite_l_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_l_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_l_face.yaml) | 4.16MB | [baidu](https://pan.baidu.com/s/1Ot14duf7GUPMcHHIUY4TSg?pwd=xy1m) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_l_face.onnx) | +| yolov6lite_m_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_m_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_m_face.yaml) | 3.00MB | [baidu](https://pan.baidu.com/s/1uL01_pqldSXdFnlDM0OFZg?pwd=n9xl) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_m_face.onnx) | +| yolov6lite_s_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_s_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_s_face.yaml) | 2.10MB | [baidu](https://pan.baidu.com/s/1nN_b4_WutwhnJYEhyqjDYg?pwd=ig4d) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_s_face.onnx) | + +- Pose Estimation + +|Name|Description|Configuration|Size|Link| +| --- | --- | --- | --- | --- | +| yolov8n-pose.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-COCO | [yolov8n_pose.yaml](../anylabeling/configs/auto_labeling/yolov8n_pose.yaml) | 12.75MB | [baidu](https://pan.baidu.com/s/1nFxtNTuLn9vQJWId-EOAKQ?pwd=37ej) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v1.0.0/yolov8n_pose.onnx) | +| yolov8x-pose-p6.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-COCO | [yolov8x_pose_p6.yaml](../anylabeling/configs/auto_labeling/yolov8x_pose_p6.yaml) | 378.92MB | [baidu](https://pan.baidu.com/s/1VRszCtMnHo2zeshuk93Aew?pwd=xvko) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v1.0.0/yolov8x_pose_p6.onnx) | +| dw-ll_ucoco_384.onnx | [DWPose](https://github.com/IDEA-Research/DWPose/tree/main)(2D human whole-body pose estimation) | [yolox_l_dwpose_ucoco.yaml](../anylabeling/configs/auto_labeling/yolox_l_dwpose_ucoco.yaml) | 128.17MB | [baidu](https://pan.baidu.com/s/1I6CAFhW2YAowN80yweGVpg?pwd=pzf4) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.0/dw-ll_ucoco_384.onnx) | +| yolox_l.onnx | [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)(2D human whole-body pose estimation) | [yolox_l_dwpose_ucoco.yaml](../anylabeling/configs/auto_labeling/yolox_l_dwpose_ucoco.yaml) | 206.71MB | [baidu](https://pan.baidu.com/s/1NpFiX1JN-0jIvd38tQIDcQ?pwd=aqk5) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.0/yolox_l.onnx) | + + ### Union Task |Name|Description|Configuration|Size|Link| | --- | --- | --- | --- | --- | -| yolov6lite_l_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_l_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_l_face.yaml) | 4.16MB | [baidu](https://pan.baidu.com/s/1Ot14duf7GUPMcHHIUY4TSg?pwd=xy1m) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_l_face.onnx) | -| yolov6lite_m_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_m_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_m_face.yaml) | 3.00MB | [baidu](https://pan.baidu.com/s/1uL01_pqldSXdFnlDM0OFZg?pwd=n9xl) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_m_face.onnx) | -| yolov6lite_s_face.onnx | [Facial Landmark Detection](https://github.com/meituan/YOLOv6/tree/yolov6-face) | [yolov6lite_s_face.yaml](../anylabeling/configs/auto_labeling/yolov6lite_s_face.yaml) | 2.10MB | [baidu](https://pan.baidu.com/s/1nN_b4_WutwhnJYEhyqjDYg?pwd=ig4d) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov6lite_s_face.onnx) | -| dw-ll_ucoco_384.onnx | [DWPose](https://github.com/IDEA-Research/DWPose/tree/main)(全身人体姿态估计) | [yolox_l_dwpose_ucoco.yaml](../anylabeling/configs/auto_labeling/yolox_l_dwpose_ucoco.yaml) | 128.17MB | [baidu](https://pan.baidu.com/s/1I6CAFhW2YAowN80yweGVpg?pwd=pzf4) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.0/dw-ll_ucoco_384.onnx) | -| yolox_l.onnx | [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)(全身人体姿态估计) | [yolox_l_dwpose_ucoco.yaml](../anylabeling/configs/auto_labeling/yolox_l_dwpose_ucoco.yaml) | 206.71MB | [baidu](https://pan.baidu.com/s/1NpFiX1JN-0jIvd38tQIDcQ?pwd=aqk5) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.0/yolox_l.onnx) | | resnet50.onnx | [ResNet50](https://arxiv.org/abs/1512.03385)-ImageNet(检测+分类级联模型) | [yolov5s_resnet50.yaml](../anylabeling/configs/auto_labeling/yolov5s_resnet50.yaml) | 97.42MB | [baidu](https://pan.baidu.com/s/1byapPRVib7rarAMTmoSkKQ?pwd=xllt) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/resnet50.onnx) | | yolov5s.onnx | [YOLOv5](https://github.com/ultralytics/yolov5)-COCO(检测+分类级联模型) | [yolov5s_resnet50.yaml](../anylabeling/configs/auto_labeling/yolov5s_resnet50.yaml) | 27.98MB | [baidu](https://pan.baidu.com/s/18I8ugM29NKjNlVEsnYYuWA?pwd=z8dl) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolov5s.onnx) | | mobile_sam.encoder.onnx | [MobileSAM](https://arxiv.org/abs/2306.14289) encoder(YOLOv5-SAM) | [yolov5s_mobile_sam_vit_h.yaml](../anylabeling/configs/auto_labeling/yolov5s_mobile_sam_vit_h.yaml) | 26.85MB | [baidu](https://pan.baidu.com/s/1knn_NkrRAgfJms6d9FsErg?pwd=xbi6) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.0/mobile_sam.encoder.onnx) |