diff --git a/visdial/models/decoders/gen.py b/visdial/models/decoders/gen.py index a5469de..3b4bd17 100644 --- a/visdial/models/decoders/gen.py +++ b/visdial/models/decoders/gen.py @@ -368,6 +368,28 @@ def beamSearchDecoder(self, initStates, beamSize, maxSeqLen): beamTokensTable[:, :, t] = tokensArray.gather(1, topIdx) backIndices[:, :, t] = backIndexArray.gather(1, topIdx) + # Update corresponding hidden and cell states for next time step + hiddenCurrent, cellCurrent = hiddenStates + + # Reshape to get explicit beamSize dim + original_state_size = hiddenCurrent.size() + num_layers, _, rnnHiddenSize = original_state_size + hiddenCurrent = hiddenCurrent.view( + num_layers, batchSize, beamSize, rnnHiddenSize) + cellCurrent = cellCurrent.view( + num_layers, batchSize, beamSize, rnnHiddenSize) + + # Update states according to the next top beams + backIndexVector = backIndices[:, :, t].unsqueeze(0)\ + .unsqueeze(-1).repeat(num_layers, 1, 1, rnnHiddenSize) + hiddenCurrent = hiddenCurrent.gather(2, backIndexVector) + cellCurrent = cellCurrent.gather(2, backIndexVector) + + # Restore original shape for next rnn forward + hiddenCurrent = hiddenCurrent.view(*original_state_size) + cellCurrent = cellCurrent.view(*original_state_size) + hiddenStates = (hiddenCurrent, cellCurrent) + # Detecting endToken to end beams aliveVector = beamTokensTable[:, :, t:t + 1].ne(self.endToken) aliveBeams = aliveVector.data.long().sum()