-
Notifications
You must be signed in to change notification settings - Fork 0
/
samplers.py
140 lines (120 loc) · 4.93 KB
/
samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from __future__ import absolute_import
import copy
import random
import torch
from collections import defaultdict
import numpy as np
from torch.utils.data.sampler import Sampler
class RandomIdentitySampler(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
Args:
data_source (Dataset): dataset to sample from.
num_instances (int): number of instances per identity.
"""
def __init__(self, data_source, num_instances=4):
super(RandomIdentitySampler).__init__()
self.data_source = data_source
self.num_instances = num_instances
self.index_dic = defaultdict(list)
for index, (_, pid, _) in enumerate(data_source):
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
self.num_identities = len(self.pids)
def __iter__(self):
indices = torch.randperm(self.num_identities)
ret = []
for i in indices:
pid = self.pids[i]
t = self.index_dic[pid]
replace = False if len(t) >= self.num_instances else True
t = np.random.choice(t, size=self.num_instances, replace=replace)
ret.extend(t)
# print(ret)
return iter(ret)
def __len__(self):
return self.num_identities * self.num_instances
class RandomIdentitySamplerStrongBasaline(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Args:
- data_source (list): list of (img_path, pid, camid).
- num_instances (int): number of instances per identity in a batch.
- batch_size (int): number of examples in a batch.
"""
def __init__(self, data_source, batch_size, num_instances):
self.data_source = data_source
self.batch_size = batch_size
self.num_instances = num_instances
self.num_pids_per_batch = self.batch_size // self.num_instances
self.index_dic = defaultdict(list)
for index, (_, pid, _) in enumerate(self.data_source):
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
# estimate number of examples in an epoch
self.length = 0
for pid in self.pids:
idxs = self.index_dic[pid]
num = len(idxs)
if num < self.num_instances:
num = self.num_instances
self.length += num - num % self.num_instances
def __iter__(self):
batch_idxs_dict = defaultdict(list)
for pid in self.pids:
idxs = copy.deepcopy(self.index_dic[pid])
if len(idxs) < self.num_instances:
idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
random.shuffle(idxs)
batch_idxs = []
for idx in idxs:
batch_idxs.append(idx)
if len(batch_idxs) == self.num_instances:
batch_idxs_dict[pid].append(batch_idxs)
batch_idxs = []
avai_pids = copy.deepcopy(self.pids)
final_idxs = []
while len(avai_pids) >= self.num_pids_per_batch:
selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[pid]) == 0:
avai_pids.remove(pid)
self.length = len(final_idxs)
return iter(final_idxs)
def __len__(self):
return self.length
# New add by gu
class RandomIdentitySampler_alignedreid(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
Args:
data_source (Dataset): dataset to sample from.
num_instances (int): number of instances per identity.
"""
def __init__(self, data_source, num_instances):
self.data_source = data_source
self.num_instances = num_instances
self.index_dic = defaultdict(list)
for index, (_, pid, _) in enumerate(data_source):
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
self.num_identities = len(self.pids)
def __iter__(self):
indices = torch.randperm(self.num_identities)
ret = []
for i in indices:
pid = self.pids[i]
t = self.index_dic[pid]
replace = False if len(t) >= self.num_instances else True
t = np.random.choice(t, size=self.num_instances, replace=replace)
ret.extend(t)
return iter(ret)
def __len__(self):
return self.num_identities * self.num_instances