Skip to content

Commit

Permalink
added lung cropper dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
datonefaridze committed Jul 8, 2021
1 parent a43149d commit b2ec08a
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 37 deletions.
108 changes: 72 additions & 36 deletions tools/datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Tuple, Union
import warnings

import cv2
import torch
Expand All @@ -8,6 +9,7 @@
import torchvision.transforms as transforms

from tools.supervisely_tools import convert_ann_to_mask
from tools.utils import find_obj_bbox


class SegmentationDataset(Dataset):
Expand Down Expand Up @@ -66,7 +68,6 @@ def __getitem__(self,
return image, mask


# TODO: Fix LungsCropper in order to crop images
class LungsCropper(Dataset):
def __init__(self,
img_paths: List[str],
Expand All @@ -76,15 +77,17 @@ def __init__(self,
output_size: Union[int, List[int]] = (512, 512),
class_name: str = 'COVID-19',
transform_params=None,
augmentation_params=None) -> None:
flag_type: str = None) -> None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

assert flag_type in ['single_crop', 'double_crop', 'crop'], 'invalid flag type'
self.flag_type = flag_type

self.img_paths, self.ann_paths = img_paths, ann_paths
self.class_name = class_name
self.model_input_size = (model_input_size, model_input_size) if isinstance(model_input_size, int) else model_input_size
self.model_input_size = (model_input_size, model_input_size) if isinstance(model_input_size,
int) else model_input_size
self.output_size = (output_size, output_size) if isinstance(output_size, int) else output_size

self.augmentation_params = augmentation_params
self.transform_params = transform_params
self.lung_segmentation_model = lung_segmentation_model.to(self.device)
self.lung_segmentation_model = self.lung_segmentation_model.eval()
Expand All @@ -94,18 +97,15 @@ def __init__(self,
interpolation=Image.BICUBIC),
transforms.Normalize(mean=self.transform_params['mean'],
std=self.transform_params['std'])])
self.preprocess_model_mask = transforms.Compose([transforms.ToTensor(),
transforms.Resize(size=self.model_input_size,
interpolation=Image.NEAREST)])

self.preprocess_output_image = transforms.Compose([transforms.ToTensor(),
transforms.Resize(size=self.output_size,
interpolation=Image.BICUBIC),
transforms.Normalize(mean=self.transform_params['mean'],
std=self.transform_params['std'])])
transforms.Resize(size=self.output_size,
interpolation=Image.BICUBIC),
transforms.Normalize(mean=self.transform_params['mean'],
std=self.transform_params['std'])])
self.preprocess_output_mask = transforms.Compose([transforms.ToTensor(),
transforms.Resize(size=self.output_size,
interpolation=Image.NEAREST)])
transforms.Resize(size=self.output_size,
interpolation=Image.NEAREST)])

def __len__(self):
return len(self.img_paths)
Expand All @@ -118,43 +118,79 @@ def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray]:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, self.model_input_size)

# TODO (David): convert_ann_to_mask was changed, test correct mask conversion using the following:
# TODO (David): mask = convert_ann_to_mask(ann_path=ann_path, class_name=self.class_name)
if ('rsna_normal' in image_path) or ('chest_xray_normal' in image_path):
mask = np.zeros(image.shape[:2], dtype=np.uint8)
else:
mask = convert_ann_to_mask(ann_path=ann_path, class_name=self.class_name)
mask = cv2.resize(mask, self.model_input_size)

# Apply augmentation
if self.augmentation_params:
sample = self.augmentation_params(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
mask = convert_ann_to_mask(ann_path=ann_path, class_name=self.class_name)
mask = cv2.resize(mask, self.model_input_size, self.model_input_size)

# TODO: avoid transferring tensors to numpy and numpy to tensors
if self.transform_params:
transformed_image = self.preprocess_model_image(image)
transformed_image = transformed_image.to(self.device)
transformed_image = self.preprocess_model_image(image).to(self.device)

with torch.no_grad():
lungs_prediction = self.lung_segmentation_model(torch.unsqueeze(transformed_image, 0))
predicted_mask = lungs_prediction.permute(0, 2, 3, 1).cpu().detach().numpy()[0, :, :, :] > 0.5

intersection_mask = mask * predicted_mask[:, :, 0]
mask = self.preprocess_output_mask(intersection_mask)
image = self.preprocess_output_image(image * predicted_mask)
intersectio_image = image * predicted_mask

if self.flag_type == 'crop':
image = self.preprocess_output_image(intersectio_image)
mask = self.preprocess_output_mask(intersection_mask)
return (image, mask)

elif self.flag_type == 'single_crop':
bbox_coordinates = find_obj_bbox(predicted_mask)
if len(bbox_coordinates) > 2:
warnings.warn("there are {} object, this might create problems".format(len(bbox_coordinates)))

bbox_min_x = np.min([x[0] for x in bbox_coordinates])
bbox_min_y = np.min([x[1] for x in bbox_coordinates])
bbox_max_x = np.max([x[2] for x in bbox_coordinates])
bbox_max_y = np.max([x[3] for x in bbox_coordinates])

single_cropped_image = intersectio_image[bbox_min_y:bbox_max_y, bbox_min_x:bbox_max_x]
single_cropped_mask = intersection_mask[bbox_min_y:bbox_max_y, bbox_min_x:bbox_max_x]

image = self.preprocess_output_image(single_cropped_image)
mask = self.preprocess_output_mask(single_cropped_mask)
return (image, mask)

elif self.flag_type == 'double_crop':
bbox_coordinates = find_obj_bbox(predicted_mask)
if len(bbox_coordinates) > 2:
warnings.warn("there are {} object, this might create problems".format(len(bbox_coordinates)))

bbox_coordinates.sort(key=lambda x: - (x[2]-x[0]) * (x[3]-x[1]))
images = []
masks = []

for i, bbox in enumerate(bbox_coordinates):
if i >= 2:
break

x_min, y_min, x_max, y_max = bbox[0], bbox[1], bbox[2], bbox[3]
single_cropped_image = intersectio_image[y_min:y_max, x_min:x_max]
single_cropped_mask = intersection_mask[y_min:y_max, x_min:x_max]

image = self.preprocess_output_image(single_cropped_image)
mask = self.preprocess_output_mask(single_cropped_mask)
images.append(image)
masks.append(mask)

return (tuple(images), tuple(masks))

return image, mask


if __name__ == '__main__':

# The code snippet below is used only for debugging
from tools.supervisely_tools import read_supervisely_project
image_paths, ann_paths, dataset_names = read_supervisely_project(sly_project_dir='dataset/covid_segmentation_single_crop',
included_datasets=[
'Actualmed-COVID-chestxray-dataset',
'rsna_normal'
])

image_paths, ann_paths, dataset_names = read_supervisely_project(
sly_project_dir='dataset/covid_segmentation_single_crop',
included_datasets=[
'Actualmed-COVID-chestxray-dataset',
'rsna_normal'
])
dataset = SegmentationDataset(img_paths=image_paths,
ann_paths=ann_paths,
input_size=[512, 512],
Expand Down
16 changes: 15 additions & 1 deletion tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,18 @@ def divide_lung(lung: np.array):
img2 = cv2.rotate(pad_2, cv2.ROTATE_90_COUNTERCLOCKWISE)
img3 = cv2.rotate(pad_3, cv2.ROTATE_90_COUNTERCLOCKWISE)

return img1, img2, img3
return img1, img2, img3


def find_obj_bbox(mask: np.array):
assert np.max(mask) <= 1 and np.min(mask) >= 0, 'mask values should be in [0,1] scale, max {}' \
' min {}'.format(np.max(mask), np.min(mask))
binary_map = (mask > 0.5).astype(np.uint8)
num_labels, _, stats, _ = cv2.connectedComponentsWithStats(binary_map, connectivity=8, ltype=cv2.CV_32S)
bbox_coordinates = []

for i in range(1, num_labels):
x0, y0 = stats[i, cv2.CC_STAT_LEFT], stats[i, cv2.CC_STAT_TOP]
x1, y1 = x0 + stats[i, cv2.CC_STAT_WIDTH], y0 + stats[i, cv2.CC_STAT_HEIGHT]
bbox_coordinates.append((x0, y0, x1, y1))
return bbox_coordinates

0 comments on commit b2ec08a

Please sign in to comment.