Skip to content

Commit

Permalink
Merge pull request #11 from batra-mlp-lab/beamsearch-fix
Browse files Browse the repository at this point in the history
Beamsearch bug fix
  • Loading branch information
nirbhayjm authored Aug 22, 2018
2 parents 60cf6a5 + e5d4416 commit 1fb7e88
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions visdial/models/decoders/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,28 @@ def beamSearchDecoder(self, initStates, beamSize, maxSeqLen):
beamTokensTable[:, :, t] = tokensArray.gather(1, topIdx)
backIndices[:, :, t] = backIndexArray.gather(1, topIdx)

# Update corresponding hidden and cell states for next time step
hiddenCurrent, cellCurrent = hiddenStates

# Reshape to get explicit beamSize dim
original_state_size = hiddenCurrent.size()
num_layers, _, rnnHiddenSize = original_state_size
hiddenCurrent = hiddenCurrent.view(
num_layers, batchSize, beamSize, rnnHiddenSize)
cellCurrent = cellCurrent.view(
num_layers, batchSize, beamSize, rnnHiddenSize)

# Update states according to the next top beams
backIndexVector = backIndices[:, :, t].unsqueeze(0)\
.unsqueeze(-1).repeat(num_layers, 1, 1, rnnHiddenSize)
hiddenCurrent = hiddenCurrent.gather(2, backIndexVector)
cellCurrent = cellCurrent.gather(2, backIndexVector)

# Restore original shape for next rnn forward
hiddenCurrent = hiddenCurrent.view(*original_state_size)
cellCurrent = cellCurrent.view(*original_state_size)
hiddenStates = (hiddenCurrent, cellCurrent)

# Detecting endToken to end beams
aliveVector = beamTokensTable[:, :, t:t + 1].ne(self.endToken)
aliveBeams = aliveVector.data.long().sum()
Expand Down

0 comments on commit 1fb7e88

Please sign in to comment.