Skip to content

Commit

Permalink
add refuge
Browse files Browse the repository at this point in the history
  • Loading branch information
WuJunde authored Jan 4, 2024
1 parent 5904e9a commit f2a298b
Showing 1 changed file with 74 additions and 1 deletion.
75 changes: 74 additions & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms.functional as F
import torch.nn.functional as F
import torchvision.transforms as transforms
import pandas as pd
from skimage.transform import rotate
Expand Down Expand Up @@ -90,4 +90,77 @@ def __getitem__(self, index):
'pt':pt,
'image_meta_dict':image_meta_dict,
}


class REFUGE(Dataset):
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
self.data_path = data_path
self.subfolders = [f.path for f in os.scandir(os.path.join(data_path, mode + '-400')) if f.is_dir()]
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.mask_size = args.out_size

self.transform = transform
self.transform_msk = transform_msk

def __len__(self):
return len(self.subfolders)

def __getitem__(self, index):
inout = 1
point_label = 1

"""Get the images"""
subfolder = self.subfolders[index]
name = subfolder.split('/')[-1]

# raw image and raters path
img_path = os.path.join(subfolder, name + '.jpg')
multi_rater_cup_path = [os.path.join(subfolder, name + '_seg_cup_' + str(i) + '.png') for i in range(1, 8)]
multi_rater_disc_path = [os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png') for i in range(1, 8)]

# raw image and raters images
img = Image.open(img_path).convert('RGB')
multi_rater_cup = [Image.open(path).convert('L') for path in multi_rater_cup_path]
multi_rater_disc = [Image.open(path).convert('L') for path in multi_rater_disc_path]

# resize raters images for generating initial point click
newsize = (self.img_size, self.img_size)
multi_rater_cup_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_cup]
multi_rater_disc_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_disc]

# first click is the target agreement among all raters
if self.prompt == 'click':
pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label, inout)
pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label, inout)

if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
multi_rater_cup = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_cup]
multi_rater_cup = torch.stack(multi_rater_cup, dim=0)
# transform to mask size (out_size) for mask define
mask_cup = F.interpolate(multi_rater_cup, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0)

multi_rater_disc = [torch.as_tensor((self.transform(single_rater) >0.5).float(), dtype=torch.float32) for single_rater in multi_rater_disc]
multi_rater_disc = torch.stack(multi_rater_disc, dim=0)
mask_disc = F.interpolate(multi_rater_disc, size=(self.mask_size, self.mask_size), mode='bilinear', align_corners=False).mean(dim=0)
torch.set_rng_state(state)

image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'multi_rater_cup': multi_rater_cup,
'multi_rater_disc': multi_rater_disc,
'mask_cup': mask_cup,
'mask_disc': mask_disc,
'label': mask_disc,
'p_label':point_label,
'pt_cup':pt_cup,
'pt_disc':pt_disc,
'pt':pt_disc,
'selected_rater': torch.tensor(np.arange(7)),
'image_meta_dict':image_meta_dict,
}

0 comments on commit f2a298b

Please sign in to comment.