Skip to content

Commit

Permalink
Add block tokens for greedy search in decoder.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Mar 26, 2024
1 parent 97dc779 commit 1ad9b02
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
27 changes: 23 additions & 4 deletions Seq2SeqSharp/Applications/Decoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using System.Collections.Generic;
using TensorSharp;
using Seq2SeqSharp.Enums;
using ProtoBuf;

namespace Seq2SeqSharp.Applications
{
Expand Down Expand Up @@ -396,14 +397,24 @@ public static (float, List<List<BeamSearchStatus>>) DecodeTransformer(List<List<
}
else
{
if (decodingOptions.BlockedTokens != null && decodingOptions.BlockedTokens.Count > 0)
{
var btList = new List<List<int>>();
btList.Add(decodingOptions.BlockedTokens);
var blockTokensTensor = g.CreateTokensTensor(btList, elementType: DType.Float32); // [1, BlockedTokens.Count]
blockTokensTensor = g.Scatter(blockTokensTensor, -1.0f, 1, false, shape: new long[] { 1, probs.Sizes[1] });
blockTokensTensor = g.Expand(blockTokensTensor, dims: probs.Sizes);
probs = g.Add(blockTokensTensor, probs);
}

// Transformer decoder with beam search at inference time
List<List<BeamSearchStatus>> bssSeqList = new List<List<BeamSearchStatus>>(); //shape: (beam_search_size, batch_size)
int beamSearchSize = decodingOptions.BeamSearchSize;
while (beamSearchSize > 0)
{
// Output "i"th target word
using var targetIdxTensor = (decodingOptions.DecodingStrategy == DecodingStrategyEnums.GreedySearch) ? g.Argmax(probs, 1) :
g.TopPSample(probs, decodingOptions.TopP, decodingOptions.RepeatPenalty, decodingOptions.BlockedTokens, decodedSequences: tgtSeqs);
g.TopPSample(probs, decodingOptions.TopP, decodingOptions.RepeatPenalty, decodedSequences: tgtSeqs);
IWeightTensor gatherTensor = null;
if (outputSentScore)
{
Expand Down Expand Up @@ -516,10 +527,8 @@ public static (float, List<List<BeamSearchStatus>>) GPTDecode(List<List<int>> tg
decOutputIdx[i] = tgtSeqLen * (i + 1) - 1;
}


var indice = g.CreateTensorWeights(new long[] { decOutputIdx.Length, 1 }, decOutputIdx);
decOutput = g.IndexSelect(decOutput, indice);
tgtSeqLen = 1;
}

IWeightTensor ffLayer = decoderFFLayer.Process(decOutput, batchSize, g);
Expand Down Expand Up @@ -548,14 +557,24 @@ public static (float, List<List<BeamSearchStatus>>) GPTDecode(List<List<int>> tg
}
else
{
if (decodingOptions.BlockedTokens != null && decodingOptions.BlockedTokens.Count > 0)
{
var btList = new List<List<int>>();
btList.Add(decodingOptions.BlockedTokens);
var blockTokensTensor = g.CreateTokensTensor(btList, elementType: DType.Float32); // [1, BlockedTokens.Count]
blockTokensTensor = g.Scatter(blockTokensTensor, -1.0f, 1, false, shape: new long[] { 1, probs.Sizes[1] });
blockTokensTensor = g.Expand(blockTokensTensor, dims: probs.Sizes);
probs = g.Add(blockTokensTensor, probs);
}

// Transformer decoder with beam search at inference time
List<List<BeamSearchStatus>> bssSeqList = new List<List<BeamSearchStatus>>(); //shape: (beam_search_size, batch_size)
int beamSearchSize = decodingOptions.BeamSearchSize;
while (beamSearchSize > 0)
{
// Output "i"th target word
using var targetIdxTensor = (decodingOptions.DecodingStrategy == DecodingStrategyEnums.GreedySearch) ? g.Argmax(probs, 1) :
g.TopPSample(probs, decodingOptions.TopP, decodingOptions.RepeatPenalty, decodingOptions.BlockedTokens, decodedSequences: tgtSeqs);
g.TopPSample(probs, decodingOptions.TopP, decodingOptions.RepeatPenalty, decodedSequences: tgtSeqs);
IWeightTensor gatherTensor = null;
if (outputSentScore)
{
Expand Down
7 changes: 1 addition & 6 deletions Seq2SeqSharp/Tools/ComputeGraphTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,7 @@ public IWeightTensor GreaterThan(IWeightTensor w, float val)
/// <param name="seqs"></param>
/// <param name="topP"></param>
/// <returns>The sampled index</returns>
public IWeightTensor TopPSample(IWeightTensor w, float topP = 1.0f, float repeatPenalty = 2.0f, List<int> blockedTokens = null, List<List<int>> decodedSequences = null)
public IWeightTensor TopPSample(IWeightTensor w, float topP = 1.0f, float repeatPenalty = 2.0f, List<List<int>> decodedSequences = null)
{
int K = w.Columns;
WeightTensor m = w as WeightTensor;
Expand Down Expand Up @@ -1597,11 +1597,6 @@ public IWeightTensor TopPSample(IWeightTensor w, float topP = 1.0f, float repeat
float weight = weights[offset + j];
int idx = j;

if (blockedTokens != null && blockedTokens.Contains(idx))
{
continue;
}

// Decay weights if tokens has already been generated before
if (tokenId2Distance.ContainsKey(idx))
{
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Tools/IComputeGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public interface IComputeGraph : IDisposable
IWeightTensor LessOrEqual(IWeightTensor w, float val);
IWeightTensor GreaterThan(IWeightTensor w, float val);

IWeightTensor TopPSample(IWeightTensor w, float topP = 1.0f, float repeatPenalty = 2.0f, List<int> blockedTokens = null, List<List<int>> decodedSequences = null);
IWeightTensor TopPSample(IWeightTensor w, float topP = 1.0f, float repeatPenalty = 2.0f, List<List<int>> decodedSequences = null);

IWeightTensor Zero(long[] sizes);
IWeightTensor CreateTensorWeights(long[] sizes, float[] values);
Expand Down

0 comments on commit 1ad9b02

Please sign in to comment.