diff --git a/anylabeling/configs/auto_labeling/models.yaml b/anylabeling/configs/auto_labeling/models.yaml index f810013e..78a21060 100644 --- a/anylabeling/configs/auto_labeling/models.yaml +++ b/anylabeling/configs/auto_labeling/models.yaml @@ -62,6 +62,16 @@ config_file: ":/ram_plus_swin_large_14m.yaml" - model_name: "ram_swin_large_14m-r20231024" config_file: ":/ram_swin_large_14m.yaml" +- model_name: "rtdetrv2l-r20240730" + config_file: ":/rtdetrv2l.yaml" +- model_name: "rtdetrv2m-r20240730" + config_file: ":/rtdetrv2m.yaml" +- model_name: "rtdetrv2m7x-r20240730" + config_file: ":/rtdetrv2m7x.yaml" +- model_name: "rtdetrv2s-r20240730" + config_file: ":/rtdetrv2s.yaml" +- model_name: "rtdetrv2x-r20240730" + config_file: ":/rtdetrv2x.yaml" - model_name: "rtdetr_r50-r20230520" config_file: ":/rtdetr_r50.yaml" - model_name: "rtmdet_m_coco_person_rtmo_m-r20240112" diff --git a/anylabeling/configs/auto_labeling/rtdetrv2l.yaml b/anylabeling/configs/auto_labeling/rtdetrv2l.yaml new file mode 100644 index 00000000..5768a564 --- /dev/null +++ b/anylabeling/configs/auto_labeling/rtdetrv2l.yaml @@ -0,0 +1,86 @@ +type: rtdetrv2 +name: rtdetrv2l-r20240730 +display_name: RT-DETRv2-L PaddleDetection +model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r50vd_6x_coco.onnx +score_threshold: 0.45 +classes: + - person + - bicycle + - car + - motorcycle + - airplane + - bus + - train + - truck + - boat + - traffic light + - fire hydrant + - stop sign + - parking meter + - bench + - bird + - cat + - dog + - horse + - sheep + - cow + - elephant + - bear + - zebra + - giraffe + - backpack + - umbrella + - handbag + - tie + - suitcase + - frisbee + - skis + - snowboard + - sports ball + - kite + - baseball bat + - baseball glove + - skateboard + - surfboard + - tennis racket + - bottle + - wine glass + - cup + - fork + - knife + - spoon + - bowl + - banana + - apple + - sandwich + - orange + - broccoli + - carrot + - hot dog + - pizza + - donut + - cake + - chair + - couch + - potted plant + - bed + - dining table + - toilet + - tv + - laptop + - mouse + - remote + - keyboard + - cell phone + - microwave + - oven + - toaster + - sink + - refrigerator + - book + - clock + - vase + - scissors + - teddy bear + - hair drier + - toothbrush diff --git a/anylabeling/configs/auto_labeling/rtdetrv2m.yaml b/anylabeling/configs/auto_labeling/rtdetrv2m.yaml new file mode 100644 index 00000000..e0176180 --- /dev/null +++ b/anylabeling/configs/auto_labeling/rtdetrv2m.yaml @@ -0,0 +1,86 @@ +type: rtdetrv2 +name: rtdetrv2m-r20240730 +display_name: RT-DETRv2-M PaddleDetection +model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r34vd_120e_coco.onnx +score_threshold: 0.45 +classes: + - person + - bicycle + - car + - motorcycle + - airplane + - bus + - train + - truck + - boat + - traffic light + - fire hydrant + - stop sign + - parking meter + - bench + - bird + - cat + - dog + - horse + - sheep + - cow + - elephant + - bear + - zebra + - giraffe + - backpack + - umbrella + - handbag + - tie + - suitcase + - frisbee + - skis + - snowboard + - sports ball + - kite + - baseball bat + - baseball glove + - skateboard + - surfboard + - tennis racket + - bottle + - wine glass + - cup + - fork + - knife + - spoon + - bowl + - banana + - apple + - sandwich + - orange + - broccoli + - carrot + - hot dog + - pizza + - donut + - cake + - chair + - couch + - potted plant + - bed + - dining table + - toilet + - tv + - laptop + - mouse + - remote + - keyboard + - cell phone + - microwave + - oven + - toaster + - sink + - refrigerator + - book + - clock + - vase + - scissors + - teddy bear + - hair drier + - toothbrush diff --git a/anylabeling/configs/auto_labeling/rtdetrv2m7x.yaml b/anylabeling/configs/auto_labeling/rtdetrv2m7x.yaml new file mode 100644 index 00000000..6c0f92cf --- /dev/null +++ b/anylabeling/configs/auto_labeling/rtdetrv2m7x.yaml @@ -0,0 +1,86 @@ +type: rtdetrv2 +name: rtdetrv2m7x-r20240730 +display_name: RT-DETRv2-M* PaddleDetection +model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r50vd_m_7x_coco.onnx +score_threshold: 0.45 +classes: + - person + - bicycle + - car + - motorcycle + - airplane + - bus + - train + - truck + - boat + - traffic light + - fire hydrant + - stop sign + - parking meter + - bench + - bird + - cat + - dog + - horse + - sheep + - cow + - elephant + - bear + - zebra + - giraffe + - backpack + - umbrella + - handbag + - tie + - suitcase + - frisbee + - skis + - snowboard + - sports ball + - kite + - baseball bat + - baseball glove + - skateboard + - surfboard + - tennis racket + - bottle + - wine glass + - cup + - fork + - knife + - spoon + - bowl + - banana + - apple + - sandwich + - orange + - broccoli + - carrot + - hot dog + - pizza + - donut + - cake + - chair + - couch + - potted plant + - bed + - dining table + - toilet + - tv + - laptop + - mouse + - remote + - keyboard + - cell phone + - microwave + - oven + - toaster + - sink + - refrigerator + - book + - clock + - vase + - scissors + - teddy bear + - hair drier + - toothbrush diff --git a/anylabeling/configs/auto_labeling/rtdetrv2s.yaml b/anylabeling/configs/auto_labeling/rtdetrv2s.yaml new file mode 100644 index 00000000..a76e88e6 --- /dev/null +++ b/anylabeling/configs/auto_labeling/rtdetrv2s.yaml @@ -0,0 +1,86 @@ +type: rtdetrv2 +name: rtdetrv2s-r20240730 +display_name: RT-DETRv2-S PaddleDetection +model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r18vd_120e_coco.onnx +score_threshold: 0.45 +classes: + - person + - bicycle + - car + - motorcycle + - airplane + - bus + - train + - truck + - boat + - traffic light + - fire hydrant + - stop sign + - parking meter + - bench + - bird + - cat + - dog + - horse + - sheep + - cow + - elephant + - bear + - zebra + - giraffe + - backpack + - umbrella + - handbag + - tie + - suitcase + - frisbee + - skis + - snowboard + - sports ball + - kite + - baseball bat + - baseball glove + - skateboard + - surfboard + - tennis racket + - bottle + - wine glass + - cup + - fork + - knife + - spoon + - bowl + - banana + - apple + - sandwich + - orange + - broccoli + - carrot + - hot dog + - pizza + - donut + - cake + - chair + - couch + - potted plant + - bed + - dining table + - toilet + - tv + - laptop + - mouse + - remote + - keyboard + - cell phone + - microwave + - oven + - toaster + - sink + - refrigerator + - book + - clock + - vase + - scissors + - teddy bear + - hair drier + - toothbrush diff --git a/anylabeling/configs/auto_labeling/rtdetrv2x.yaml b/anylabeling/configs/auto_labeling/rtdetrv2x.yaml new file mode 100644 index 00000000..e08a77d7 --- /dev/null +++ b/anylabeling/configs/auto_labeling/rtdetrv2x.yaml @@ -0,0 +1,86 @@ +type: rtdetrv2 +name: rtdetrv2x-r20240730 +display_name: RT-DETRv2-X PaddleDetection +model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r101vd_6x_coco.onnx +score_threshold: 0.45 +classes: + - person + - bicycle + - car + - motorcycle + - airplane + - bus + - train + - truck + - boat + - traffic light + - fire hydrant + - stop sign + - parking meter + - bench + - bird + - cat + - dog + - horse + - sheep + - cow + - elephant + - bear + - zebra + - giraffe + - backpack + - umbrella + - handbag + - tie + - suitcase + - frisbee + - skis + - snowboard + - sports ball + - kite + - baseball bat + - baseball glove + - skateboard + - surfboard + - tennis racket + - bottle + - wine glass + - cup + - fork + - knife + - spoon + - bowl + - banana + - apple + - sandwich + - orange + - broccoli + - carrot + - hot dog + - pizza + - donut + - cake + - chair + - couch + - potted plant + - bed + - dining table + - toilet + - tv + - laptop + - mouse + - remote + - keyboard + - cell phone + - microwave + - oven + - toaster + - sink + - refrigerator + - book + - clock + - vase + - scissors + - teddy bear + - hair drier + - toothbrush diff --git a/anylabeling/services/auto_labeling/model_manager.py b/anylabeling/services/auto_labeling/model_manager.py index f269e729..f618070f 100644 --- a/anylabeling/services/auto_labeling/model_manager.py +++ b/anylabeling/services/auto_labeling/model_manager.py @@ -64,6 +64,7 @@ class ModelManager(QObject): "yolov10", "depth_anything_v2", "yolow_ram", + "rtdetrv2", ] model_configs_changed = pyqtSignal(list) @@ -1054,6 +1055,28 @@ def _load_model(self, model_id): ) ) return + elif model_config["type"] == "rtdetrv2": + from .rtdetrv2 import RTDETRv2 + + try: + model_config["model"] = RTDETRv2( + model_config, on_message=self.new_model_status.emit + ) + self.auto_segmentation_model_unselected.emit() + except Exception as e: # noqa + self.new_model_status.emit( + self.tr( + "Error in loading model: {error_message}".format( + error_message=str(e) + ) + ) + ) + print( + "Error in loading model: {error_message}".format( + error_message=str(e) + ) + ) + return elif model_config["type"] == "yolov6_face": from .yolov6_face import YOLOv6Face diff --git a/anylabeling/services/auto_labeling/rtdetrv2.py b/anylabeling/services/auto_labeling/rtdetrv2.py new file mode 100644 index 00000000..e370038e --- /dev/null +++ b/anylabeling/services/auto_labeling/rtdetrv2.py @@ -0,0 +1,150 @@ +import logging +import os + +import cv2 +import numpy as np +from PyQt5 import QtCore +from PyQt5.QtCore import QCoreApplication + +from anylabeling.app_info import __preferred_device__ +from anylabeling.views.labeling.shape import Shape +from anylabeling.views.labeling.utils.opencv import qt_img_to_rgb_cv_img +from .model import Model +from .types import AutoLabelingResult +from .engines.build_onnx_engine import OnnxBaseModel +from .utils.points_conversion import cxywh2xyxy + + +class RTDETRv2(Model): + """Object detection model using RTDETRv2""" + + class Meta: + required_config_names = [ + "type", + "name", + "display_name", + "model_path", + "score_threshold", + "classes", + ] + widgets = [ + "button_run", + "input_conf", + "edit_conf", + "toggle_preserve_existing_annotations", + ] + output_modes = { + "rectangle": QCoreApplication.translate("Model", "Rectangle"), + } + default_output_mode = "rectangle" + + def __init__(self, model_config, on_message) -> None: + # Run the parent class's init method + super().__init__(model_config, on_message) + model_name = self.config["type"] + model_abs_path = self.get_model_abs_path(self.config, "model_path") + if not model_abs_path or not os.path.isfile(model_abs_path): + raise FileNotFoundError( + QCoreApplication.translate( + "Model", + f"Could not download or initialize {model_name} model.", + ) + ) + self.net = OnnxBaseModel(model_abs_path, __preferred_device__) + self.classes = self.config["classes"] + self.input_shape = self.net.get_input_shape()[-2:] + self.conf_thres = self.config["score_threshold"] + self.replace = True + + def set_auto_labeling_conf(self, value): + """ set auto labeling confidence threshold """ + self.conf_thres = value + + def set_auto_labeling_preserve_existing_annotations_state(self, state): + """ Toggle the preservation of existing annotations based on the checkbox state. """ + self.replace = not state + + def preprocess(self, input_image): + """ + Pre-processes the input image before feeding it to the network. + + Args: + input_image (numpy.ndarray): The input image to be processed. + + Returns: + numpy.ndarray: The pre-processed output. + """ + # Get the image width and height + image_h, image_w = input_image.shape[:2] + input_h, input_w = self.input_shape + # Perform the pre-processing steps + image = cv2.resize(input_image, (input_w, input_h)) + image = image.transpose((2, 0, 1)) # HWC to CHW + image = np.ascontiguousarray(image).astype("float32") + image /= 255 # 0 - 255 to 0.0 - 1.0 + if len(image.shape) == 3: + image = image[None] + orig_size = np.array([image_w, image_h], np.int64)[None, :] + blob = {"images": image, "orig_target_sizes": orig_size} + return blob + + def postprocess(self, outputs): + """ + Post-processes the network's output. + + Args: + outputs (numpy.ndarray): The output from the network. + + Returns: + scores (List[float]): prediction score + indexs (List[int]): category index + bboxes (List[list[int]]): xyxy format + """ + indexs, boxes, scores = outputs + scores = scores[0] + indexs = indexs[0][scores > self.conf_thres] + bboxes = boxes[0][scores > self.conf_thres] + + return scores, indexs, bboxes + + def predict_shapes(self, image, image_path=None): + """ + Predict shapes from image + """ + + if image is None: + return [] + + try: + image = qt_img_to_rgb_cv_img(image, image_path) + except Exception as e: # noqa + logging.warning("Could not inference model") + logging.warning(e) + return [] + + blob = self.preprocess(image) + detections = self.net.get_ort_inference( + None, inputs=blob, extract=False + ) + scores, indexs, bboxes = self.postprocess(detections) + shapes = [] + + for score, index, box in zip(scores, indexs, bboxes): + xmin, ymin, xmax, ymax = box + label = self.classes[int(index)] + shape = Shape( + label=str(label), + score=float(score), + shape_type="rectangle" + ) + shape.add_point(QtCore.QPointF(xmin, ymin)) + shape.add_point(QtCore.QPointF(xmax, ymin)) + shape.add_point(QtCore.QPointF(xmax, ymax)) + shape.add_point(QtCore.QPointF(xmin, ymax)) + shapes.append(shape) + + result = AutoLabelingResult(shapes, replace=self.replace) + return result + + def unload(self): + del self.net diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md index d50bc275..88f1a022 100644 --- a/docs/en/model_zoo.md +++ b/docs/en/model_zoo.md @@ -65,6 +65,11 @@ | Gold_m_pre_dist.onnx | [Gold-YOLO](https://github.com/huawei-noah/Efficient-Computing/tree/master/Detection/Gold-YOLO)-COCO | [gold_yolo_m.yaml](../../anylabeling/configs/auto_labeling/gold_yolo_m.yaml) | 169.88MB | [baidu](https://pan.baidu.com/s/1lY5VljoL9pZxadpe9DEUdQ?pwd=ephp) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.3.0/Gold_m_pre_dist.onnx) | | Gold_l_pre_dist.onnx | [Gold-YOLO](https://github.com/huawei-noah/Efficient-Computing/tree/master/Detection/Gold-YOLO)-COCO | [gold_yolo_l.yaml](../../anylabeling/configs/auto_labeling/gold_yolo_l.yaml) | 286.79MB | [baidu](https://pan.baidu.com/s/1ySxB3R18oWuIdYzKRLrflg?pwd=1wlk) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.3.0/Gold_l_pre_dist.onnx) | | rtdetr_r50vd_6x_coco.onnx | [RT-DETR](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/rtdetr/README.md)-COCO | [rtdetr_r50.yaml](../../anylabeling/configs/auto_labeling/rtdetr_r50.yaml) | 160.96MB | [baidu](https://pan.baidu.com/s/11vkwveTeZWwOFUpk8qrqVg?pwd=3sc4) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/rtdetr_r50vd_6x_coco.onnx) | +| rtdetrv2_r101vd_6x_coco.onnx | [RT-DETRv2-X](https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch)-COCO | [rtdetrv2x.yaml](../../anylabeling/configs/auto_labeling/rtdetrv2x.yaml) | 286.48MB | [baidu](https://pan.baidu.com/s/11lpaWt4IyFb39jJxKtlxnQ?pwd=t2kv) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r101vd_6x_coco.onnx) | +| rtdetrv2_r50vd_6x_coco.onnx | [RT-DETRv2-L](https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch)-COCO | [rtdetrv2l.yaml](../../anylabeling/configs/auto_labeling/rtdetrv2l.yaml) | 161.38MB | [baidu](https://pan.baidu.com/s/1ex1fzc18wnIa7YREBGD3_g?pwd=qrab) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r50vd_6x_coco.onnx) | +| rtdetrv2_r50vd_m_7x_coco.onnx | [RT-DETRv2-M*](https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch)-COCO | [rtdetrv2m7x.yaml](../../anylabeling/configs/auto_labeling/rtdetrv2m7x.yaml) | 126.52MB | [baidu](https://pan.baidu.com/s/1wzf6wJoysHfvAbgtv337oA?pwd=v8lw) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r50vd_m_7x_coco.onnx) | +| rtdetrv2_r34vd_120e_coco.onnx | [RT-DETRv2-M](https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch)-COCO | [rtdetrv2m.yaml](../../anylabeling/configs/auto_labeling/rtdetrv2m.yaml) | 119.73MB | [baidu](https://pan.baidu.com/s/1ByaV67J7yBKtznvwHvtRiQ?pwd=oeny) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r34vd_120e_coco.onnx) | +| rtdetrv2_r18vd_120e_coco.onnx | [RT-DETRv2-S](https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch)-COCO | [rtdetrv2s.yaml](../../anylabeling/configs/auto_labeling/rtdetrv2s.yaml) | 76.80MB | [baidu](https://pan.baidu.com/s/1jEuzrrn3P1GQcQ28qY-Q2Q?pwd=axm2) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r18vd_120e_coco.onnx) | | yolo_nas_l.onnx | [YOLO-NAS](https://github.com/Deci-AI/super-gradients/tree/master)-COCO | [yolo_nas_l.yaml](../../anylabeling/configs/auto_labeling/yolo_nas_l.yaml) | 160.38MB | [baidu](https://pan.baidu.com/s/1ckVfkul8ZckiQd3PfM-Htw?pwd=wmk8) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolo_nas_l.onnx) | | yolo_nas_m.onnx | [YOLO-NAS](https://github.com/Deci-AI/super-gradients/tree/master)-COCO | [yolo_nas_m.yaml](../../anylabeling/configs/auto_labeling/yolo_nas_m.yaml) | 121.87MB | [baidu](https://pan.baidu.com/s/1bPH1m6dSxxG2KG3zrDywAA?pwd=1mlh) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolo_nas_m.onnx) | | yolo_nas_s.onnx | [YOLO-NAS](https://github.com/Deci-AI/super-gradients/tree/master)-COCO | [yolo_nas_s.yaml](../../anylabeling/configs/auto_labeling/yolo_nas_s.yaml) | 46.62MB | [baidu](https://pan.baidu.com/s/1yA_n6QxJGXoR38xlY25q_w?pwd=4wdw) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolo_nas_s.onnx) | @@ -90,10 +95,8 @@ | yolov8x-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | [yolov8x_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8x_oiv7.yaml) | 262.24MB | [baidu](https://pan.baidu.com/s/1in76UG8GhsXw6ACNqOVYKA?pwd=o07s) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8x-oiv7.onnx) | | yolov8l-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | [yolov8l_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8l_oiv7.yaml) | 168.28MB | [baidu](https://pan.baidu.com/s/1p9-oYvkV-IXvy6RlTxLWfQ?pwd=fs07) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8l-oiv7.onnx) | | yolov8m-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | [yolov8m_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8m_oiv7.yaml) | 100.05MB | [baidu](https://pan.baidu.com/s/1E7c-XIriTLH-gmBm5UmdWA?pwd=36n6) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8m-oiv7.onnx) | -| yolov8s-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | -[yolov8s_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8s_oiv7.yaml) | 43.47MB | [baidu](https://pan.baidu.com/s/1a6bP_x76Gk7LIf8uisWqfA?pwd=53lu) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8s-oiv7.onnx) | -| yolov8n-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | -[yolov8n_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8n_oiv7.yaml) | 13.47MB | [baidu](https://pan.baidu.com/s/1u8sUFr2LmBGwqZCbGb6Xww?pwd=l6ip) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8n-oiv7.onnx) | +| yolov8s-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | [yolov8s_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8s_oiv7.yaml) | 43.47MB | [baidu](https://pan.baidu.com/s/1a6bP_x76Gk7LIf8uisWqfA?pwd=53lu) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8s-oiv7.onnx) | +| yolov8n-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | [yolov8n_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8n_oiv7.yaml) | 13.47MB | [baidu](https://pan.baidu.com/s/1u8sUFr2LmBGwqZCbGb6Xww?pwd=l6ip) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8n-oiv7.onnx) | | yolov9c.onnx | [YOLOv9](https://github.com/WongKinYiu/yolov9)-COCO | [yolov9c.yaml](../../anylabeling/configs/auto_labeling/yolov9c.yaml) | 195.34MB | [baidu](https://pan.baidu.com/s/1iHZH29HE8s4fG8X7aHcn5g?pwd=zb5w) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.2/yolov9c.onnx) | | yolov9e.onnx | [YOLOv9](https://github.com/WongKinYiu/yolov9)-COCO | [yolov9e.yaml](../../anylabeling/configs/auto_labeling/yolov9e.yaml) | 265.43MB | [baidu](https://pan.baidu.com/s/1oJJQ1L5UfKi43kT96tauSA?pwd=vpa1) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.2/yolov9e.onnx) | | gelan-c.onnx | [YOLOv9](https://github.com/WongKinYiu/yolov9)-COCO | [gelan-c.yaml](../../anylabeling/configs/auto_labeling/yolov9_gelan_c.yaml) | 97.43MB | [baidu](https://pan.baidu.com/s/1cM0Tc056ICuA5jvesacCng?pwd=whb3) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.2/gelan-c.onnx) | diff --git a/docs/zh_cn/model_zoo.md b/docs/zh_cn/model_zoo.md index 98a1d43c..e0d0dcba 100644 --- a/docs/zh_cn/model_zoo.md +++ b/docs/zh_cn/model_zoo.md @@ -63,6 +63,11 @@ | Gold_m_pre_dist.onnx | [Gold-YOLO](https://github.com/huawei-noah/Efficient-Computing/tree/master/Detection/Gold-YOLO)-COCO | [gold_yolo_m.yaml](../../anylabeling/configs/auto_labeling/gold_yolo_m.yaml) | 169.88MB | [百度网盘](https://pan.baidu.com/s/1lY5VljoL9pZxadpe9DEUdQ?pwd=ephp) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.3.0/Gold_m_pre_dist.onnx) | | Gold_l_pre_dist.onnx | [Gold-YOLO](https://github.com/huawei-noah/Efficient-Computing/tree/master/Detection/Gold-YOLO)-COCO | [gold_yolo_l.yaml](../../anylabeling/configs/auto_labeling/gold_yolo_l.yaml) | 286.79MB | [百度网盘](https://pan.baidu.com/s/1ySxB3R18oWuIdYzKRLrflg?pwd=1wlk) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.3.0/Gold_l_pre_dist.onnx) | | rtdetr_r50vd_6x_coco.onnx | [RT-DETR](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/rtdetr/README.md)-COCO | [rtdetr_r50.yaml](../../anylabeling/configs/auto_labeling/rtdetr_r50.yaml) | 160.96MB | [百度网盘](https://pan.baidu.com/s/11vkwveTeZWwOFUpk8qrqVg?pwd=3sc4) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/rtdetr_r50vd_6x_coco.onnx) | +| rtdetrv2_r101vd_6x_coco.onnx | [RT-DETRv2-X](https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch)-COCO | [rtdetrv2x.yaml](../../anylabeling/configs/auto_labeling/rtdetrv2x.yaml) | 286.48MB | [百度网盘](https://pan.baidu.com/s/11lpaWt4IyFb39jJxKtlxnQ?pwd=t2kv) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r101vd_6x_coco.onnx) | +| rtdetrv2_r50vd_6x_coco.onnx | [RT-DETRv2-L](https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch)-COCO | [rtdetrv2l.yaml](../../anylabeling/configs/auto_labeling/rtdetrv2l.yaml) | 161.38MB | [百度网盘](https://pan.baidu.com/s/1ex1fzc18wnIa7YREBGD3_g?pwd=qrab) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r50vd_6x_coco.onnx) | +| rtdetrv2_r50vd_m_7x_coco.onnx | [RT-DETRv2-M*](https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch)-COCO | [rtdetrv2m7x.yaml](../../anylabeling/configs/auto_labeling/rtdetrv2m7x.yaml) | 126.52MB | [百度网盘](https://pan.baidu.com/s/1wzf6wJoysHfvAbgtv337oA?pwd=v8lw) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r50vd_m_7x_coco.onnx) | +| rtdetrv2_r34vd_120e_coco.onnx | [RT-DETRv2-M](https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch)-COCO | [rtdetrv2m.yaml](../../anylabeling/configs/auto_labeling/rtdetrv2m.yaml) | 119.73MB | [百度网盘](https://pan.baidu.com/s/1ByaV67J7yBKtznvwHvtRiQ?pwd=oeny) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r34vd_120e_coco.onnx) | +| rtdetrv2_r18vd_120e_coco.onnx | [RT-DETRv2-S](https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch)-COCO | [rtdetrv2s.yaml](../../anylabeling/configs/auto_labeling/rtdetrv2s.yaml) | 76.80MB | [百度网盘](https://pan.baidu.com/s/1jEuzrrn3P1GQcQ28qY-Q2Q?pwd=axm2) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/rtdetrv2_r18vd_120e_coco.onnx) | | yolo_nas_l.onnx | [YOLO-NAS](https://github.com/Deci-AI/super-gradients/tree/master)-COCO | [yolo_nas_l.yaml](../../anylabeling/configs/auto_labeling/yolo_nas_l.yaml) | 160.38MB | [百度网盘](https://pan.baidu.com/s/1ckVfkul8ZckiQd3PfM-Htw?pwd=wmk8) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolo_nas_l.onnx) | | yolo_nas_m.onnx | [YOLO-NAS](https://github.com/Deci-AI/super-gradients/tree/master)-COCO | [yolo_nas_m.yaml](../../anylabeling/configs/auto_labeling/yolo_nas_m.yaml) | 121.87MB | [百度网盘](https://pan.baidu.com/s/1bPH1m6dSxxG2KG3zrDywAA?pwd=1mlh) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolo_nas_m.onnx) | | yolo_nas_s.onnx | [YOLO-NAS](https://github.com/Deci-AI/super-gradients/tree/master)-COCO | [yolo_nas_s.yaml](../../anylabeling/configs/auto_labeling/yolo_nas_s.yaml) | 46.62MB | [百度网盘](https://pan.baidu.com/s/1yA_n6QxJGXoR38xlY25q_w?pwd=4wdw) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.1.0/yolo_nas_s.onnx) | @@ -88,10 +93,8 @@ | yolov8x-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | [yolov8x_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8x_oiv7.yaml) | 262.24MB | [百度网盘](https://pan.baidu.com/s/1in76UG8GhsXw6ACNqOVYKA?pwd=o07s) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8x-oiv7.onnx) | | yolov8l-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | [yolov8l_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8l_oiv7.yaml) | 168.28MB | [百度网盘](https://pan.baidu.com/s/1p9-oYvkV-IXvy6RlTxLWfQ?pwd=fs07) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8l-oiv7.onnx) | | yolov8m-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | [yolov8m_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8m_oiv7.yaml) | 100.05MB | [百度网盘](https://pan.baidu.com/s/1E7c-XIriTLH-gmBm5UmdWA?pwd=36n6) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8m-oiv7.onnx) | -| yolov8s-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | -[yolov8s_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8s_oiv7.yaml) | 43.47MB | [百度网盘](https://pan.baidu.com/s/1a6bP_x76Gk7LIf8uisWqfA?pwd=53lu) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8s-oiv7.onnx) | -| yolov8n-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | -[yolov8n_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8n_oiv7.yaml) | 13.47MB | [百度网盘](https://pan.baidu.com/s/1u8sUFr2LmBGwqZCbGb6Xww?pwd=l6ip) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8n-oiv7.onnx) | +| yolov8s-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | [yolov8s_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8s_oiv7.yaml) | 43.47MB | [百度网盘](https://pan.baidu.com/s/1a6bP_x76Gk7LIf8uisWqfA?pwd=53lu) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8s-oiv7.onnx) | +| yolov8n-oiv7.onnx | [YOLOv8](https://github.com/ultralytics/ultralytics)-Open Image V7 | [yolov8n_oiv7.yaml](../../anylabeling/configs/auto_labeling/yolov8n_oiv7.yaml) | 13.47MB | [百度网盘](https://pan.baidu.com/s/1u8sUFr2LmBGwqZCbGb6Xww?pwd=l6ip) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.7/yolov8n-oiv7.onnx) | | yolov9c.onnx | [YOLOv9](https://github.com/WongKinYiu/yolov9)-COCO | [yolov9c.yaml](../../anylabeling/configs/auto_labeling/yolov9c.yaml) | 195.34MB | [百度网盘](https://pan.baidu.com/s/1iHZH29HE8s4fG8X7aHcn5g?pwd=zb5w) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.2/yolov9c.onnx) | | yolov9e.onnx | [YOLOv9](https://github.com/WongKinYiu/yolov9)-COCO | [yolov9e.yaml](../../anylabeling/configs/auto_labeling/yolov9e.yaml) | 265.43MB | [百度网盘](https://pan.baidu.com/s/1oJJQ1L5UfKi43kT96tauSA?pwd=vpa1) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.2/yolov9e.onnx) | | gelan-c.onnx | [YOLOv9](https://github.com/WongKinYiu/yolov9)-COCO | [gelan-c.yaml](../../anylabeling/configs/auto_labeling/yolov9_gelan_c.yaml) | 97.43MB | [百度网盘](https://pan.baidu.com/s/1cM0Tc056ICuA5jvesacCng?pwd=whb3) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.2/gelan-c.onnx) |