-
Notifications
You must be signed in to change notification settings - Fork 40
/
facade_dataset.py
55 lines (45 loc) · 1.79 KB
/
facade_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import numpy
from PIL import Image
import six
import numpy as np
from io import BytesIO
import os
import pickle
import json
import numpy as np
import skimage.io as io
from chainer.dataset import dataset_mixin
# download `BASE` dataset from http://cmp.felk.cvut.cz/~tylecr1/facade/
class FacadeDataset(dataset_mixin.DatasetMixin):
def __init__(self, dataDir='./facade/base', data_range=(1,300)):
print("load dataset start")
print(" from: %s"%dataDir)
print(" range: [%d, %d)"%(data_range[0], data_range[1]))
self.dataDir = dataDir
self.dataset = []
for i in range(data_range[0],data_range[1]):
img = Image.open(dataDir+"/cmp_b%04d.jpg"%i)
label = Image.open(dataDir+"/cmp_b%04d.png"%i)
w,h = img.size
r = 286 / float(min(w,h))
# resize images so that min(w, h) == 286
img = img.resize((int(r*w), int(r*h)), Image.BILINEAR)
label = label.resize((int(r*w), int(r*h)), Image.NEAREST)
img = np.asarray(img).astype("f").transpose(2,0,1)/128.0-1.0
label_ = np.asarray(label)-1 # [0, 12)
label = np.zeros((12, img.shape[1], img.shape[2])).astype("i")
for j in range(12):
label[j,:] = label_==j
self.dataset.append((img,label))
print("load dataset done")
def __len__(self):
return len(self.dataset)
# return (label, img)
def get_example(self, i, crop_width=256):
_,h,w = self.dataset[i][0].shape
x_l = np.random.randint(0,w-crop_width)
x_r = x_l+crop_width
y_l = np.random.randint(0,h-crop_width)
y_r = y_l+crop_width
return self.dataset[i][1][:,y_l:y_r,x_l:x_r], self.dataset[i][0][:,y_l:y_r,x_l:x_r]