-
Notifications
You must be signed in to change notification settings - Fork 21
/
dataset.py
81 lines (72 loc) · 3.15 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from torch.utils.data import Dataset
from torchvision import transforms, utils
import numpy as np
from scipy import ndimage
import torch
class ICLRDataset(Dataset):
def __init__(self, tile, imgs, areas, gts, field_masks, split_type, index, mix_aug = True):
self.tile = tile
if split_type == 'test':
idx = np.where(gts == -1)
self.imgs = imgs[idx]
self.areas = areas[idx]
self.gts = gts[idx]
self.field_masks = field_masks[idx]
else:
idx = np.array(index)
self.imgs = imgs[gts > -1][idx]
self.areas = areas[gts > -1][idx]
self.field_masks = field_masks[gts > -1][idx]
self.gts = gts[gts > -1][idx]
self.split_type = split_type
self.feat_arr = [i for i in range(imgs.shape[2]) if i != 10] #remove B11 from features
self.mix_aug = mix_aug
def __len__(self):
return self.imgs.shape[0]
def augment(self, img, mask):
p = np.random.random(3)
ang = np.random.uniform(-10, 10)
#mixup training image with a random crop from the tiles
if self.mix_aug & (p[0] > 0.5):
start_x = np.random.randint(0, self.tile.shape[-2] - 32)
start_y = np.random.randint(0, self.tile.shape[-1] - 32)
t = np.random.randint(0, 4)
rnd_crop = self.tile[t, :, :, start_x:start_x+32, start_y:start_y+32]
d = 0.85
img = img * d + rnd_crop * (1 - d)
#remove data of randomly selected date (history augmentation)
size = 1
while True:
idx_to_rmv = np.random.randint(low = img.shape[0], size = size).tolist()
if np.unique(idx_to_rmv).shape[0] == size:
break
hist_idx = [i for i in range(img.shape[0]) if not (i in idx_to_rmv)]
img = img[hist_idx]
#apply flipping and rotation augmentation
if p[1] > 0.5:
mask[0] = np.flipud(mask[0])
if p[2] > 0.5:
mask[0] = ndimage.rotate(mask[0], ang, reshape = False)
for i in range(img.shape[0]):
for j in range(img.shape[1]):
if p[1] > 0.5:
img[i,j] = np.flipud(img[i,j])
if p[2] > 0.5:
img[i,j] = ndimage.rotate(img[i,j], ang, reshape = False)
return img, mask
def crop(self, img, mask):
#randomly take a (16,16) crop from training image
size = 16
while True:
i = np.random.randint(0, 32 - size)
j = np.random.randint(0, 32 - size)
if mask[0, i:i+size, j:j+size].sum() > 0:
break
return img[:,:, i:i+size, j:j+size], mask[0, i:i+size, j:j+size]
def __getitem__(self, idx):
img = self.imgs[idx]
field_mask = self.field_masks[idx]
if self.split_type == 'train':
img, field_mask = self.augment(img, field_mask)
img, field_mask = self.crop(img, field_mask)
return torch.FloatTensor(img[:, self.feat_arr]), torch.FloatTensor(self.areas[idx:idx+1]), torch.FloatTensor(field_mask), self.gts[idx]