Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : add SAM2CameraPredictor #124

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ or individually from:
- [sam2_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)
- [sam2_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)

Then SAM 2 can be used in a few lines as follows for image and video prediction.
Then SAM 2 can be used in a few lines as follows for image, video and camera prediction.

### Image prediction

Expand Down Expand Up @@ -98,9 +98,43 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
...
```

Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.


### Camera prediction


```python
import torch
from sam2.build_sam import build_sam2_camera_predictor

checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = build_sam2_camera_predictor(model_cfg, checkpoint)

cap = cv2.VideoCapture(<your video or camera >)

if_init = False

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
while True:
ret, frame = cap.read()
if not ret:
break
width, height = frame.shape[:2][::-1]

if not if_init:
predictor.load_first_frame(frame)
if_init = True
_, out_obj_ids, out_mask_logits = predictor.add_new_points(<your promot >)

else:
out_obj_ids, out_mask_logits = predictor.track(frame)
...
```
Please refer to the examples in [camera_predictor_example.ipynb](./notebooks/camera_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects on live video.


## Model Description

| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
Expand Down
728 changes: 728 additions & 0 deletions notebooks/camera_predictor_example.ipynb

Large diffs are not rendered by default.

Binary file added notebooks/videos/blackswan/blackswan.mp4
Binary file not shown.
37 changes: 36 additions & 1 deletion sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,41 @@ def build_sam2_video_predictor(
return model


def build_sam2_camera_predictor(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
):
hydra_overrides = [
"++model._target_=sam2.sam2_camera_predictor.SAM2CameraPredictor",
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
"++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
"++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)

# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path)
model = model.to(device)
if mode == "eval":
model.eval()
return model

def _load_checkpoint(model, ckpt_path):
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu")["model"]
Expand All @@ -86,4 +121,4 @@ def _load_checkpoint(model, ckpt_path):
if unexpected_keys:
logging.error(unexpected_keys)
raise RuntimeError()
logging.info("Loaded checkpoint sucessfully")
logging.info("Loaded checkpoint sucessfully")
Loading