-
Notifications
You must be signed in to change notification settings - Fork 21
/
infer.py
54 lines (40 loc) · 1.55 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import argparse
import logging
import torch
from machine.evaluator import Predictor
from machine.util.checkpoint import Checkpoint
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_path',
help='Give the checkpoint path from which to load the model')
parser.add_argument('--cuda_device', default=0, type=int,
help='Set cuda device to use')
parser.add_argument('--debug', action='store_true', help=argparse.SUPPRESS)
parser.add_argument('--log-level', default='info', help='Logging level.')
opt = parser.parse_args()
LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
logging.basicConfig(format=LOG_FORMAT, level=getattr(
logging, opt.log_level.upper()))
logging.info(opt)
if torch.cuda.is_available():
logging.info("Cuda device set to %i" % opt.cuda_device)
torch.cuda.set_device(opt.cuda_device)
##########################################################################
# load model
logging.info("loading checkpoint from {}".format(
os.path.join(opt.checkpoint_path)))
checkpoint = Checkpoint.load(opt.checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab
##########################################################################
# Generate predictor
predictor = Predictor(seq2seq, input_vocab, output_vocab)
if opt.debug:
exit()
while True:
seq_str = input("\n\nType in a source sequence: ")
if seq_str == 'q':
exit()
seq = seq_str.strip().split()
print(predictor.predict(seq))