-
Notifications
You must be signed in to change notification settings - Fork 1
/
gm_dataset.py
192 lines (166 loc) · 7.03 KB
/
gm_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
import random
from VOC import PascalVOC
from SPair71k import SPair71K
# from data.willow_obj import WillowObject
from utils.build_graphs import build_graphs
from utils.config import cfg
from torch_geometric.data import Data, Batch
datasets = {"PascalVOC": PascalVOC,
# "WillowObject": WillowObject,
"SPair71k": SPair71K}
class GMDataset(Dataset):
def __init__(self, name, length, **args):
self.name = name
self.ds = datasets[name](**args)
self.true_epochs = length is None
self.length = (
self.ds.total_size if self.true_epochs else length
) # NOTE images pairs are sampled randomly, so there is no exact definition of dataset size
if self.true_epochs:
print(f"Initializing {self.ds.set}-set with all {self.length} examples.")
else:
print(f"Initializing {self.ds.set}-set. Randomly sampling {self.length} examples.")
# length here represents the iterations between two checkpoints
# if length is None the length is set to the size of the ds
self.obj_size = self.ds.obj_resize
self.classes = self.ds.classes
self.cls = None
self.num_graphs_in_matching_instance = None
def set_cls(self, cls):
if cls == "none":
cls = None
self.cls = cls
if self.true_epochs: # Update length of dataset for dataloader according to class
self.length = self.ds.total_size if cls is None else self.ds.size_by_cls[cls]
def set_num_graphs(self, num_graphs_in_matching_instance):
self.num_graphs_in_matching_instance = num_graphs_in_matching_instance
def __len__(self):
return self.length
def __getitem__(self, idx):
# sampling_strategy = cfg.train_sampling if self.ds.set == "train" else cfg.eval_sampling
sampling_strategy = "intersection"
if self.num_graphs_in_matching_instance is None:
raise ValueError("Num_graphs has to be set to an integer value.")
idx = idx if self.true_epochs else None
anno_list, perm_mat_list = self.ds.get_k_samples(idx, k=self.num_graphs_in_matching_instance, cls=self.cls, mode=sampling_strategy)
for perm_mat in perm_mat_list:
if (
not perm_mat.size
or (perm_mat.size < 2 * 2 and sampling_strategy == "intersection")
and not self.true_epochs
):
# 'and not self.true_epochs' because we assume all data is valid when sampling a true epoch
next_idx = None if idx is None else idx + 1
return self.__getitem__(next_idx)
points_gt = [np.array([(kp["x"], kp["y"]) for kp in anno_dict["keypoints"]]) for anno_dict in anno_list]
n_points_gt = [len(p_gt) for p_gt in points_gt]
graph_list = []
for p_gt, n_p_gt in zip(points_gt, n_points_gt):
edge_indices, edge_features = build_graphs(p_gt, n_p_gt)
# Add dummy node features so the __slices__ of them is saved when creating a batch
pos = torch.tensor(p_gt).to(torch.float32) / 256.0
assert (pos > -1e-5).all(), p_gt
graph = Data(
edge_attr=torch.tensor(edge_features).to(torch.float32),
edge_index=torch.tensor(edge_indices, dtype=torch.long),
x=pos,
pos=pos,
)
graph.num_nodes = n_p_gt
graph_list.append(graph)
ret_dict = {
"Ps": [torch.Tensor(x).unsqueeze(0) for x in points_gt],
"ns": [torch.tensor(x).unsqueeze(0) for x in n_points_gt],
"gt_perm_mat": [torch.from_numpy(x) for x in perm_mat_list],
"edges": graph_list,
}
imgs = [anno["image"] for anno in anno_list]
if imgs[0] is not None:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize(cfg.NORM_MEANS, cfg.NORM_STD)])
imgs = [trans(img).unsqueeze(0) for img in imgs]
ret_dict["images"] = imgs
elif "feat" in anno_list[0]["keypoints"][0]:
feat_list = [np.stack([kp["feat"] for kp in anno_dict["keypoints"]], axis=-1) for anno_dict in anno_list]
ret_dict["features"] = [torch.Tensor(x) for x in feat_list]
return ret_dict
def collate_fn(data: list):
"""
Create mini-batch data for training.
:param data: data dict
:return: mini-batch
"""
def pad_tensor(inp):
assert type(inp[0]) == torch.Tensor
it = iter(inp)
t = next(it)
max_shape = list(t.shape)
while True:
try:
t = next(it)
for i in range(len(max_shape)):
max_shape[i] = int(max(max_shape[i], t.shape[i]))
except StopIteration:
break
max_shape = np.array(max_shape)
padded_ts = []
for t in inp:
pad_pattern = np.zeros(2 * len(max_shape), dtype=np.int64)
pad_pattern[::-2] = max_shape - np.array(t.shape)
pad_pattern = tuple(pad_pattern.tolist())
padded_ts.append(F.pad(t, pad_pattern, "constant", 0))
return padded_ts
def stack(inp):
if type(inp[0]) == list:
ret = []
for vs in zip(*inp):
ret.append(stack(vs))
elif type(inp[0]) == dict:
ret = {}
for kvs in zip(*[x.items() for x in inp]):
ks, vs = zip(*kvs)
for k in ks:
assert k == ks[0], "Key value mismatch."
ret[k] = stack(vs)
elif type(inp[0]) == torch.Tensor:
new_t = pad_tensor(inp)
ret = torch.stack(new_t, 0)
elif type(inp[0]) == np.ndarray:
new_t = pad_tensor([torch.from_numpy(x) for x in inp])
ret = torch.stack(new_t, 0)
elif type(inp[0]) == str:
ret = inp
elif type(inp[0]) == Data: # Graph from torch.geometric, create a batch
ret = Batch.from_data_list(inp)
else:
raise ValueError("Cannot handle type {}".format(type(inp[0])))
return ret
ret = stack(data)
return ret
def worker_init_fix(worker_id):
"""
Init dataloader workers with fixed seed.
"""
random.seed(cfg.RANDOM_SEED + worker_id)
np.random.seed(cfg.RANDOM_SEED + worker_id)
def worker_init_rand(worker_id):
"""
Init dataloader workers with torch.initial_seed().
torch.initial_seed() returns different seeds when called from different dataloader threads.
"""
random.seed(torch.initial_seed())
np.random.seed(torch.initial_seed() % 2 ** 32)
def get_dataloader(dataset, fix_seed=True, shuffle=False):
return torch.utils.data.DataLoader(
dataset,
batch_size=cfg.BATCH_SIZE,
shuffle=shuffle,
num_workers=2,
collate_fn=collate_fn,
pin_memory=False,
worker_init_fn=worker_init_fix if fix_seed else worker_init_rand,
)