Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Efficient-SAM into Labelme #1375

Merged
merged 9 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion labelme/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():
)
parser.add_argument(
"--logger-level",
default="info",
default="debug",
choices=["debug", "info", "warning", "fatal", "error"],
help="logger level",
)
Expand Down
125 changes: 86 additions & 39 deletions labelme/ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,93 @@
import collections
import gdown

from .models.segment_anything import SegmentAnythingModel # NOQA
from .efficient_sam import EfficientSam
from .segment_anything_model import SegmentAnythingModel


Model = collections.namedtuple(
"Model", ["name", "encoder_weight", "decoder_weight"]
)
class SegmentAnythingModelVitB(SegmentAnythingModel):
name = "SegmentAnything (speed)"

def __init__(self):
super().__init__(
encoder_path=gdown.cached_download(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.encoder.onnx", # NOQA
md5="80fd8d0ab6c6ae8cb7b3bd5f368a752c",
),
decoder_path=gdown.cached_download(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.decoder.onnx", # NOQA
md5="4253558be238c15fc265a7a876aaec82",
),
)


class SegmentAnythingModelVitL(SegmentAnythingModel):
name = "SegmentAnything (balanced)"

def __init__(self):
super().__init__(
encoder_path=gdown.cached_download(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx", # NOQA
md5="080004dc9992724d360a49399d1ee24b",
),
decoder_path=gdown.cached_download(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx", # NOQA
md5="851b7faac91e8e23940ee1294231d5c7",
),
)


class SegmentAnythingModelVitH(SegmentAnythingModel):
name = "SegmentAnything (accuracy)"

def __init__(self):
super().__init__(
encoder_path=gdown.cached_download(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.encoder.onnx", # NOQA
md5="958b5710d25b198d765fb6b94798f49e",
),
decoder_path=gdown.cached_download(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.decoder.onnx", # NOQA
md5="a997a408347aa081b17a3ffff9f42a80",
),
)


class EfficientSamVitT(EfficientSam):
name = "EfficientSam (speed)"

def __init__(self):
super().__init__(
encoder_path=gdown.cached_download(
url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vitt_encoder.onnx", # NOQA
md5="2d4a1303ff0e19fe4a8b8ede69c2f5c7",
),
decoder_path=gdown.cached_download(
url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vitt_decoder.onnx", # NOQA
md5="be3575ca4ed9b35821ac30991ab01843",
),
)


class EfficientSamVitS(EfficientSam):
name = "EfficientSam (accuracy)"

def __init__(self):
super().__init__(
encoder_path=gdown.cached_download(
url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vits_encoder.onnx", # NOQA
md5="7d97d23e8e0847d4475ca7c9f80da96d",
),
decoder_path=gdown.cached_download(
url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vits_decoder.onnx", # NOQA
md5="d9372f4a7bbb1a01d236b0508300b994",
),
)

Weight = collections.namedtuple("Weight", ["url", "md5"])

MODELS = [
Model(
name="Segment-Anything (speed)",
encoder_weight=Weight(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.encoder.onnx", # NOQA
md5="80fd8d0ab6c6ae8cb7b3bd5f368a752c",
),
decoder_weight=Weight(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_b_01ec64.quantized.decoder.onnx", # NOQA
md5="4253558be238c15fc265a7a876aaec82",
),
),
Model(
name="Segment-Anything (balanced)",
encoder_weight=Weight(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.encoder.onnx", # NOQA
md5="080004dc9992724d360a49399d1ee24b",
),
decoder_weight=Weight(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_l_0b3195.quantized.decoder.onnx", # NOQA
md5="851b7faac91e8e23940ee1294231d5c7",
),
),
Model(
name="Segment-Anything (accuracy)",
encoder_weight=Weight(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.encoder.onnx", # NOQA
md5="958b5710d25b198d765fb6b94798f49e",
),
decoder_weight=Weight(
url="https://github.com/wkentaro/labelme/releases/download/sam-20230416/sam_vit_h_4b8939.quantized.decoder.onnx", # NOQA
md5="a997a408347aa081b17a3ffff9f42a80",
),
),
SegmentAnythingModelVitB,
SegmentAnythingModelVitL,
SegmentAnythingModelVitH,
EfficientSamVitT,
EfficientSamVitS,
]
36 changes: 36 additions & 0 deletions labelme/ai/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import imgviz
import numpy as np
import skimage


def _get_contour_length(contour):
contour_start = contour
contour_end = np.r_[contour[1:], contour[0:1]]
return np.linalg.norm(contour_end - contour_start, axis=1).sum()


def compute_polygon_from_mask(mask):
contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
contour = max(contours, key=_get_contour_length)
POLYGON_APPROX_TOLERANCE = 0.004
polygon = skimage.measure.approximate_polygon(
coords=contour,
tolerance=np.ptp(contour, axis=0).max() * POLYGON_APPROX_TOLERANCE,
)
polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1))
polygon = polygon[:-1] # drop last point that is duplicate of first point

if 0:
import PIL.Image

image_pil = PIL.Image.fromarray(
imgviz.gray2rgb(imgviz.bool2ubyte(mask))
)
imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
for point in polygon:
imgviz.draw.circle_(
image_pil, center=point, diameter=10, fill=(0, 255, 0)
)
imgviz.io.imsave("contour.jpg", np.asarray(image_pil))

return polygon[:, ::-1] # yx -> xy
109 changes: 109 additions & 0 deletions labelme/ai/efficient_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import collections
import threading

import imgviz
import numpy as np
import onnxruntime
import skimage

from ..logger import logger

from . import _utils


class EfficientSam:
def __init__(self, encoder_path, decoder_path):
self._encoder_session = onnxruntime.InferenceSession(encoder_path)
self._decoder_session = onnxruntime.InferenceSession(decoder_path)

self._lock = threading.Lock()
self._image_embedding_cache = collections.OrderedDict()

self._thread = None

def set_image(self, image: np.ndarray):
with self._lock:
self._image = image
self._image_embedding = self._image_embedding_cache.get(
self._image.tobytes()
)

if self._image_embedding is None:
self._thread = threading.Thread(
target=self._compute_and_cache_image_embedding
)
self._thread.start()

def _compute_and_cache_image_embedding(self):
with self._lock:
logger.debug("Computing image embedding...")
image = imgviz.rgba2rgb(self._image)
batched_images = (
image.transpose(2, 0, 1)[None].astype(np.float32) / 255.0
)
(self._image_embedding,) = self._encoder_session.run(
output_names=None,
input_feed={"batched_images": batched_images},
)
if len(self._image_embedding_cache) > 10:
self._image_embedding_cache.popitem(last=False)
self._image_embedding_cache[
self._image.tobytes()
] = self._image_embedding
logger.debug("Done computing image embedding.")

def _get_image_embedding(self):
if self._thread is not None:
self._thread.join()
self._thread = None
with self._lock:
return self._image_embedding

def predict_mask_from_points(self, points, point_labels):
return _compute_mask_from_points(
decoder_session=self._decoder_session,
image=self._image,
image_embedding=self._get_image_embedding(),
points=points,
point_labels=point_labels,
)

def predict_polygon_from_points(self, points, point_labels):
mask = self.predict_mask_from_points(
points=points, point_labels=point_labels
)
return _utils.compute_polygon_from_mask(mask=mask)


def _compute_mask_from_points(
decoder_session, image, image_embedding, points, point_labels
):
input_point = np.array(points, dtype=np.float32)
input_label = np.array(point_labels, dtype=np.float32)

# batch_size, num_queries, num_points, 2
batched_point_coords = input_point[None, None, :, :]
# batch_size, num_queries, num_points
batched_point_labels = input_label[None, None, :]

decoder_inputs = {
"image_embeddings": image_embedding,
"batched_point_coords": batched_point_coords,
"batched_point_labels": batched_point_labels,
"orig_im_size": np.array(image.shape[:2], dtype=np.int64),
}

masks, _, _ = decoder_session.run(None, decoder_inputs)
mask = masks[0, 0, 0, :, :] # (1, 1, 3, H, W) -> (H, W)
mask = mask > 0.0

MIN_SIZE_RATIO = 0.05
skimage.morphology.remove_small_objects(
mask, min_size=mask.sum() * MIN_SIZE_RATIO, out=mask
)

if 0:
imgviz.io.imsave(
"mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
)
return mask
Empty file removed labelme/ai/models/__init__.py
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
import imgviz
import numpy as np
import onnxruntime
import PIL.Image
import skimage

from ...logger import logger
from ..logger import logger

from . import _utils

class SegmentAnythingModel:
def __init__(self, name, encoder_path, decoder_path):
self.name = name

class SegmentAnythingModel:
def __init__(self, encoder_path, decoder_path):
self._image_size = 1024

self._encoder_session = onnxruntime.InferenceSession(encoder_path)
Expand Down Expand Up @@ -59,29 +58,21 @@ def _get_image_embedding(self):
with self._lock:
return self._image_embedding

def predict_polygon_from_points(self, points, point_labels):
image_embedding = self._get_image_embedding()
polygon = _compute_polygon_from_points(
def predict_mask_from_points(self, points, point_labels):
return _compute_mask_from_points(
image_size=self._image_size,
decoder_session=self._decoder_session,
image=self._image,
image_embedding=image_embedding,
image_embedding=self._get_image_embedding(),
points=points,
point_labels=point_labels,
)
return polygon

def predict_mask_from_points(self, points, point_labels):
image_embedding = self._get_image_embedding()
mask = _compute_mask_from_points(
image_size=self._image_size,
decoder_session=self._decoder_session,
image=self._image,
image_embedding=image_embedding,
points=points,
point_labels=point_labels,
def predict_polygon_from_points(self, points, point_labels):
mask = self.predict_mask_from_points(
points=points, point_labels=point_labels
)
return mask
return _utils.compute_polygon_from_mask(mask=mask)


def _compute_scale_to_resize_image(image_size, image):
Expand Down Expand Up @@ -133,12 +124,6 @@ def _compute_image_embedding(image_size, encoder_session, image):
return image_embedding


def _get_contour_length(contour):
contour_start = contour
contour_end = np.r_[contour[1:], contour[0:1]]
return np.linalg.norm(contour_end - contour_start, axis=1).sum()


def _compute_mask_from_points(
image_size, decoder_session, image, image_embedding, points, point_labels
):
Expand Down Expand Up @@ -186,36 +171,3 @@ def _compute_mask_from_points(
"mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
)
return mask


def _compute_polygon_from_points(
image_size, decoder_session, image, image_embedding, points, point_labels
):
mask = _compute_mask_from_points(
image_size=image_size,
decoder_session=decoder_session,
image=image,
image_embedding=image_embedding,
points=points,
point_labels=point_labels,
)

contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
contour = max(contours, key=_get_contour_length)
POLYGON_APPROX_TOLERANCE = 0.004
polygon = skimage.measure.approximate_polygon(
coords=contour,
tolerance=np.ptp(contour, axis=0).max() * POLYGON_APPROX_TOLERANCE,
)
polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1))
polygon = polygon[:-1] # drop last point that is duplicate of first point
if 0:
image_pil = PIL.Image.fromarray(image)
imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
for point in polygon:
imgviz.draw.circle_(
image_pil, center=point, diameter=10, fill=(0, 255, 0)
)
imgviz.io.imsave("contour.jpg", np.asarray(image_pil))

return polygon[:, ::-1] # yx -> xy
Loading
Loading