-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
71 lines (55 loc) · 2.19 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# See the LICENSE file for more details.
import experiment
import vis
import argparse
def main():
parser = argparse.ArgumentParser(description='Bounded logit attention.')
parser.add_argument('-d','--dataset', required=False, default='cats_vs_dogs',
choices=['cats_vs_dogs', 'stanford_dogs', 'caltech_birds2011'])
parser.add_argument('-f', '--fixed-size', action='store_true')
parser.add_argument('-t', '--threshold', action='store_true')
parser.add_argument('-p', '--post-hoc', action='store_true')
parser.add_argument('-r', '--force-retraining', action='store_true')
parser.add_argument('-s','--preset', required=False,
choices=['L2X-F', 'BLA', 'BLA-T', 'BLA-PH'])
args = vars(parser.parse_args())
fixed_size = args['fixed_size']
threshold = args['threshold']
train_head = not args['post_hoc']
force_retrain = args['force_retraining']
ds = args['dataset']
preset = args['preset']
if preset=='L2X-F':
fixed_size = True
threshold = False
train_head = True
elif preset=='BLA':
fixed_size = False
threshold = False
train_head = True
elif preset=='BLA-T':
fixed_size = False
threshold = True
train_head = True
elif preset=='BLA-PH':
fixed_size = False
threshold = True
train_head = False
wrapper = experiment.make_wrapper(ds,
fixed_size=fixed_size,
threshold=threshold,
train_head=train_head,
force_retrain=force_retrain
)
vis.visualize(wrapper)
if __name__ == "__main__":
main()