forked from alexandru-dinu/cae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_folder.py
executable file
·62 lines (48 loc) · 1.82 KB
/
image_folder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import glob
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
class BaseImageFolder(Dataset):
def __init__(self, folder_path):
self.files = sorted(glob.glob('%s/*.*' % folder_path))
def getitem(self, index, w, h):
path = self.files[index % len(self.files)]
img = Image.open(path)
img = img.resize((w*128, h*128))
# To ensure that the input image is RGB mode
# For Ex. ADE_train_00008455.jpg is not an RGB image.
# It will cause axes don't match in numpy.transpose().
img = img.convert('RGB')
img = np.array(img)
img = img / 255.0
img = np.transpose(img, (2, 0, 1))
img = torch.from_numpy(img).float()
patches = np.reshape(img, (3, h, 128, w, 128))
patches = np.transpose(patches, (0, 1, 3, 2, 4))
return img, patches, path
def __getitem__(self, index):
pass
def get_random(self):
i = np.random.randint(0, len(self.files))
return self[i]
def __len__(self):
return len(self.files)
# Image shape is 6x10 128x128 patches
class ImageFolder720p(BaseImageFolder):
def __getitem__(self, index):
return self.getitem(index=index, w=10, h=6)
# Image shape is 8x16 128x128 patches
class ImageFolder2K(BaseImageFolder):
def __getitem__(self, index):
return self.getitem(index=index, w=16, h=8)
# Image shape is 8x16 128x128 patches
class ImageFolder1024sqr(BaseImageFolder):
def __getitem__(self, index):
return self.getitem(index=index, w=16, h=16)
# Automatically resize input image
class ImageFolderAuto(BaseImageFolder):
def __getitem__(self, index):
path = self.files[index % len(self.files)]
(w, h) = Image.open(path).size
return self.getitem(index=index, w=w//128, h=h//128)