From 0437e399fb4a3c1202abd00069c13d6bbb0e3982 Mon Sep 17 00:00:00 2001 From: cvhub Date: Tue, 3 Sep 2024 22:15:23 +0800 Subject: [PATCH] [Model] Added support for interactive video object tracking by SAM2 (#602) --- README.md | 14 +- README_zh-CN.md | 14 +- anylabeling/configs/auto_labeling/models.yaml | 8 + .../auto_labeling/sam2_hiera_base_video.yaml | 5 + .../auto_labeling/sam2_hiera_large_video.yaml | 5 + .../auto_labeling/sam2_hiera_small_video.yaml | 5 + .../auto_labeling/sam2_hiera_tiny_video.yaml | 5 + .../services/auto_labeling/model_manager.py | 60 ++- .../auto_labeling/segment_anything_2_video.py | 386 ++++++++++++++++++ anylabeling/views/labeling/label_widget.py | 15 +- .../widgets/auto_labeling/auto_labeling.py | 6 +- docs/en/model_zoo.md | 12 +- docs/en/user_guide.md | 19 +- docs/zh_cn/model_zoo.md | 12 +- docs/zh_cn/user_guide.md | 5 + .../README.md | 149 +++++++ 16 files changed, 689 insertions(+), 31 deletions(-) create mode 100644 anylabeling/configs/auto_labeling/sam2_hiera_base_video.yaml create mode 100644 anylabeling/configs/auto_labeling/sam2_hiera_large_video.yaml create mode 100644 anylabeling/configs/auto_labeling/sam2_hiera_small_video.yaml create mode 100644 anylabeling/configs/auto_labeling/sam2_hiera_tiny_video.yaml create mode 100644 anylabeling/services/auto_labeling/segment_anything_2_video.py create mode 100644 examples/interactive_video_object_segmentation/README.md diff --git a/README.md b/README.md index cc7e1727..2ea742ca 100644 --- a/README.md +++ b/README.md @@ -33,18 +33,20 @@ ## 🥳 What's New -- Aug. 2024: - - 🤗 Release the latest version [2.4.1](https://github.com/CVHub520/X-AnyLabeling/releases/tag/v2.4.1) 🤗 - - 🔥🔥🔥 Support [tracking-by-det/obb/seg/pose](./examples/multiple_object_tracking/README.md) tasks. - - ✨✨✨ Support [Segment-Anything-2](https://github.com/facebookresearch/segment-anything-2) model! (Recommended) - - 👏👏👏 Support [Grounding-SAM2](./docs/en/model_zoo.md) model. - - Support lightweight model for Japanese recognition. +- Sep. 2024: + - 🔥🔥🔥 Added support for interactive video object tracking based on [Segment-Anything-2](https://github.com/CVHub520/segment-anything-2). [[Tutorial](examples/interactive_video_object_segmentation/README.md)]
Click to view more news. +- Aug. 2024: + - Release version [2.4.1](https://github.com/CVHub520/X-AnyLabeling/releases/tag/v2.4.1) + - Support [tracking-by-det/obb/seg/pose](./examples/multiple_object_tracking/README.md) tasks. + - Support [Segment-Anything-2](https://github.com/facebookresearch/segment-anything-2) model! (Recommended) + - Support [Grounding-SAM2](./docs/en/model_zoo.md) model. + - Support lightweight model for Japanese recognition. - Jul. 2024: - Add PPOCR-Recognition and KIE import/export functionality for training PP-OCR task. - Add ODVG import/export functionality for training grounding task. diff --git a/README_zh-CN.md b/README_zh-CN.md index 52f2d8bf..98a3f13c 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -32,18 +32,20 @@ ## 🥳 新功能 -- 2024年8月: - - 🤗 发布[X-AnyLabeling v2.4.1](https://github.com/CVHub520/X-AnyLabeling/releases/tag/v2.4.1)最新版本 🤗 - - 🔥🔥🔥 支持[tracking-by-det/obb/seg/pose](./examples/multiple_object_tracking/README.md)任务。 - - ✨✨✨ 支持[Segment-Anything-2](https://github.com/facebookresearch/segment-anything-2)模型。 - - 👏👏👏 支持[Grounding-SAM2](./docs/zh_cn/model_zoo.md)模型。 - - 支持[日文字符识别](./anylabeling/configs/auto_labeling/japan_ppocr.yaml)模型。 +- 2024年9月: + - 🔥🔥🔥 支持基于[Segment-Anything-2](https://github.com/CVHub520/segment-anything-2)交互式视频目标追踪功能。【[教程](examples/interactive_video_object_segmentation/README.md)】
点击查看历史更新。 +- 2024年8月: + - 发布[X-AnyLabeling v2.4.1](https://github.com/CVHub520/X-AnyLabeling/releases/tag/v2.4.1)版本。 + - 支持[tracking-by-det/obb/seg/pose](./examples/multiple_object_tracking/README.md)任务。 + - 支持[Segment-Anything-2](https://github.com/facebookresearch/segment-anything-2)模型。 + - 支持[Grounding-SAM2](./docs/zh_cn/model_zoo.md)模型。 + - 支持[日文字符识别](./anylabeling/configs/auto_labeling/japan_ppocr.yaml)模型。 - 2024年7月: - 新增 PPOCR 识别和关键信息提取标签导入/导出功能。 - 新增 ODVG 标签导入/导出功能,以支持 Grounding 模型训练。 diff --git a/anylabeling/configs/auto_labeling/models.yaml b/anylabeling/configs/auto_labeling/models.yaml index 224cc963..257bf134 100644 --- a/anylabeling/configs/auto_labeling/models.yaml +++ b/anylabeling/configs/auto_labeling/models.yaml @@ -1,5 +1,7 @@ - model_name: "sam2_hiera_base-r20240801" config_file: ":/sam2_hiera_base.yaml" +- model_name: "sam2_hiera_large_video-r20240901" + config_file: ":/sam2_hiera_large_video.yaml" - model_name: "yolov5s-r20230520" config_file: ":/yolov5s.yaml" - model_name: "yolov5_car_plate-r20230112" @@ -120,6 +122,12 @@ config_file: ":/sam2_hiera_small.yaml" - model_name: "sam2_hiera_tiny-r20240801" config_file: ":/sam2_hiera_tiny.yaml" +- model_name: "sam2_hiera_base_video-r20240901" + config_file: ":/sam2_hiera_base_video.yaml" +- model_name: "sam2_hiera_small_video-r20240901" + config_file: ":/sam2_hiera_small_video.yaml" +- model_name: "sam2_hiera_tiny_video-r20240901" + config_file: ":/sam2_hiera_tiny_video.yaml" - model_name: "sam-hq_vit_b-r20231111" config_file: ":/sam_hq_vit_b.yaml" - model_name: "sam-hq_vit_h_quant-r20231111" diff --git a/anylabeling/configs/auto_labeling/sam2_hiera_base_video.yaml b/anylabeling/configs/auto_labeling/sam2_hiera_base_video.yaml new file mode 100644 index 00000000..270be3be --- /dev/null +++ b/anylabeling/configs/auto_labeling/sam2_hiera_base_video.yaml @@ -0,0 +1,5 @@ +type: segment_anything_2_video +name: sam2_hiera_base_video-r20240901 +display_name: Segment Anything 2 Video (Base) +model_cfg: sam2_hiera_b+.yaml +model_path: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt \ No newline at end of file diff --git a/anylabeling/configs/auto_labeling/sam2_hiera_large_video.yaml b/anylabeling/configs/auto_labeling/sam2_hiera_large_video.yaml new file mode 100644 index 00000000..c1577e09 --- /dev/null +++ b/anylabeling/configs/auto_labeling/sam2_hiera_large_video.yaml @@ -0,0 +1,5 @@ +type: segment_anything_2_video +name: sam2_hiera_large_video-r20240901 +display_name: Segment Anything 2 Video (Large) +model_cfg: sam2_hiera_l.yaml +model_path: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt \ No newline at end of file diff --git a/anylabeling/configs/auto_labeling/sam2_hiera_small_video.yaml b/anylabeling/configs/auto_labeling/sam2_hiera_small_video.yaml new file mode 100644 index 00000000..7c5a484a --- /dev/null +++ b/anylabeling/configs/auto_labeling/sam2_hiera_small_video.yaml @@ -0,0 +1,5 @@ +type: segment_anything_2_video +name: sam2_hiera_small_video-r20240901 +display_name: Segment Anything 2 Video (Small) +model_cfg: sam2_hiera_s.yaml +model_path: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt \ No newline at end of file diff --git a/anylabeling/configs/auto_labeling/sam2_hiera_tiny_video.yaml b/anylabeling/configs/auto_labeling/sam2_hiera_tiny_video.yaml new file mode 100644 index 00000000..56fe2267 --- /dev/null +++ b/anylabeling/configs/auto_labeling/sam2_hiera_tiny_video.yaml @@ -0,0 +1,5 @@ +type: segment_anything_2_video +name: sam2_hiera_tiny_video-r20240901 +display_name: Segment Anything 2 Video (Tiny) +model_cfg: sam2_hiera_t.yaml +model_path: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt \ No newline at end of file diff --git a/anylabeling/services/auto_labeling/model_manager.py b/anylabeling/services/auto_labeling/model_manager.py index f7852a73..0c8e6709 100644 --- a/anylabeling/services/auto_labeling/model_manager.py +++ b/anylabeling/services/auto_labeling/model_manager.py @@ -21,6 +21,7 @@ class ModelManager(QObject): CUSTOM_MODELS = [ "segment_anything", "segment_anything_2", + "segment_anything_2_video" "sam_med2d", "sam_hq", "yolov5", @@ -967,6 +968,29 @@ def _load_model(self, model_id): return # Request next files for prediction self.request_next_files_requested.emit() + elif model_config["type"] == "segment_anything_2_video": + try: + from .segment_anything_2_video import SegmentAnything2Video + model_config["model"] = SegmentAnything2Video( + model_config, on_message=self.new_model_status.emit + ) + self.auto_segmentation_model_selected.emit() + except Exception as e: # noqa + print( + "Error in loading model: {error_message}".format( + error_message=str(e) + ) + ) + self.new_model_status.emit( + self.tr( + "Error in loading model: {error_message}".format( + error_message=str(e) + ) + ) + ) + return + # Request next files for prediction + self.request_next_files_requested.emit() elif model_config["type"] == "efficientvit_sam": from .efficientvit_sam import EfficientViT_SAM @@ -1472,6 +1496,7 @@ def set_auto_labeling_marks(self, marks): marks_model_list = [ "segment_anything", "segment_anything_2", + "segment_anything_2_video", "sam_med2d", "sam_hq", "yolov5_sam", @@ -1498,6 +1523,7 @@ def set_auto_labeling_reset_tracker(self): "yolov8_obb_track", "yolov8_seg_track", "yolov8_pose_track", + "segment_anything_2_video", ] if ( self.loaded_model_config is None @@ -1606,13 +1632,23 @@ def set_auto_labeling_preserve_existing_annotations_state(self, state): "model" ].set_auto_labeling_preserve_existing_annotations_state(state) + def set_auto_labeling_prompt(self): + model_list = ['segment_anything_2_video'] + if ( + self.loaded_model_config is not None + and self.loaded_model_config["type"] in model_list + ): + self.loaded_model_config[ + "model" + ].set_auto_labeling_prompt() + def unload_model(self): """Unload model""" if self.loaded_model_config is not None: self.loaded_model_config["model"].unload() self.loaded_model_config = None - def predict_shapes(self, image, filename=None, text_prompt=None): + def predict_shapes(self, image, filename=None, text_prompt=None, run_tracker=False): """Predict shapes. NOTE: This function is blocking. The model can take a long time to predict. So it is recommended to use predict_shapes_threading instead. @@ -1624,14 +1660,18 @@ def predict_shapes(self, image, filename=None, text_prompt=None): self.prediction_finished.emit() return try: - if text_prompt is None: + if text_prompt is not None: auto_labeling_result = self.loaded_model_config[ "model" - ].predict_shapes(image, filename) + ].predict_shapes(image, filename, text_prompt=text_prompt) + elif run_tracker is True: + auto_labeling_result = self.loaded_model_config[ + "model" + ].predict_shapes(image, filename, run_tracker=run_tracker) else: auto_labeling_result = self.loaded_model_config[ "model" - ].predict_shapes(image, filename, text_prompt) + ].predict_shapes(image, filename) self.new_auto_labeling_result.emit(auto_labeling_result) self.new_model_status.emit( self.tr("Finished inferencing AI model. Check the result.") @@ -1646,7 +1686,7 @@ def predict_shapes(self, image, filename=None, text_prompt=None): self.prediction_finished.emit() @pyqtSlot() - def predict_shapes_threading(self, image, filename=None, text_prompt=None): + def predict_shapes_threading(self, image, filename=None, text_prompt=None, run_tracker=False): """Predict shapes. This function starts a thread to run the prediction. """ @@ -1675,13 +1715,17 @@ def predict_shapes_threading(self, image, filename=None, text_prompt=None): return self.model_execution_thread = QThread() - if text_prompt is None: + if text_prompt is not None: self.model_execution_worker = GenericWorker( - self.predict_shapes, image, filename + self.predict_shapes, image, filename, text_prompt=text_prompt + ) + elif run_tracker is True: + self.model_execution_worker = GenericWorker( + self.predict_shapes, image, filename, run_tracker=run_tracker ) else: self.model_execution_worker = GenericWorker( - self.predict_shapes, image, filename, text_prompt + self.predict_shapes, image, filename ) self.model_execution_worker.finished.connect( self.model_execution_thread.quit diff --git a/anylabeling/services/auto_labeling/segment_anything_2_video.py b/anylabeling/services/auto_labeling/segment_anything_2_video.py new file mode 100644 index 00000000..824ffcaf --- /dev/null +++ b/anylabeling/services/auto_labeling/segment_anything_2_video.py @@ -0,0 +1,386 @@ +import logging +import os +import traceback + +import warnings +warnings.filterwarnings('ignore') + +import cv2 +import numpy as np +from PyQt5 import QtCore +from PyQt5.QtCore import QCoreApplication + +from anylabeling.app_info import __preferred_device__ +from anylabeling.views.labeling.shape import Shape +from anylabeling.views.labeling.utils.opencv import qt_img_to_rgb_cv_img + +from .model import Model +from .types import AutoLabelingResult + +import torch +from sam2.build_sam import build_sam2, build_sam2_camera_predictor +from sam2.sam2_image_predictor import SAM2ImagePredictor + + +class SegmentAnything2Video(Model): + """Segmentation model using SegmentAnything2 for video processing. + + This class provides methods to perform image segmentation on video frames + using the SegmentAnything2 model. It supports interactive marking and + tracking of objects across frames. + """ + + class Meta: + """Meta class to define required configurations and UI elements.""" + required_config_names = [ + "type", + "name", + "display_name", + "model_cfg", + "model_path", + ] + widgets = [ + "output_label", + "output_select_combobox", + "button_add_point", + "button_remove_point", + "button_add_rect", + "button_clear", + "button_finish_object", + "button_reset_tracker", + ] + output_modes = { + "polygon": QCoreApplication.translate("Model", "Polygon"), + "rectangle": QCoreApplication.translate("Model", "Rectangle"), + "rotation": QCoreApplication.translate("Model", "Rotation"), + } + default_output_mode = "polygon" + + def __init__(self, config_path, on_message) -> None: + """Initialize the segmentation model with given configuration. + + Args: + config_path (str): Path to the configuration file. + on_message (callable): Callback for logging messages. + """ + super().__init__(config_path, on_message) + + # Enable automatic mixed precision for faster computations + torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + + if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs + # (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # Load the SAM2 predictor models + model_abs_path = self.get_model_abs_path( + self.config, "model_path" + ) + if not model_abs_path or not os.path.isfile( + model_abs_path + ): + raise FileNotFoundError( + QCoreApplication.translate( + "Model", + "Could not download or initialize model of Segment Anything 2.", + ) + ) + model_cfg = self.config['model_cfg'] + sam2_image_model = build_sam2(model_cfg, model_abs_path) + self.image_predictor = SAM2ImagePredictor(sam2_image_model) + self.video_predictor = build_sam2_camera_predictor(model_cfg, model_abs_path) + self.is_first_init = True + + # Initialize marking and prompting structures + self.marks = [] + self.prompts = [] + + def set_auto_labeling_marks(self, marks): + """Set marks for auto labeling. + + Args: + marks (list): List of marks (points or rectangles). + """ + self.marks = marks + + def set_auto_labeling_reset_tracker(self): + """Reset the tracker to its initial state.""" + self.is_first_init = True + self.prompts = [] + try: + self.video_predictor.reset_state() + except Exception as e: # noqa + print(f'An error occurred while resetting the tracker: {e}') + + def set_auto_labeling_prompt(self): + """Convert marks to prompts for the model.""" + point_coords, point_labels, box = self.marks_to_prompts() + if box: + promot = { + 'type': 'rectangle', + 'data': np.array([[*box[:2]], [*box[2:]]], dtype=np.float32) + } + self.prompts.append(promot) + elif (point_coords and point_labels): + promot = { + 'type': 'point', + 'data': { + 'point_coords': np.array(point_coords, dtype=np.float32), + 'point_labels': np.array(point_labels, dtype=np.int32), + } + } + self.prompts.append(promot) + + def marks_to_prompts(self): + """Convert marks to prompts for the model.""" + point_coords, point_labels, box = None, None, None + for marks in self.marks: + if marks['type'] == 'rectangle': + box = marks['data'] + elif marks['type'] == 'point': + if point_coords is None and point_labels is None: + point_coords = [marks['data']] + point_labels = [marks['label']] + else: + point_coords.append(marks['data']) + point_labels.append(marks['label']) + return point_coords, point_labels, box + + def post_process(self, masks, label=None): + """Post-process the masks produced by the model. + + Args: + masks (np.array): The masks to post-process. + label (str, optional): Label for the masks. Defaults to None. + + Returns: + list: A list of Shape objects representing the masks. + """ + # Convert masks to binary format + masks[masks > 0.0] = 255 + masks[masks <= 0.] = 0 + masks = masks.astype(np.uint8) + + # Find contours of the masks + contours, _ = cv2.findContours( + masks, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE + ) + + # Refine and filter contours + approx_contours = [] + for contour in contours: + # Approximate contour + epsilon = 0.001 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, epsilon, True) + approx_contours.append(approx) + + # Remove large contours (likely background) + if len(approx_contours) > 1: + image_size = masks.shape[0] * masks.shape[1] + areas = [cv2.contourArea(contour) for contour in approx_contours] + filtered_approx_contours = [ + contour + for contour, area in zip(approx_contours, areas) + if area < image_size * 0.9 + ] + + # Remove small contours (likely noise) + if len(approx_contours) > 1: + areas = [cv2.contourArea(contour) for contour in approx_contours] + avg_area = np.mean(areas) + + filtered_approx_contours = [ + contour + for contour, area in zip(approx_contours, areas) + if area > avg_area * 0.2 + ] + approx_contours = filtered_approx_contours + + # Convert contours to shapes + shapes = [] + if self.output_mode == "polygon": + for approx in approx_contours: + # Scale points + points = approx.reshape(-1, 2) + points[:, 0] = points[:, 0] + points[:, 1] = points[:, 1] + points = points.tolist() + if len(points) < 3: + continue + points.append(points[0]) + shape = Shape(flags={}) + for point in points: + point[0] = int(point[0]) + point[1] = int(point[1]) + shape.add_point(QtCore.QPointF(point[0], point[1])) + shape.shape_type = "polygon" + shape.closed = True + shape.label = "AUTOLABEL_OBJECT" if label is None else label + shape.selected = False + shapes.append(shape) + elif self.output_mode in ["rectangle", "rotation"]: + x_min = 100000000 + y_min = 100000000 + x_max = 0 + y_max = 0 + for approx in approx_contours: + # Scale points + points = approx.reshape(-1, 2) + points[:, 0] = points[:, 0] + points[:, 1] = points[:, 1] + points = points.tolist() + if len(points) < 3: + continue + # Get min/max + for point in points: + x_min = min(x_min, point[0]) + y_min = min(y_min, point[1]) + x_max = max(x_max, point[0]) + y_max = max(y_max, point[1]) + # Create shape + shape = Shape(flags={}) + shape.add_point(QtCore.QPointF(x_min, y_min)) + shape.add_point(QtCore.QPointF(x_max, y_min)) + shape.add_point(QtCore.QPointF(x_max, y_max)) + shape.add_point(QtCore.QPointF(x_min, y_max)) + shape.shape_type = ( + "rectangle" if self.output_mode == "rectangle" else "rotation" + ) + shape.closed = True + shape.label = "AUTOLABEL_OBJECT" if label is None else label + shape.selected = False + shapes.append(shape) + + return shapes + + def image_process(self, rgb_image): + """Process a single image using the SAM2 predictor. + + Args: + rgb_image (np.array): The RGB image to process. + + Returns: + list: A list of Shape objects representing the segmented regions. + """ + self.image_predictor.set_image(rgb_image) + + # prompt SAM 2 image predictor to get the mask for the object + point_coords, point_labels, box = self.marks_to_prompts() + if not box and not (point_coords and point_labels): + return [] + masks, _, _ = self.image_predictor.predict( + point_coords=point_coords, + point_labels=point_labels, + box=box, + multimask_output=False, + ) + + if len(masks.shape) == 4: + masks = masks[0][0] + else: + masks = masks[0] + shapes = self.post_process(masks) + return shapes + + def video_process(self, cv_image, filename): + """Process a video frame using the SAM2 predictor. + + Args: + cv_image (np.array): The OpenCV image to process. + filename (str): The filename of the image. + + Returns: + tuple: A tuple containing a list of Shape objects and a boolean indicating if the frame was replaced. + """ + if not self.prompts: + return [] + + if self.is_first_init: + self.video_predictor.load_first_frame(cv_image) + ann_frame_idx = self.get_ann_frame_idx(filename) # the frame index we interact with + if ann_frame_idx == -1: + print(f"No .jpg or .jpeg files found in the directory.") + return [], False + for i, prompt in enumerate(self.prompts): + ann_obj_id = i + 1 # give a unique id to each object we interact with (it can be any integers) + if prompt['type'] == 'rectangle': + bbox = prompt['data'] + _, out_obj_ids, out_mask_logits = self.video_predictor.add_new_prompt( + frame_idx=ann_frame_idx, obj_id=ann_obj_id, bbox=bbox + ) + elif prompt['type'] == 'point': + points = prompt['data']['point_coords'] + labels = prompt['data']['point_labels'] + _, out_obj_ids, out_mask_logits = self.video_predictor.add_new_prompt( + frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels + ) + self.is_first_init = False + return [], False + else: + shapes = [] + out_obj_ids, out_mask_logits = self.video_predictor.track(cv_image) + for i in range(0, len(out_obj_ids)): + masks = out_mask_logits[i].cpu().numpy() + if len(masks.shape) == 4: + masks = masks[0][0] + else: + masks = masks[0] + shapes.extend(self.post_process(masks, label=f'object{i}')) + return shapes, True + + def predict_shapes(self, image, filename=None, run_tracker=False) -> AutoLabelingResult: + """Predict shapes from an image or video frame. + + Args: + image (QtImage): The image to process. + filename (str, optional): The filename of the image. Required for video processing. Defaults to None. + run_tracker (bool, optional): Whether to run the tracker. Defaults to False. + + Returns: + AutoLabelingResult: The result containing the predicted shapes and a flag indicating if the frame was replaced. + """ + if image is None or not self.marks: + return AutoLabelingResult([], replace=False) + + shapes = [] + cv_image = qt_img_to_rgb_cv_img(image, filename) + try: + if run_tracker is True: + shapes, replace = self.video_process(cv_image, filename) + result = AutoLabelingResult(shapes, replace=replace) + else: + shapes = self.image_process(cv_image) + result = AutoLabelingResult(shapes, replace=False) + except Exception as e: # noqa + logging.warning("Could not inference model") + logging.warning(e) + traceback.print_exc() + return AutoLabelingResult([], replace=False) + + return result + + @staticmethod + def get_ann_frame_idx(filename): + """Get the annotation frame index for a given filename. + + Args: + filename (str): The filename of the image. + + Returns: + int: The index of the frame in the sorted list of frames, or -1 if not found. + """ + frame_names = [ + p for p in os.listdir(os.path.dirname(filename)) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + if not frame_names: + return -1 + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + return frame_names.index(os.path.basename(filename)) + + def unload(self): + """Unload the model and predictors.""" + del self.image_predictor + del self.video_predictor diff --git a/anylabeling/views/labeling/label_widget.py b/anylabeling/views/labeling/label_widget.py index 9d46dca3..5021f4c3 100644 --- a/anylabeling/views/labeling/label_widget.py +++ b/anylabeling/views/labeling/label_widget.py @@ -5783,6 +5783,7 @@ def run_all_images(self): self.current_index = self.image_list.index(self.filename) self.image_index = self.current_index self.text_prompt = "" + self.run_tracker = False if self.auto_labeling_widget.model_manager.loaded_model_config[ "type" ] in [ @@ -5794,6 +5795,13 @@ def run_all_images(self): self.text_prompt = text_input_dialog.get_input_text() if self.text_prompt: self.show_progress_dialog_and_process() + elif self.auto_labeling_widget.model_manager.loaded_model_config[ + "type" + ] in [ + "segment_anything_2_video", + ]: + self.run_tracker = True + self.show_progress_dialog_and_process() else: self.show_progress_dialog_and_process() @@ -5830,7 +5838,11 @@ def process_next_image(self, progress_dialog): self.load_file(self.filename) if self.text_prompt: self.auto_labeling_widget.model_manager.predict_shapes( - self.image, self.filename, self.text_prompt + self.image, self.filename, text_prompt=self.text_prompt + ) + elif self.run_tracker: + self.auto_labeling_widget.model_manager.predict_shapes( + self.image, self.filename, run_tracker=self.run_tracker ) else: self.auto_labeling_widget.model_manager.predict_shapes( @@ -5858,6 +5870,7 @@ def finish_processing(self, progress_dialog): self.filename = self.image_list[self.current_index] self.load_file(self.filename) del self.text_prompt + del self.run_tracker del self.image_index del self.current_index progress_dialog.close() diff --git a/anylabeling/views/labeling/widgets/auto_labeling/auto_labeling.py b/anylabeling/views/labeling/widgets/auto_labeling/auto_labeling.py index bed530f4..9e90a177 100644 --- a/anylabeling/views/labeling/widgets/auto_labeling/auto_labeling.py +++ b/anylabeling/views/labeling/widgets/auto_labeling/auto_labeling.py @@ -106,6 +106,7 @@ def set_enable_tools(enable): self.clear_auto_labeling_action_requested ) self.button_clear.setShortcut("B") + self.button_finish_object.clicked.connect(self.add_new_prompt) self.button_finish_object.clicked.connect( self.finish_auto_labeling_object_action_requested ) @@ -205,7 +206,7 @@ def run_vl_prediction(self): """Run visual-language prediction""" if self.parent.filename is not None and self.edit_text: self.model_manager.predict_shapes_threading( - self.parent.image, self.parent.filename, self.edit_text.text() + self.parent.image, self.parent.filename, text_prompt=self.edit_text.text() ) def unload_and_hide(self): @@ -350,3 +351,6 @@ def on_preserve_existing_annotations_state_changed(self, state): def on_reset_tracker(self): self.model_manager.set_auto_labeling_reset_tracker() + + def add_new_prompt(self): + self.model_manager.set_auto_labeling_prompt() diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md index 91325c8b..5cb1ad10 100644 --- a/docs/en/model_zoo.md +++ b/docs/en/model_zoo.md @@ -258,4 +258,14 @@ | depth_anything_vitl14.onnx | [DepthAnything](https://github.com/LiheYoung/Depth-Anything.git) | [depth_anything_vit_l.yaml](../../anylabeling/configs/auto_labeling/depth_anything_vit_l.yaml) | 1.25GB | [baidu](https://pan.baidu.com/s/1MeEcbyJa6ysoGzK-8EjqYw?pwd=p3j4) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.1/depth_anything_vitl14.onnx) | | depth_anything_v2_vits.onnx | [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) | [depth_anything_v2_vit_s.yaml](../../anylabeling/configs/auto_labeling/depth_anything_v2_vit_s.yaml) | 94.77MB | [baidu](https://pan.baidu.com/s/1mO8UEWAEgYW2_bDnQFSVpQ?pwd=3sf0) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/depth_anything_v2_vits.onnx) | | depth_anything_v2_vitb.onnx | [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) | [depth_anything_v2_vit_b.yaml](../../anylabeling/configs/auto_labeling/depth_anything_v2_vit_b.yaml) | 371.20MB | [baidu](https://pan.baidu.com/s/1wo8xYJiuMjie5THPjr4DWg?pwd=kcal) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/depth_anything_v2_vitb.onnx) | -| depth_anything_v2_vitl.onnx | [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) | [depth_anything_v2_vit_l.yaml](../../anylabeling/configs/auto_labeling/depth_anything_v2_vit_l.yaml) | 1.25GB | [baidu](https://pan.baidu.com/s/134WYgOdhzWeyap_xk0rBhw?pwd=cnqt) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/depth_anything_v2_vitl.onnx) | \ No newline at end of file +| depth_anything_v2_vitl.onnx | [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) | [depth_anything_v2_vit_l.yaml](../../anylabeling/configs/auto_labeling/depth_anything_v2_vit_l.yaml) | 1.25GB | [baidu](https://pan.baidu.com/s/134WYgOdhzWeyap_xk0rBhw?pwd=cnqt) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/depth_anything_v2_vitl.onnx) | + + +### Interactive Video Object Segmentation + +|Name|Description|Configuration|Size|Link| +| --- | --- | --- | --- | --- | +| sam2_hiera_tiny.pt | [SAM 2](https://github.com/CVHub520/segment-anything-2) | [sam2_hiera_tiny_video.yaml](../../anylabeling/configs/auto_labeling/sam2_hiera_tiny_video.yaml) | 148.68MB | [baidu](https://pan.baidu.com/s/1sIiVHQb5MMrK49x5ngB8ig?pwd=q38g) \| [github](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt) | +| sam2_hiera_small.pt | [SAM 2](https://github.com/CVHub520/segment-anything-2) | [sam2_hiera_small_video.yaml](../../anylabeling/configs/auto_labeling/sam2_hiera_small_video.yaml) | 175.77MB | [baidu](https://pan.baidu.com/s/17P5PvRUlT0xYr4wlqqeKVw?pwd=rsim) \| [github](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt) | +| sam2_hiera_base_plus.pt | [SAM 2](https://github.com/CVHub520/segment-anything-2) | [sam2_hiera_base_video.yaml](../../anylabeling/configs/auto_labeling/sam2_hiera_base_video.yaml) | 308.51MB | [baidu](https://pan.baidu.com/s/1nP-ysdw1t7JgUQtIv-adNA?pwd=44u4) \| [github](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt) | +| sam2_hiera_large.pt | [SAM 2](https://github.com/CVHub520/segment-anything-2) | [sam2_hiera_large_video.yaml](../../anylabeling/configs/auto_labeling/sam2_hiera_large_video.yaml) | 856.35MB | [baidu](https://pan.baidu.com/s/1hFYQc1IKnEw2l-vW6Eoi2w?pwd=gyby) \| [github](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt) | diff --git a/docs/en/user_guide.md b/docs/en/user_guide.md index 66ff40e0..31de28f3 100644 --- a/docs/en/user_guide.md +++ b/docs/en/user_guide.md @@ -49,6 +49,7 @@ * [8.5 Multi-Object Tracking](#85-multi-object-tracking) * [8.6 Depth Estimation](#86-depth-estimation) * [8.7 Optical Character Recognition](#87-optical-character-recognition) + * [8.8 Interactive Video Object Segmentation](#88-interactive-video-object-segmentation) * [9. Models](#9-models) @@ -670,31 +671,31 @@ In X-AnyLabeling v2.4.0 and above, the **hover auto-highlight mode** feature is Note: In `multi-label classification tasks`, if the user manually uploads a property file, the `auto_highlight_shape` field will be set to `false` to prevent accidental switching of the property window status bar, thus improving user experience. -### 8. Tasks +## 8. Tasks -#### 8.1 Image Classification +### 8.1 Image Classification - Image-level classification: [Link](../../examples/classification/image-level/README.md) - Object-level classification: [Link](../../examples/classification/shape-level/README.md) -#### 8.2 Object Detection +### 8.2 Object Detection - Horizontal Bounding Box Detection: [Link](../../examples/detection/hbb/README.md) - Oriented Bounding Box Detection: [Link](../../examples/detection/obb/README.md) -#### 8.3 Image Segmentation +### 8.3 Image Segmentation - Semantic & Instance Segmentation: [Link](../../examples/segmentation/README.md) -#### 8.4 Pose Estimation +### 8.4 Pose Estimation - Keypoint Detection: [Link](../../examples/estimation/pose_estimation/README.md) -#### 8.5 Multi-Object Tracking +### 8.5 Multi-Object Tracking - Multi-Object Tracking: [Link](../../examples/multiple_object_tracking/README.md) -#### 8.6 Depth Estimation +### 8.6 Depth Estimation - Depth Estimation: [Link](../../examples/estimation/depth_estimation/README.md) @@ -703,6 +704,10 @@ Note: In `multi-label classification tasks`, if the user manually uploads a prop - Text Detection and Recognition: [Link](../../examples/optical_character_recognition/text_recognition/README.md) - Key Information Extraction: [Link](../../examples/optical_character_recognition/kie/README.md) +### 8.8 Interactive Video Object Segmentation + +- Interactive Video Object Segmentation: [Link](../../examples/interactive_video_object_segmentation/README.md) + ## 9. Models diff --git a/docs/zh_cn/model_zoo.md b/docs/zh_cn/model_zoo.md index 14db28b6..c88233b3 100644 --- a/docs/zh_cn/model_zoo.md +++ b/docs/zh_cn/model_zoo.md @@ -256,4 +256,14 @@ | depth_anything_vitl14.onnx | [DepthAnything](https://github.com/LiheYoung/Depth-Anything.git) | [depth_anything_vit_l.yaml](../../anylabeling/configs/auto_labeling/depth_anything_vit_l.yaml) | 1.25GB | [百度网盘](https://pan.baidu.com/s/1MeEcbyJa6ysoGzK-8EjqYw?pwd=p3j4) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.3.1/depth_anything_vitl14.onnx) | | depth_anything_v2_vits.onnx | [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) | [depth_anything_v2_vit_s.yaml](../../anylabeling/configs/auto_labeling/depth_anything_v2_vit_s.yaml) | 94.77MB | [百度网盘](https://pan.baidu.com/s/1mO8UEWAEgYW2_bDnQFSVpQ?pwd=3sf0) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/depth_anything_v2_vits.onnx) | | depth_anything_v2_vitb.onnx | [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) | [depth_anything_v2_vit_b.yaml](../../anylabeling/configs/auto_labeling/depth_anything_v2_vit_b.yaml) | 371.20MB | [百度网盘](https://pan.baidu.com/s/1wo8xYJiuMjie5THPjr4DWg?pwd=kcal) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/depth_anything_v2_vitb.onnx) | -| depth_anything_v2_vitl.onnx | [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) | [depth_anything_v2_vit_l.yaml](../../anylabeling/configs/auto_labeling/depth_anything_v2_vit_l.yaml) | 1.25GB | [百度网盘](https://pan.baidu.com/s/134WYgOdhzWeyap_xk0rBhw?pwd=cnqt) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/depth_anything_v2_vitl.onnx) | \ No newline at end of file +| depth_anything_v2_vitl.onnx | [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) | [depth_anything_v2_vit_l.yaml](../../anylabeling/configs/auto_labeling/depth_anything_v2_vit_l.yaml) | 1.25GB | [百度网盘](https://pan.baidu.com/s/134WYgOdhzWeyap_xk0rBhw?pwd=cnqt) \| [github](https://github.com/CVHub520/X-AnyLabeling/releases/download/v2.4.0/depth_anything_v2_vitl.onnx) | + + +### 交互式视频目标分割 + +|名称|描述|配置|大小|链接| +| --- | --- | --- | --- | --- | +| sam2_hiera_tiny.pt | [SAM 2](https://github.com/CVHub520/segment-anything-2) | [sam2_hiera_tiny_video.yaml](../../anylabeling/configs/auto_labeling/sam2_hiera_tiny_video.yaml) | 148.68MB | [百度网盘](https://pan.baidu.com/s/1sIiVHQb5MMrK49x5ngB8ig?pwd=q38g) \| [github](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt) | +| sam2_hiera_small.pt | [SAM 2](https://github.com/CVHub520/segment-anything-2) | [sam2_hiera_small_video.yaml](../../anylabeling/configs/auto_labeling/sam2_hiera_small_video.yaml) | 175.77MB | [百度网盘](https://pan.baidu.com/s/17P5PvRUlT0xYr4wlqqeKVw?pwd=rsim) \| [github](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt) | +| sam2_hiera_base_plus.pt | [SAM 2](https://github.com/CVHub520/segment-anything-2) | [sam2_hiera_base_video.yaml](../../anylabeling/configs/auto_labeling/sam2_hiera_base_video.yaml) | 308.51MB | [百度网盘](https://pan.baidu.com/s/1nP-ysdw1t7JgUQtIv-adNA?pwd=44u4) \| [github](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt) | +| sam2_hiera_large.pt | [SAM 2](https://github.com/CVHub520/segment-anything-2) | [sam2_hiera_large_video.yaml](../../anylabeling/configs/auto_labeling/sam2_hiera_large_video.yaml) | 856.35MB | [百度网盘](https://pan.baidu.com/s/1hFYQc1IKnEw2l-vW6Eoi2w?pwd=gyby) \| [github](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt) | \ No newline at end of file diff --git a/docs/zh_cn/user_guide.md b/docs/zh_cn/user_guide.md index 644dce80..a0dceee8 100644 --- a/docs/zh_cn/user_guide.md +++ b/docs/zh_cn/user_guide.md @@ -49,6 +49,7 @@ * [8.5 多目标跟踪](#85-多目标跟踪) * [8.6 深度估计](#86-深度估计) * [8.7 光学字符识别](#87-光学字符识别) + * [8.8 交互式视频目标分割](#88-交互式视频目标分割) * [9. 模型](#9-模型) ## 1. 文件 @@ -707,6 +708,10 @@ labels: - 文本检测与识别:[链接](../../examples/optical_character_recognition/text_recognition/README.md) - 关键信息提取:[链接](../../examples/optical_character_recognition/kie/README.md) +### 8.8 交互式视频目标分割 + +- 交互式视频目标分割: [链接](../../examples/interactive_video_object_segmentation/README.md) + ## 9. 模型 diff --git a/examples/interactive_video_object_segmentation/README.md b/examples/interactive_video_object_segmentation/README.md new file mode 100644 index 00000000..79e25b7e --- /dev/null +++ b/examples/interactive_video_object_segmentation/README.md @@ -0,0 +1,149 @@ +# Interactive Video Object Segmentation Example + +## Introduction + +**Interactive Video Object Segmentation (iVOS)** has become an essential task for efficiently obtaining object segmentations in videos, often guided by user inputs like scribbles, clicks, or bounding boxes. In this tutorial, you'll learn how to leverage the video tracking feature of [SAM2](https://github.com/facebookresearch/segment-anything-2) on X-AnyLabeling to accomplish iVOS tasks. + + + +Let's get started! + +## Installation + +Before you begin, make sure you have the following prerequisites installed: + +**Step 0:** Download and install Miniconda from the [official website](https://docs.anaconda.com/miniconda/). + +**Step 1:** Create a new Conda environment with Python version `3.10` or higher, and activate it: + +```bash +conda create -n x-anylabeling-sam2 python=3.10 -y +conda activate x-anylabeling-sam2 +``` + +You'll need to install SAM2 first. The code requires `torch>=2.3.1` and `torchvision>=0.18.1`. Follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. + +Afterward, you can install SAM2 on a GPU-enabled machine using: + +```bash +git clone https://github.com/CVHub520/segment-anything-2 +cd segment-anything-2 +pip install -e . +``` + +Finally, install the necessary dependencies for X-AnyLabeling (v2.4.2+): + +```bash +cd .. +git clone https://github.com/CVHub520/X-AnyLabeling +cd X-AnyLabeling + +# For Windows or Linux +pip install -r requirements.txt + +# For macOS +pip install -r requirements-macos.txt +conda install -c conda-forge pyqt=5.15.9 +``` + +## Getting Started + +### Prerequisites + +**Step 0:** Launch the app: + +```bash +python3 anylabeling/app.py +``` + +**Step 1:** Load the SAM 2 Video model + +![Load-Model](https://github.com/user-attachments/assets/8c3e0593-ccb5-45a8-bb61-73f4b9f5f82f) + +
+Note: If the model fails to load due to network issues, please refer to the following settings. + +First, you'll need to download a model checkpoint. For this tutorial, we'll use the [sam2_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt) checkpoint as an example. + +After downloading, place the checkpoint file in the corresponding model folder within your user directory (create the folder if it doesn't exist): + +```bash +# Windows +C:\Users\${User}\xanylabeling_data\models\sam2_hiera_large_video-r20240901 + +# Linux or macOS +~/xanylabeling_data/models/sam2_hiera_large_video-r20240901 +``` + +Additionally, if you want to use other sizes of SAM2 models or modify the model loading path, refer to this documentation for custom settings: [简体中文](../../docs/zh_cn/custom_model.md) | [English](../../docs/en/custom_model.md). + +
+ +**Step 2:** Add a video file (Ctrl + O) or a folder of split video frames (Ctrl + U). + +> [!NOTE] +> As of now, the supported file formats are limited to [*.jpg, *.jpeg, *.JPG, *.JPEG]. When loading video files, they will be automatically converted to jpg format by default. + + +### Usage + +**Step 0:** Add Prompts + + + +> [!TIP] +> - **Point (q):** Add a positive point. +> - **Point (e):** Add a negative point. +> - **+Rect:** Draw a rectangle around the object. +> - **Clear (b):** Erase all added marks. +> - **Finish Object (f):** Confirm the object. + +For the initial frame, you can add prompts such as positive points, negative points, and rectangles (Marks) to guide the tracking of the desired object. Follow these steps: + +1. If the segmentation result meets your expectations, click the `Finish Object (f)` button at the top of the screen or press the `f` key to confirm the object. If not, click the `Clear (b)` button or press the `b` key to quickly clear any invalid marks. +2. We strongly recommend assigning labels like `object0`, `object1`, ..., `objectN` to each added target sequentially. + +> [!WARNING] +> If you need to delete a confirmed object, follow these steps:
+> a. Open the edit mode (Ctrl + J) and remove all added objects from the current frame;
+> b. Click the `Reset Tracker` button at the top of the screen to reset the tracker;
+> c. Reapply the prompts (Marks) as described above. + +![rectangle_tracklet](https://github.com/user-attachments/assets/1dbe1d41-1792-4c45-9ea0-51c26a08c6af) + +Alternatively, if you only want to set up object detection tracking, you simply need to filter the output mode to Rectangle. + + +**Step 1:** Propagate the prompts to get the tracklet across the video + +![run_video](https://github.com/user-attachments/assets/e4763f32-bfdb-4b0a-be23-4885e3cc9f96) + +Once you've finished setting the prompts, you can start the video tracking by either clicking the video start button on the left-hand menu or using the shortcut `Ctrl+M` to get the tracklet throughout the entire video. + +**Step 2:** Add New Prompts to Further Refine the tracklet + +After tracking the entire video, if you notice any of the following issues in the middle frames: + +- Target is lost +- Imperfections in boundary details +- New objects need to be tracked + +You can treat the current frame as the starting frame and follow these steps: + +a. Open the edit mode (`Ctrl + J`) and remove all added objects from the current frame.
+b. Click the `Reset Tracker` button at the top of the screen to reset the tracker.
+c. Reapply the prompts (Marks) as described earlier. + +Then, repeat the steps in **Step 0** and **Step 1**. + +![rename](https://github.com/user-attachments/assets/04707624-b13d-490f-a75d-7e35d5dee1c7) + +Upon completion of all tasks, you can access the `Tool` -> `Label Manager` option from the top menu to assign specific class names. + +> [!NOTE] +> Just a reminder to click the `Reset Tracker` button at the top of the screen after uploading a new video file to reset the tracker. + +--- + +Congratulations! 🎉 You’ve now mastered the basics of X-AnyLabeling. Feel free to experiment with it on your own videos and various use cases!