-
Notifications
You must be signed in to change notification settings - Fork 0
/
my_dataset_coco.py
156 lines (124 loc) · 5.16 KB
/
my_dataset_coco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import json
import torch
from PIL import Image
import torch.utils.data as data
from pycocotools.coco import COCO
from train_utils import coco_remove_images_without_annotations, convert_coco_poly_mask
class CocoDetection(data.Dataset):
"""`MS Coco Detection <https://cocodataset.org/>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
dataset (string): train or val.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
def __init__(self, root, dataset="train", transforms=None, years="2017"):
super(CocoDetection, self).__init__()
assert dataset in ["train", "val"], 'dataset must be in ["train", "val"]'
anno_file = f"instances_{dataset}{years}.json"
assert os.path.exists(root), "file '{}' does not exist.".format(root)
self.img_root = os.path.join(root, f"{dataset}{years}")
assert os.path.exists(self.img_root), "path '{}' does not exist.".format(
self.img_root
)
self.anno_path = os.path.join(root, "annotations", anno_file)
assert os.path.exists(self.anno_path), "file '{}' does not exist.".format(
self.anno_path
)
self.mode = dataset
self.transforms = transforms
self.coco = COCO(self.anno_path)
# 获取coco数据索引与类别名称的关系
# 注意在object80中的索引并不是连续的,虽然只有80个类别,但索引还是按照stuff91来排序的
data_classes = dict([(v["id"], v["name"]) for k, v in self.coco.cats.items()])
max_index = max(data_classes.keys()) # 90
# 将缺失的类别名称设置成N/A
coco_classes = {}
for k in range(1, max_index + 1):
if k in data_classes:
coco_classes[k] = data_classes[k]
else:
coco_classes[k] = "N/A"
if dataset == "train":
json_str = json.dumps(coco_classes, indent=4)
with open("coco91_indices.json", "w") as f:
f.write(json_str)
self.coco_classes = coco_classes
ids = list(sorted(self.coco.imgs.keys()))
if dataset == "train":
# 移除没有目标,或者目标面积非常小的数据
valid_ids = coco_remove_images_without_annotations(self.coco, ids)
self.ids = valid_ids
else:
self.ids = ids
def parse_targets(
self, img_id: int, coco_targets: list, w: int = None, h: int = None
):
assert w > 0
assert h > 0
# 只筛选出单个对象的情况
anno = [obj for obj in coco_targets if obj['iscrowd'] == 0]
boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
# [xmin, ymin, w, h] -> [xmin, ymin, xmax, ymax]
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_mask(segmentations, h, w)
# 筛选出合法的目标,即x_max>x_min且y_max>y_min
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
masks = masks[keep]
area = area[keep]
iscrowd = iscrowd[keep]
target = {}
target["boxes"] = boxes
target["labels"] = classes
target["masks"] = masks
target["image_id"] = torch.tensor([img_id])
# for conversion to coco api
target["area"] = area
target["iscrowd"] = iscrowd
return target
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
"""
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
coco_target = coco.loadAnns(ann_ids)
path = coco.loadImgs(img_id)[0]['file_name']
img = Image.open(os.path.join(self.img_root, path)).convert('RGB')
w, h = img.size
target = self.parse_targets(img_id, coco_target, w, h)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.ids)
def get_height_and_width(self, index):
coco = self.coco
img_id = self.ids[index]
img_info = coco.loadImgs(img_id)[0]
w = img_info["width"]
h = img_info["height"]
return h, w
@staticmethod
def collate_fn(batch):
return tuple(zip(*batch))
if __name__ == '__main__':
train = CocoDetection("/data/coco2017", dataset="train")
print(len(train))
t = train[0]