-
Notifications
You must be signed in to change notification settings - Fork 161
/
beam_search.py
119 lines (96 loc) · 3.46 KB
/
beam_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Beam search implementation in PyTorch."""
#
#
# hyp1#-hyp1---hyp1 -hyp1
# \ /
# hyp2 \-hyp2 /-hyp2#hyp2
# / \
# hyp3#-hyp3---hyp3 -hyp3
# ========================
#
# Takes care of beams, back pointers, and scores.
# Code borrowed from PyTorch OpenNMT example
# https://github.com/pytorch/examples/blob/master/OpenNMT/onmt/Beam.py
import torch
class Beam(object):
"""Ordered beam of candidate outputs."""
def __init__(self, size, vocab, cuda=False):
"""Initialize params."""
self.size = size
self.done = False
self.pad = vocab['<pad>']
self.bos = vocab['<s>']
self.eos = vocab['</s>']
self.tt = torch.cuda if cuda else torch
# The score for each translation on the beam.
self.scores = self.tt.FloatTensor(size).zero_()
# The backpointers at each time-step.
self.prevKs = []
# The outputs at each time-step.
self.nextYs = [self.tt.LongTensor(size).fill_(self.pad)]
self.nextYs[0][0] = self.bos
# The attentions (matrix) for each time.
self.attn = []
# Get the outputs for the current timestep.
def get_current_state(self):
"""Get state of beam."""
return self.nextYs[-1]
# Get the backpointers for the current timestep.
def get_current_origin(self):
"""Get the backpointer to the beam at this step."""
return self.prevKs[-1]
# Given prob over words for every last beam `wordLk` and attention
# `attnOut`: Compute and update the beam search.
#
# Parameters:
#
# * `wordLk`- probs of advancing from the last step (K x words)
# * `attnOut`- attention at the last step
#
# Returns: True if beam search is complete.
def advance(self, workd_lk):
"""Advance the beam."""
num_words = workd_lk.size(1)
# Sum the previous scores.
if len(self.prevKs) > 0:
beam_lk = workd_lk + self.scores.unsqueeze(1).expand_as(workd_lk)
else:
beam_lk = workd_lk[0]
flat_beam_lk = beam_lk.view(-1)
bestScores, bestScoresId = flat_beam_lk.topk(self.size, 0, True, True)
self.scores = bestScores
# bestScoresId is flattened beam x word array, so calculate which
# word and beam each score came from
prev_k = bestScoresId / num_words
self.prevKs.append(prev_k)
self.nextYs.append(bestScoresId - prev_k * num_words)
# End condition is when top-of-beam is EOS.
if self.nextYs[-1][0] == self.eos:
self.done = True
return self.done
def sort_best(self):
"""Sort the beam."""
return torch.sort(self.scores, 0, True)
# Get the score of the best in the beam.
def get_best(self):
"""Get the most likely candidate."""
scores, ids = self.sort_best()
return scores[1], ids[1]
# Walk back to construct the full hypothesis.
#
# Parameters.
#
# * `k` - the position in the beam to construct.
#
# Returns.
#
# 1. The hypothesis
# 2. The attention at each time step.
def get_hyp(self, k):
"""Get hypotheses."""
hyp = []
# print(len(self.prevKs), len(self.nextYs), len(self.attn))
for j in range(len(self.prevKs) - 1, -1, -1):
hyp.append(self.nextYs[j + 1][k])
k = self.prevKs[j][k]
return hyp[::-1]