-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
82 lines (58 loc) · 2.14 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch.utils.data as data
from torchvision import transforms
from PIL import Image
import glob
import random
import os
random.seed(1143)
def populate_train_list(lowlight_images_path):
image_list_lowlight = glob.glob(lowlight_images_path + "*")
train_list = image_list_lowlight
random.shuffle(train_list)
return train_list
class lowlight_loader(data.Dataset):
def __init__(self, lowlight_images_path, patch_size):
self.train_list = populate_train_list(lowlight_images_path)
self.patch = patch_size
self.data_list = self.train_list
print("Total training examples:", len(self.train_list))
def __getitem__(self, index):
# fetch image
fn = self.data_list[index]
im = Image.open(fn).convert("RGB")
transformer = transforms.Compose(
[
transforms.Resize(size=(self.patch, self.patch)),
transforms.ToTensor(),
]
)
im = transformer(im)
return im
def __len__(self):
return len(self.data_list)
def populate_train_list_contrast(lowlight_images_path):
image_list_lowlight_low = glob.glob(os.path.join(lowlight_images_path, 'low', "*"))
image_list_lowlight_normal = glob.glob(os.path.join(lowlight_images_path, 'normal', "*"))
train_list = image_list_lowlight_low + image_list_lowlight_normal
random.shuffle(train_list)
return train_list
class lowlight_loader_contrast(data.Dataset):
def __init__(self, lowlight_images_path, patch_size):
self.train_list = populate_train_list_contrast(lowlight_images_path)
self.patch = patch_size
self.data_list = self.train_list
print("Total training examples:", len(self.train_list))
def __getitem__(self, index):
# fetch image
fn = self.data_list[index]
im = Image.open(fn).convert("RGB")
transformer = transforms.Compose(
[
transforms.Resize(size=(self.patch, self.patch)),
transforms.ToTensor(),
]
)
im = transformer(im)
return im
def __len__(self):
return len(self.data_list)