-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
84 lines (74 loc) · 1.62 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
from generator import Generator
from parse_dataset import comment_to_tokens
import os
import re
def main():
sequences = []
files = os.listdir(OPTS.test_data_dir)
files.sort()
for i in range(len(files)):
file = files[i]
if not re.match('\d+\.txt', file):
continue
with open(os.path.join(OPTS.test_data_dir, file), 'rb') as f:
content = f.read().decode('utf-8')
for comment in content.split('\n\n'):
tokens = comment_to_tokens(comment)
if len(tokens) >= 10 and len(tokens) <= 80:
sequences.append(tokens)
def callback(i, res_tokens, seed_tokens):
print('')
print(' '.join(seed_tokens))
print('>>>>>>', ' '.join(res_tokens))
print('sequences:', len(sequences))
Generator(
OPTS.weights_file,
OPTS.id2token_file,
OPTS.embedding_size,
OPTS.hidden_size
).generate(
sequences,
forbidden_tokens=OPTS.forbidden_tokens.split(',') if OPTS.forbidden_tokens else (),
max_res_len=OPTS.max_res_len,
callback=callback
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--weights_file',
type=str,
default='./models/weights.h5'
)
parser.add_argument(
'--id2token_file',
type=str,
default='./models/id2token.json'
)
parser.add_argument(
'--test_data_dir',
type=str,
default='./tests/dataset/threads/'
)
parser.add_argument(
'--forbidden_tokens',
type=str,
default='<unk>'
)
parser.add_argument(
'--max_res_len',
type=int,
default=200
)
parser.add_argument(
'--embedding_size',
type=int,
default=1024
)
parser.add_argument(
'--hidden_size',
type=int,
default=1024
)
OPTS = parser.parse_args()
main()