-
Notifications
You must be signed in to change notification settings - Fork 2
/
track_vot.py
77 lines (56 loc) · 1.95 KB
/
track_vot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) SenseTime. All Rights Reserved.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import os
import sys
import cv2
import torch
import vot_tool as vot
from siamban.core.config import cfg
from siamban.models.model_builder import ModelBuilder
from siamban.tracker.tracker_builder import build_tracker
from siamban.utils.bbox import corner2center, Corner
from siamban.utils.model_load import load_pretrain
parser = argparse.ArgumentParser(description='siamese tracking')
parser.add_argument('--config', default='config.yaml', type=str,
help='config file')
parser.add_argument('--snapshot', default='model.pth', type=str, help='config file')
args = parser.parse_args()
torch.set_num_threads(1)
def main():
# load config
cfg.merge_from_file(args.config)
cur_dir = os.path.dirname(os.path.realpath(__file__))
# create model
model = ModelBuilder()
# load model
model = load_pretrain(model, args.snapshot).cuda().eval()
# build tracker
tracker = build_tracker(model)
handle = vot.VOT("rectangle")
region = handle.region()
imagefile = handle.frame()
if not imagefile:
sys.exit(0)
img = cv2.imread(imagefile)
left = max(region.x, 0)
top = max(region.y, 0)
right = min(region.x + region.width, img.shape[1] - 1)
bottom = min(region.y + region.height, img.shape[0] - 1)
cx, cy, w, h = corner2center(Corner(left, top, right, bottom))
gt_bbox_ = [cx - (w - 1) / 2, cy - (h - 1) / 2, w, h]
tracker.init(img, gt_bbox_)
while True:
imagefile = handle.frame()
if not imagefile:
break
image = cv2.imread(imagefile)
outputs = tracker.track(image)
pred_bbox = outputs['bbox']
conf = outputs['best_score']
handle.report(vot.Rectangle(*pred_bbox), conf)
if __name__ == '__main__':
main()