-
Notifications
You must be signed in to change notification settings - Fork 14
/
train_ontoemma.py
100 lines (90 loc) · 3.49 KB
/
train_ontoemma.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python
import os
import sys
import getopt
import nltk
import ssl
from emma.OntoEmma import OntoEmma
import emma.constants
def main(argv):
model_path = None
model_type = "nn"
config_file = None
evaluate_flag = False
evaluation_data_file = None
cuda_device = -1
try:
nltk.data.find("corpora/stopwords")
except LookupError:
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.download("stopwords")
try:
# TODO(waleeda): use argparse instead of getopt to parse command line arguments.
opts, args = getopt.getopt(
argv, "hec:m:p:d:g:", ["config=", "model_path=", "model_type=", "evaluation_data_file=", "cuda_device="]
)
except getopt.GetoptError:
sys.stdout.write('Unknown option... -h or --help for help.\n')
sys.exit(1)
for opt, arg in opts:
if opt in ("-h", "--help"):
sys.stdout.write('Options: \n')
sys.stdout.write('-c <configuration_file>\n')
sys.stdout.write('-m <model_location>\n')
sys.stdout.write('-p <model_type>\n')
sys.stdout.write('-e # evaluation mode\n')
sys.stdout.write('-d <evaluation_data_file>\n')
sys.stdout.write('-g <cuda_device>\n\n')
sys.stdout.write('Example usages: \n')
sys.stdout.write(
' ./train_ontoemma.py -c configuration_file.json -m model_file_path -p nn\n'
)
sys.stdout.write(
' ./train_ontoemma.py -e -m model_file_path -d evaluation_data_path -g 5\n'
)
sys.stdout.write('-------------------------\n')
sys.stdout.write('Accepted model types: nn (neural network), lr (logistic regression)\n')
sys.stdout.write('-------------------------\n')
sys.stdout.write('\n')
sys.exit(0)
elif opt in ("-e", "--evaluate"):
evaluate_flag = True
sys.stdout.write('Evaluation mode\n')
elif opt in ("-c", "--config"):
config_file = os.path.abspath(arg)
sys.stdout.write('Configuration file is %s\n' % config_file)
elif opt in ("-m", "--model"):
model_path = os.path.abspath(arg)
sys.stdout.write('Model output path is %s\n' % model_path)
elif opt in ("-p", "--model-type"):
if arg in emma.constants.IMPLEMENTED_MODEL_TYPES:
model_type = arg
sys.stdout.write(
'Model type is %s\n' % emma.constants.IMPLEMENTED_MODEL_TYPES[model_type]
)
else:
sys.stdout.write('Error: Unknown model type...\n')
sys.exit(1)
elif opt in ("-d", "--eval-data-file"):
evaluation_data_file = os.path.abspath(arg)
sys.stdout.write('Evaluation data file is %s\n' % evaluation_data_file)
elif opt in ("-g", "--cuda-device"):
cuda_device = int(arg)
sys.stdout.write('Using CUDA device %i\n' % cuda_device)
sys.stdout.write('\n')
matcher = OntoEmma()
if evaluate_flag:
matcher.evaluate(
model_type, model_path, evaluation_data_file, cuda_device
)
else:
matcher.train(
model_type, model_path, config_file
)
if __name__ == "__main__":
main(sys.argv[1:])