-
Notifications
You must be signed in to change notification settings - Fork 0
/
DataLoader.py
63 lines (51 loc) · 2.1 KB
/
DataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import numpy as np
from skimage import io, transform
import pandas as pd
"""
info about: torch.utils.data.DataLoader
https://pytorch.org/docs/stable/data.html
"""
def get_train_dataloader(p, dataset):
return torch.utils.data.DataLoader(dataset, num_workers=p['num_workers'],
batch_size=p['batch_size'], pin_memory=True, collate_fn=collate_custom,
drop_last=True, shuffle=True)
class P2DataLoader():
def __init__(self, csv_file, root='', train=True, transform=None,):
# path to image data
self.csv_data = pd.read_csv(csv_file)
self.target = np.array(self.csv_data['Sex (subj)']) # label of image
self.im_file = np.array(self.csv_data['Image File']) # label of image
self.h = np.array(self.csv_data['Image Height'])
self.w = np.array(self.csv_data['Image Width'])
self.x1 = np.array(self.csv_data['X (top left)'])
self.x2 = np.array(self.csv_data['X (bottom right)'])
self.y1 = np.array(self.csv_data['Y (top left)'])
self.y2 = np.array(self.csv_data['Y (bottom right)'])
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
dict: {'image': image, 'target': index of target class, 'meta': dict}
"""
#make slicer from bbox
img, target, h, w = io.imread(f'{root}/{self.im_file[index]}'), self.targets[index], self.h[index], self.w[index]
# slicer from bbox
# might have my x's and y's backwards
img = img[self.y1[index]:self.y2[index],self.x1[index]:self.x2[index]]
# resize to a standard size
img = img.resize((89, 80), Image.ANTIALIAS)
"""
I have a transform library we can use here
"""
# if self.transform is not None:
# img = self.transform(img)
out = {'image': img,
'target': target,
'meta': {'im_size': (h, w), 'index': index, 'class_ID': target}}
return out
def get_image(self, index):
img = index
return img
def __len__(self):
return len(self.data)