Skip to content

Commit

Permalink
Plotting outlines with the holes.
Browse files Browse the repository at this point in the history
Added missing arguments, more consistent documentation/type-hints, little refactoring of the function for making polygons.
  • Loading branch information
hey2homie committed Mar 21, 2024
1 parent 5f35ef3 commit 8ea29d3
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 42 deletions.
11 changes: 7 additions & 4 deletions cellpose/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def get_masks(p, iscell=None, rpad=20):

def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
flow_threshold=0.4, interp=True, do_3D=False, min_size=15, fill_holes=True,
resize=None, device=None):
area_threshold=None, resize=None, device=None):
"""Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None.
Args:
Expand All @@ -771,6 +771,7 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold
do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
min_size (int, optional): The minimum size of the masks. Defaults to 15.
fill_holes (bool, optional): Whether to fill holes in the masks. Defaults to True.
area_threshold (int, optional): If filling holes, fills holes smaller than this threshold. Default is None.
resize (tuple, optional): The desired size for resizing the masks. Defaults to None.
device (str, optional): The torch device to use for computation. Defaults to None.
Expand All @@ -780,7 +781,7 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold
mask, p = compute_masks(dP, cellprob, p=p, niter=niter,
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, interp=interp, do_3D=do_3D,
min_size=min_size, fill_holes=fill_holes, device=device)
min_size=min_size, fill_holes=fill_holes, area_threshold=area_threshold, device=device)

if resize is not None:
mask = transforms.resize_image(mask, resize[0], resize[1],
Expand All @@ -795,7 +796,7 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold

def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
flow_threshold=0.4, interp=True, do_3D=False, min_size=15,
fill_holes=True, device=None):
fill_holes=True, area_threshold=None, device=None):
"""Compute masks using dynamics from dP and cellprob.
Args:
Expand All @@ -809,6 +810,7 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
min_size (int, optional): The minimum size of the masks. Defaults to 15.
fill_holes (bool, optional): Whether to fill holes in the masks. Defaults to True.
area_threshold (int, optional): If filling holes, fills holes smaller than this threshold. Default is None.
device (str, optional): The torch device to use for computation. Defaults to None.
Returns:
Expand Down Expand Up @@ -858,7 +860,8 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
p = np.zeros((len(shape), *shape), np.uint16)
return mask, p

mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size, fill_holes=fill_holes)
mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size, fill_holes=fill_holes,
area_threshold=area_threshold)

if mask.dtype == np.uint32:
dynamics_logger.warning(
Expand Down
25 changes: 21 additions & 4 deletions cellpose/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,15 @@ def outlines_to_text(base, outlines):
f.write("\n")


def polygons_to_geojson(base, polygons):
def polygons_to_geojson(base, polygons) -> None:
"""
Create a geojson file from polygons.
Args:
base (str): base name of the file to save
polygons (list): list of polygons
Returns:
None
"""
geojson = {
"type": "FeatureCollection",
"features": []
Expand Down Expand Up @@ -613,7 +621,7 @@ def save_rois(masks, file_name):
def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[0, 0],
suffix="", save_flows=False, save_outlines=False,
dir_above=False, in_folders=False, savedir=None, save_txt=False,
save_geojson=False, save_mpl=False):
save_geojson=False, keep_holes=False, save_mpl=False):
""" Save masks + nicely plotted segmentation image to png and/or tiff.
Can save masks, flows to different directories, if in_folders is True.
Expand Down Expand Up @@ -642,6 +650,7 @@ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[
savedir (str, optional): Absolute path where images will be saved. If None, saves to image directory. Defaults to None.
save_txt (bool, optional): Save masks as list of outlines for ImageJ. Defaults to False.
save_geojson (bool, optional): Save masks as geojson. Defaults to False.
keep_holes (bool, optional): Keep holes outlines inside polygons. Default is False.
save_mpl (bool, optional): If True, saves a matplotlib figure of the original image/segmentation/flows. Does not work for 3D.
This takes a long time for large images. Defaults to False.
Expand Down Expand Up @@ -741,14 +750,22 @@ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[

# QuPath geojson files
if masks.ndim < 3 and save_geojson:
polygons = utils.outlines_polygons(masks)
polygons = utils.outlines_polygons(masks, keep_holes=keep_holes)
if polygons is not None:
polygons_to_geojson(os.path.join(txtdir, basename), polygons)

# RGB outline images
if masks.ndim < 3 and save_outlines:
check_dir(outlinedir)
outlines = utils.masks_to_outlines(masks)
polygons = utils.outlines_polygons(masks, keep_holes=True)
image_shape = images.shape[1:] if images.shape[0] < 4 else images.shape[:2]
outlines = np.zeros(shape=image_shape)
for polygon in polygons: # TODO: Little ad-hoc
for outline in polygon:
for coordinates in range(len(outline)):
x = outline[coordinates][0]
y = outline[coordinates][1]
outlines[int(y), int(x)] = 255
outX, outY = np.nonzero(outlines)
img0 = transforms.normalize99(images)
if img0.shape[0] < 4:
Expand Down
12 changes: 7 additions & 5 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None,
stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile=True,
tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, fill_holes=True,
progress=None):
area_threshold=None, progress=None):
""" segment list of images x, or 4D array - Z x nchan x Y x X
Args:
Expand Down Expand Up @@ -353,6 +353,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True.
compute_masks (bool, optional): Whether or not to compute dynamics and return masks. This is set to False when retrieving the styles for the size model. Defaults to True.
fill_holes (bool, optional): Whether or not to fill holes in masks. Defaults to True.
area_threshold (int, optional): If filling holes, fills holes smaller than this threshold. Default is None.
progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
Returns:
Expand Down Expand Up @@ -386,7 +387,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
interp=interp, flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold, compute_masks=compute_masks,
min_size=min_size, fill_holes=fill_holes, stitch_threshold=stitch_threshold,
progress=progress, niter=niter)
area_threshold=area_threshold, progress=progress, niter=niter)
masks.append(maski)
flows.append(flowi)
styles.append(stylei)
Expand Down Expand Up @@ -414,15 +415,15 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
tile_overlap=tile_overlap, bsize=bsize, flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold, interp=interp, min_size=min_size,
do_3D=do_3D, anisotropy=anisotropy, niter=niter, fill_holes=fill_holes,
stitch_threshold=stitch_threshold)
area_threshold=area_threshold, stitch_threshold=stitch_threshold)

flows = [plot.dx_to_circ(dP), dP, cellprob, p]
return masks, flows, styles

def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=None,
rescale=1.0, resample=True, augment=False, tile=True, tile_overlap=0.1,
cellprob_threshold=0.0, bsize=224, flow_threshold=0.4, min_size=15,
interp=True, anisotropy=1.0, do_3D=False, stitch_threshold=0.0, fill_holes=True):
interp=True, anisotropy=1.0, do_3D=False, stitch_threshold=0.0, fill_holes=True, area_threshold=None):

if isinstance(normalize, dict):
normalize_params = {**normalize_default, **normalize}
Expand Down Expand Up @@ -506,7 +507,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
masks, p = dynamics.resize_and_compute_masks(
dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, interp=interp, do_3D=do_3D,
min_size=min_size, resize=None, fill_holes=fill_holes,
min_size=min_size, resize=None, fill_holes=fill_holes, area_threshold=area_threshold,
device=self.device if self.gpu else None)
else:
masks, p = [], []
Expand All @@ -526,6 +527,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
min_size=min_size if stitch_threshold == 0 or nimg == 1 else
-1, # turn off for 3D stitching
fill_holes=fill_holes,
area_threshold=area_threshold,
device=self.device if self.gpu else None)
masks.append(outputs[0])
p.append(outputs[1])
Expand Down
72 changes: 43 additions & 29 deletions cellpose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def outlines_list_single(masks):
list: List of outlines as pixel coordinates.
"""
# TODO: Computing masks
outpix = []
for n in np.unique(masks)[1:]:
mn = masks == n
Expand Down Expand Up @@ -290,16 +289,31 @@ def outlines_list_multi(masks, num_processes=None):
return outpix


def get_polygon(outline: np.ndarray[int], bb: tuple[int, int, int, int], dim: int = 0,
keep_holes: bool = False) -> list:
def make_polygon(outline: np.ndarray, bb: tuple) -> list:
"""
Enclose a polygon by adding the first point to the end of the list.
Args:
outline (np.ndarray): A list of points in the polygon.
Returns:
polygon (list): The enclosed polygon.
"""
coordinates = outline + np.array([bb[1], bb[0]])
polygon = [list(map(float, point)) for point in coordinates]
if polygon[0] != polygon[-1]:
polygon.append(polygon[0])
return polygon


def get_polygon(outline: np.ndarray, bb: tuple, dim: int, keep_holes: bool = False) -> list:
"""
Compute contour contours from binary mask, translate to bounding box coordinates, and return as polygon.
Args:
outline (np.ndarray[int]): Binary mask.
bb (tuple[int, int, int, int]): Bounding box coordinates.
outline (np.ndarray): Binary mask.
bb (tuple): Bounding box coordinates.
dim (int): In which dimension to look for the contour.
keep_holes (bool, optional): Whether to keep holes in the mask. Default is False.
Returns:
polygon (list[list[float]]): Polygon coordinates compatible with geojson format. # TODO: Holes case
polygon (list): Polygon coordinates compatible with geojson format.
"""
cv2_method = cv2.RETR_TREE if keep_holes else cv2.RETR_EXTERNAL
contours, hierarchy = cv2.findContours(
Expand All @@ -308,40 +322,34 @@ def get_polygon(outline: np.ndarray[int], bb: tuple[int, int, int, int], dim: in
method=cv2.CHAIN_APPROX_NONE,
)
outline = contours[dim].squeeze()
coordinates = outline + np.array([bb[1], bb[0]])
polygon = [list(map(float, point)) for point in coordinates]
if polygon[0] != polygon[-1]:
polygon.append(polygon[0])
to_return = [polygon]
if hierarchy.shape[1] > 1:
polygon = make_polygon(outline, bb)
polygon = [polygon]
if keep_holes and hierarchy.shape[1] > 1:
inner_contours = np.where(hierarchy[0, :, 3].squeeze() != -1)[0]
inner_polygons = []
for c_idx in inner_contours:
inner_outline = contours[c_idx].squeeze()
inner_coordinates = inner_outline + np.array([bb[1], bb[0]])
inner_polygon = [list(map(float, point)) for point in inner_coordinates]
if inner_polygon[0] != inner_polygon[-1]:
inner_polygon.append(inner_polygon[0])
inner_polygon = make_polygon(inner_outline, bb)
inner_polygons.append(inner_polygon)
to_return.extend(inner_polygons)
return to_return
polygon.extend(inner_polygons)
return polygon


def outlines_polygons(masks: np.ndarray[int], keep_holes: bool = False) -> list[list[list[float]]]:
def outlines_polygons(masks: np.ndarray, keep_holes: bool = False) -> list:
"""
Get outlines of masks as polygons writing geojson.
Args:
masks (np.ndarray[int]): masks (0=no cells, 1=first cell, 2=second cell,...)
masks (np.ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
keep_holes (bool, optional): Whether to keep holes in the mask. Default is False.
Returns:
polygons (list[list[list[float]]]): List of polygons as pixel coordinates.
polygons (list): List of polygons as pixel coordinates.
"""
polygons: list[list[list[float]]] = []
polygons: list = []
objects = find_objects(masks)
for i, sl in enumerate(objects):
lb = i + 1
image = masks[sl] == lb
bbox = tuple([sl[i].start for i in range(masks.ndim)]
+ [sl[i].stop for i in range(masks.ndim)])
bbox = tuple([sl[i].start for i in range(masks.ndim)] + [sl[i].stop for i in range(masks.ndim)])
outline = image.astype(np.uint8)
try:
polygon = get_polygon(outline, bbox, 0, keep_holes)
Expand Down Expand Up @@ -674,7 +682,7 @@ def size_distribution(masks):
counts = np.unique(masks, return_counts=True)[1][1:]
return np.percentile(counts, 25) / np.percentile(counts, 75)

def fill_holes_and_remove_small_masks(masks, min_size=15, fill_holes=True):
def fill_holes_and_remove_small_masks(masks, min_size=15, fill_holes=True, area_threshold=None):
""" Fills holes in masks (2D/3D) and discards masks smaller than min_size.
This function fills holes in each mask using scipy.ndimage.morphology.binary_fill_holes.
Expand All @@ -688,7 +696,8 @@ def fill_holes_and_remove_small_masks(masks, min_size=15, fill_holes=True):
Masks smaller than min_size will be removed.
Set to -1 to turn off this functionality. Default is 15.
fill_holes (bool, optional): Whether to fill holes in masks. Default is True.
area_threshold (int, optional): If filling holes, fills holes smaller than this threshold.
If None or SKIMAGE_ENABLED is False, fills all holes. Default is None.
Returns:
ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed.
0 represents no mask, while positive integers represent mask labels.
Expand All @@ -711,10 +720,15 @@ def fill_holes_and_remove_small_masks(masks, min_size=15, fill_holes=True):
if fill_holes:
if msk.ndim == 3:
for k in range(msk.shape[0]):
# TODO: Replace binary_fill_holes with remove_small_holes
msk[k] = remove_small_holes(msk[k], area_threshold=5)
if area_threshold is not None and SKIMAGE_ENABLED:
msk[k] = remove_small_holes(msk[k], area_threshold=area_threshold)
else:
msk[k] = binary_fill_holes(msk[k])
else:
msk = remove_small_holes(msk, area_threshold=5)
if area_threshold is not None and SKIMAGE_ENABLED:
msk = remove_small_holes(msk, area_threshold=area_threshold)
else:
msk = binary_fill_holes(msk)
masks[slc][msk] = (j + 1)
j += 1
return masks

0 comments on commit 8ea29d3

Please sign in to comment.