-
Notifications
You must be signed in to change notification settings - Fork 6
/
examine.py
107 lines (90 loc) · 2.83 KB
/
examine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import logging
import pickle
from pathlib import Path
from pdb import set_trace
import coloredlogs
import enlighten
import networkx as nx
import torch
import yaml
from torchvision import io
from dnn.dnn_factory import DNN_Factory
from utilities.bbox_utils import jaccard
from utilities.results_utils import read_results, write_results
from utilities.video_utils import read_bandwidth
def main(args):
logger = logging.getLogger("examine")
handler = logging.NullHandler()
logger.addHandler(handler)
bws = [read_bandwidth(video) for video in args.inputs]
video_names = args.inputs
app = DNN_Factory().get_model(args.app)
ground_truth_dict = read_results(args.ground_truth, app.name, logger)
for video_name, bw in zip(video_names, bws):
video_dict = read_results(video_name, app.name, logger)
metrics = app.calc_accuracy(video_dict, ground_truth_dict, args)
res = {
"application": app.name,
"video_name": video_name,
"bw": bw,
"ground_truth_name": args.ground_truth,
"gt_conf": float(args.gt_confidence_threshold),
"conf": float(args.confidence_threshold),
}
res.update(metrics)
with open(args.stats, "a") as f:
f.write(yaml.dump([res]))
if __name__ == "__main__":
# set the format of the logger
coloredlogs.install(
fmt="%(asctime)s [%(levelname)s] %(name)s:%(funcName)s[%(lineno)s] -- %(message)s",
level="INFO",
)
parser = argparse.ArgumentParser()
parser.add_argument("--stats", type=str, required=True)
parser.add_argument(
"-i",
"--inputs",
type=str,
help="The video file names to obtain inference results.",
required=True,
nargs="+",
)
parser.add_argument(
"--app", type=str, help="The name of the model.", required=True,
)
parser.add_argument(
"-g",
"--ground_truth",
type=str,
help="The ground-truth video name.",
required=True,
)
parser.add_argument(
"--confidence_threshold",
type=float,
help="The confidence score threshold for calculating accuracy.",
default=0.7,
)
parser.add_argument(
"--gt_confidence_threshold",
type=float,
help="The confidence score threshold for calculating accuracy.",
default=0.7,
)
parser.add_argument(
"--iou_threshold",
type=float,
help="The IoU threshold for calculating accuracy in object detection.",
default=0.5,
)
parser.add_argument("--size_bound", type=float, default=0.05)
parser.add_argument(
"--dist_thresh",
type=float,
help="Distance thresh for accuracy calculation.",
default=3,
)
args = parser.parse_args()
main(args)