-
Notifications
You must be signed in to change notification settings - Fork 2
/
kenlm_perplexity.py
85 lines (59 loc) · 2.22 KB
/
kenlm_perplexity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python3
# Add KenLM model perplexity value to JSONL with text as 'text'.
import sys
import json
import unicodedata
import kenlm
from argparse import ArgumentParser
from berttokenizer import basic_tokenize
# Default key for the perplexity value in JSONL
DEFAULT_KEY = 'ppl_kenlm'
def argparser():
ap = ArgumentParser()
ap.add_argument('-k', '--key', default=DEFAULT_KEY,
help='key for perplexity value')
ap.add_argument('model', help='KenLM model')
ap.add_argument('jsonl', nargs='+')
return ap
def is_punct(string):
return all(unicodedata.category(c).startswith('P') for c in string)
def word_count(tokenized, args):
return sum(not is_punct(t) for t in tokenized)
def tokenize(text, args):
for line in text.split('\n'):
if line and not line.isspace():
yield basic_tokenize(line)
def add_perplexity(fn, model, args):
with open(fn) as f:
for ln, line in enumerate(f, start=1):
try:
data = json.loads(line)
text = data['text']
except:
logging.error(f'parsing line {ln} in {fn}: {line}')
raise
total_score, total_words = 0, 0
for tokenized in tokenize(text, args):
total_score += model.score(' '.join(tokenized))
total_words += word_count(tokenized, args) + 1 # +1 for EOS
perplexity = 10**(-total_score/total_words)
if 'meta' not in data:
data['meta'] = {}
else:
assert args.key not in data['meta']
data['meta'][args.key] = int(perplexity)
if ln % 100 == 0:
print(f'Processed {ln} ...', file=sys.stderr, flush=True)
flush = True
else:
flush = False
print(json.dumps(data, ensure_ascii=False), flush=flush)
def main(argv):
args = argparser().parse_args(argv[1:])
print(f'loading model ... ', file=sys.stderr, end='', flush=True)
model = kenlm.Model(args.model)
print(f'done.', file=sys.stderr, flush=True)
for fn in args.jsonl:
add_perplexity(fn, model, args)
if __name__ == '__main__':
sys.exit(main(sys.argv))