From f2a298b9412235cf2c11429410e8026e296a5017 Mon Sep 17 00:00:00 2001 From: Junde Wu Date: Thu, 4 Jan 2024 15:47:55 +0000 Subject: [PATCH] add refuge --- dataset.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/dataset.py b/dataset.py index 68d027d2..b53e01fe 100644 --- a/dataset.py +++ b/dataset.py @@ -12,7 +12,7 @@ import torch from torch.utils.data import Dataset from PIL import Image -import torchvision.transforms.functional as F +import torch.nn.functional as F import torchvision.transforms as transforms import pandas as pd from skimage.transform import rotate @@ -90,4 +90,77 @@ def __getitem__(self, index): 'pt':pt, 'image_meta_dict':image_meta_dict, } + + + class REFUGE(Dataset): + def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False): + self.data_path = data_path + self.subfolders = [f.path for f in os.scandir(os.path.join(data_path, mode + '-400')) if f.is_dir()] + self.mode = mode + self.prompt = prompt + self.img_size = args.image_size + self.mask_size = args.out_size + + self.transform = transform + self.transform_msk = transform_msk + + def __len__(self): + return len(self.subfolders) + + def __getitem__(self, index): + inout = 1 + point_label = 1 + + """Get the images""" + subfolder = self.subfolders[index] + name = subfolder.split('/')[-1] + + # raw image and raters path + img_path = os.path.join(subfolder, name + '.jpg') + multi_rater_cup_path = [os.path.join(subfolder, name + '_seg_cup_' + str(i) + '.png') for i in range(1, 8)] + multi_rater_disc_path = [os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png') for i in range(1, 8)] + + # raw image and raters images + img = Image.open(img_path).convert('RGB') + multi_rater_cup = [Image.open(path).convert('L') for path in multi_rater_cup_path] + multi_rater_disc = [Image.open(path).convert('L') for path in multi_rater_disc_path] + + # resize raters images for generating initial point click + newsize = (self.img_size, self.img_size) + multi_rater_cup_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_cup] + multi_rater_disc_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_disc] + + # first click is the target agreement among all raters + if self.prompt == 'click': + pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label, inout) + pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label, inout) + + if self.transform: + state = torch.get_rng_state() + img = self.transform(img) + multi_rater_cup = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_cup] + multi_rater_cup = torch.stack(multi_rater_cup, dim=0) + # transform to mask size (out_size) for mask define + mask_cup = F.interpolate(multi_rater_cup, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0) + + multi_rater_disc = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_disc] + multi_rater_disc = torch.stack(multi_rater_disc, dim=0) + mask_disc = F.interpolate(multi_rater_disc, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0) + torch.set_rng_state(state) + + image_meta_dict = {'filename_or_obj':name} + return { + 'image':img, + 'multi_rater_cup': multi_rater_cup, + 'multi_rater_disc': multi_rater_disc, + 'mask_cup': mask_cup, + 'mask_disc': mask_disc, + 'label': mask_disc, + 'p_label':point_label, + 'pt_cup':pt_cup, + 'pt_disc':pt_disc, + 'pt':pt_disc, + 'selected_rater': torch.tensor(np.arange(7)), + 'image_meta_dict':image_meta_dict, + }