diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index c7e01ccf..b7ffcc9d 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -46,15 +46,16 @@ def init_state( video_path, offload_video_to_cpu=False, offload_state_to_cpu=False, - async_loading_frames=False, + frame_load_config=None, ): """Initialize an inference state.""" + frame_load_config = frame_load_config or {} compute_device = self.device # device of the model images, video_height, video_width = load_video_frames( video_path=video_path, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, - async_loading_frames=async_loading_frames, + frame_load_config=frame_load_config, compute_device=compute_device, ) inference_state = {} diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index b65ee825..c02a6037 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -4,14 +4,18 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import asyncio import os import warnings + +from abc import abstractmethod from threading import Thread import numpy as np import torch from PIL import Image from tqdm import tqdm +from types import LambdaType def get_sdpa_settings(): @@ -89,39 +93,35 @@ def mask_to_box(masks: torch.Tensor): return bbox_coords -def _load_img_as_tensor(img_path, image_size): - img_pil = Image.open(img_path) +def _load_img_pil_as_tensor(img_id, img_pil, image_size): img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images img_np = img_np / 255.0 else: - raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_id}") img = torch.from_numpy(img_np).permute(2, 0, 1) video_width, video_height = img_pil.size # the original video size return img, video_height, video_width -class AsyncVideoFrameLoader: +class LazyVideoFrameLoader: """ - A list of video frames to be load asynchronously without blocking session start. + Abstract class that defines primitives to load frames lazily. """ - def __init__( self, - img_paths, image_size, offload_video_to_cpu, img_mean, img_std, compute_device, ): - self.img_paths = img_paths self.image_size = image_size self.offload_video_to_cpu = offload_video_to_cpu self.img_mean = img_mean self.img_std = img_std # items in `self.images` will be loaded asynchronously - self.images = [None] * len(img_paths) + self.images = [None] * self.__len__() # catch and raise any exceptions in the async loading thread self.exception = None # video_height and video_width be filled when loading the first image @@ -131,18 +131,25 @@ def __init__( # load the first frame to fill video_height and video_width and also # to cache it (since it's most likely where the user will click) - self.__getitem__(0) + self.__getitem__(self.get_first_frame_num()) - # load the rest of frames asynchronously without blocking the session start - def _load_frames(): + if self.should_preload(): + self.thread = Thread( + target=self.load_frames, + daemon=True, + ) + self.thread.start() + + def load_frames(self): + asyncio.run(self.preload()) + + async def preload(self): + async for index in self.get_preload_generator(): try: - for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): - self.__getitem__(n) + self.__getitem__(index) except Exception as e: - self.exception = e - - self.thread = Thread(target=_load_frames, daemon=True) - self.thread.start() + if self.propagate_preload_errors(): + self.exception = e def __getitem__(self, index): if self.exception is not None: @@ -152,8 +159,8 @@ def __getitem__(self, index): if img is not None: return img - img, video_height, video_width = _load_img_as_tensor( - self.img_paths[index], self.image_size + img, video_height, video_width = _load_img_pil_as_tensor( + self.get_image_id(index), self.load_image(index), self.image_size ) self.video_height = video_height self.video_width = video_width @@ -166,16 +173,132 @@ def __getitem__(self, index): return img def __len__(self): + return self.get_length() + + @abstractmethod + def get_first_frame_num(self): + raise NotImplementedError + + @abstractmethod + def should_preload(self): + raise NotImplementedError + + @abstractmethod + def get_preload_generator(self): + raise NotImplementedError + + @abstractmethod + def propagate_preload_errors(self): + raise NotImplementedError + + @abstractmethod + def load_image(self, index): + raise NotImplementedError + + @abstractmethod + def get_image_id(self, index): + raise NotImplementedError + + @abstractmethod + def get_length(self): + raise NotImplementedError + + +class AsyncVideoFrameLoader(LazyVideoFrameLoader): + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + LazyVideoFrameLoader.__init__( + self, image_size, offload_video_to_cpu, img_mean, img_std, compute_device + ) + + def get_first_frame_num(self): + return 0 + + def should_preload(self): + return True + + def get_preload_generator(self): + async def _available(img_paths): + for i in tqdm(len(img_paths), desc="frame loading (JPEG)"): + yield i + + return _available(self.img_paths) + + def propagate_preload_errors(self): + return True + + def load_image(self, index): + return Image.load(self.img_paths[index]) + + def get_image_id(self, index): + return self.img_paths[index] + + def get_length(self): return len(self.images) + +class StreamingVideoFrameLoader(LazyVideoFrameLoader): + """ + A list of video frames that can be loaded lazily even if they are produced after session start. + """ + def __init__( + self, + loader_func, + stream_config, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.loader_func = loader_func + self.stream_config = stream_config + LazyVideoFrameLoader.__init__( + self, image_size, offload_video_to_cpu, img_mean, img_std, compute_device + ) + + def get_first_frame_num(self): + return self.stream_config.get("first_frame_num", 0) + + def should_preload(self): + return self.stream_config.get("preload_gen", None) is not None + + def get_preload_generator(self): + return self.stream_config.get("preload_gen") + + def propagate_preload_errors(self): + return self.stream_config.get("propagate_preload_errors", True) + + def load_image(self, index): + return self.loader_func(index) + + def get_image_id(self, index): + return str(index) + + def get_length(self): + return self.stream_config.get("max_frames") + + def load_video_frames( video_path, image_size, offload_video_to_cpu, img_mean=(0.485, 0.456, 0.406), img_std=(0.229, 0.224, 0.225), - async_loading_frames=False, + frame_load_config=None, compute_device=torch.device("cuda"), ): """ @@ -184,6 +307,7 @@ def load_video_frames( """ is_bytes = isinstance(video_path, bytes) is_str = isinstance(video_path, str) + is_func = isinstance(video_path, LambdaType) is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] if is_bytes or is_mp4_path: return load_video_frames_from_video_file( @@ -201,7 +325,18 @@ def load_video_frames( offload_video_to_cpu=offload_video_to_cpu, img_mean=img_mean, img_std=img_std, - async_loading_frames=async_loading_frames, + frame_load_config=frame_load_config, + compute_device=compute_device, + ) + + elif is_func: + return load_video_frames_from_lambda( + loader_func=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + frame_load_config=frame_load_config, compute_device=compute_device, ) else: @@ -210,13 +345,37 @@ def load_video_frames( ) +def load_video_frames_from_lambda( + loader_func, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + frame_load_config=None, + compute_device=torch.device("cuda"), +): + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + lazy_images = StreamingVideoFrameLoader( + loader_func, + frame_load_config, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + def load_video_frames_from_jpg_images( video_path, image_size, offload_video_to_cpu, img_mean=(0.485, 0.456, 0.406), img_std=(0.229, 0.224, 0.225), - async_loading_frames=False, + frame_load_config=None, compute_device=torch.device("cuda"), ): """ @@ -253,7 +412,7 @@ def load_video_frames_from_jpg_images( img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] - if async_loading_frames: + if frame_load_config.get("async", False): lazy_images = AsyncVideoFrameLoader( img_paths, image_size, @@ -266,7 +425,9 @@ def load_video_frames_from_jpg_images( images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): - images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + images[n], video_height, video_width = _load_img_pil_as_tensor( + img_path, Image.open(img_path), image_size + ) if not offload_video_to_cpu: images = images.to(compute_device) img_mean = img_mean.to(compute_device) diff --git a/tools/vos_inference.py b/tools/vos_inference.py index 5c40cda9..437bded2 100644 --- a/tools/vos_inference.py +++ b/tools/vos_inference.py @@ -135,7 +135,7 @@ def vos_inference( ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) inference_state = predictor.init_state( - video_path=video_dir, async_loading_frames=False + video_path=video_dir, frame_load_config=None ) height = inference_state["video_height"] width = inference_state["video_width"] @@ -273,7 +273,7 @@ def vos_separate_inference_per_object( ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) inference_state = predictor.init_state( - video_path=video_dir, async_loading_frames=False + video_path=video_dir, frame_load_config=None ) height = inference_state["video_height"] width = inference_state["video_width"]