-
Notifications
You must be signed in to change notification settings - Fork 19
/
test.py
103 lines (87 loc) · 3.8 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import os
import time
import torch
from utils import utils
from dataset import load_data
from extract import extract
from evaluate.evaluate import Benchmark
from evaluate.matcher import Matcher
from evaluate.generalReader import GeneralReader
from carb.carb import Benchmark as CarbBenchmark
from carb.matcher import Matcher as CarbMatcher
from carb.tabReader import TabReader
def get_performance(output_path, gold_path):
auc, precision, recall, f1 = [None for _ in range(4)]
if 'evaluate' in gold_path:
matching_func = Matcher.lexicalMatch
error_fn = os.path.join(output_path, 'error_idxs.txt')
evaluator = Benchmark(gold_path)
reader = GeneralReader()
reader.read(os.path.join(output_path, 'extraction.txt'))
(precision, recall, f1), auc = evaluator.compare(
predicted=reader.oie,
matchingFunc=matching_func,
output_fn=os.path.join(output_path, 'pr_curve.txt'),
error_file=error_fn)
elif 'carb' in gold_path:
matching_func = CarbMatcher.binary_linient_tuple_match
error_fn = os.path.join(output_path, 'error_idxs.txt')
evaluator = CarbBenchmark(gold_path)
reader = TabReader()
reader.read(os.path.join(output_path, 'extraction.txt'))
auc, (precision, recall, f1) = evaluator.compare(
predicted=reader.oie,
matchingFunc=matching_func,
output_fn=os.path.join(output_path, 'pr_curve.txt'),
error_file=error_fn)
return auc, precision, recall, f1
def do_eval(output_path, gold_path):
auc, prec, rec, f1 = get_performance(output_path, gold_path)
eval_results = [f1, prec, rec, auc]
return eval_results
def main(args):
model = utils.get_models(
bert_config=args.bert_config,
pred_n_labels=args.pred_n_labels,
arg_n_labels=args.arg_n_labels,
n_arg_heads=args.n_arg_heads,
n_arg_layers=args.n_arg_layers,
pos_emb_dim=args.pos_emb_dim,
use_lstm=args.use_lstm,
device=args.device)
model.load_state_dict(torch.load(args.model_path))
model.zero_grad()
model.eval()
loader = load_data(
data_path=args.test_data_path,
batch_size=args.batch_size,
tokenizer_config=args.bert_config,
train=False)
start = time.time()
extract(args, model, loader, args.save_path)
print("TIME: ", time.time() - start)
test_results = do_eval(args.save_path, args.test_gold_path)
utils.print_results("TEST RESULT", test_results, ["F1 ", "PREC", "REC ", "AUC "])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', default='./results/model.bin')
parser.add_argument('--save_path', default='./results/carb_test')
parser.add_argument('--bert_config', default='bert-base-cased')
parser.add_argument('--test_data_path', default='./datasets/carb_test.pkl')
parser.add_argument('--test_gold_path', default='./carb/CaRB_test.tsv')
parser.add_argument('--device', default='cuda:0')
parser.add_argument('--visible_device', default="0")
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--pos_emb_dim', type=int, default=64)
parser.add_argument('--n_arg_heads', type=int, default=8)
parser.add_argument('--n_arg_layers', type=int, default=4)
parser.add_argument('--use_lstm', nargs='?', const=True, default=False, type=utils.str2bool)
parser.add_argument('--binary', nargs='?', const=True, default=False, type=utils.str2bool)
main_args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = main_args.visible_device
main_args.pred_n_labels = 3
main_args.arg_n_labels = 9
device = torch.device(main_args.device if torch.cuda.is_available() else 'cpu')
main_args.device = device
main(main_args)