Skip to content

Commit

Permalink
Disable self-masking when FlashAttentionV2 is enabled.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Jul 29, 2024
1 parent f812cd3 commit e698de1
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 21 deletions.
42 changes: 24 additions & 18 deletions Seq2SeqSharp/Applications/Decoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -282,16 +282,19 @@ public static (float, List<List<BeamSearchStatus>>) DecodeTransformer(List<List<
srcTgtMask = g.View(srcTgtMask, new long[] { srcTgtMask.Sizes[0], 1, srcTgtMask.Sizes[1], srcTgtMask.Sizes[2] });
}

IWeightTensor tgtSelfTriMask;
if (paddingType == PaddingEnums.NoPadding || paddingType == PaddingEnums.NoPaddingInTgt || batchSize == 1)
IWeightTensor tgtSelfTriMask = null;
if (decoder.AttentionType == AttentionTypeEnums.Classic)
{
tgtSelfTriMask = g.BuildTriMask(tgtSeqLen, batchSize, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { 1, 1, tgtSeqLen, tgtSeqLen });
}
else
{
tgtSelfTriMask = g.BuildSelfTriMask(tgtSeqLen, tgtOriginalLengths, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { batchSize, 1, tgtSeqLen, tgtSeqLen });
if (paddingType == PaddingEnums.NoPadding || paddingType == PaddingEnums.NoPaddingInTgt || batchSize == 1)
{
tgtSelfTriMask = g.BuildTriMask(tgtSeqLen, batchSize, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { 1, 1, tgtSeqLen, tgtSeqLen });
}
else
{
tgtSelfTriMask = g.BuildSelfTriMask(tgtSeqLen, tgtOriginalLengths, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { batchSize, 1, tgtSeqLen, tgtSeqLen });
}
}

IWeightTensor inputEmbs = TensorUtils.CreateTokensEmbeddings(tgtSeqs, g, tgtEmbedding, segmentEmbeddings, tgtVocab, scaleFactor: (float)Math.Sqrt(tgtEmbedding.Columns), amp: amp);
Expand Down Expand Up @@ -497,16 +500,19 @@ public static (float, List<List<BeamSearchStatus>>) GPTDecode(List<List<int>> tg
var tgtOriginalLengths = BuildInTokens.PadSentences(tgtSeqs, eosTokenId);
int tgtSeqLen = tgtSeqs[0].Count;

IWeightTensor tgtSelfTriMask;
if (paddingType == PaddingEnums.NoPadding || paddingType == PaddingEnums.NoPaddingInTgt || batchSize == 1)
IWeightTensor tgtSelfTriMask = null;
if (decoder.AttentionType == AttentionTypeEnums.Classic)
{
tgtSelfTriMask = g.BuildTriMask(tgtSeqLen, batchSize, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { 1, 1, tgtSeqLen, tgtSeqLen });
}
else
{
tgtSelfTriMask = g.BuildSelfTriMask(tgtSeqLen, tgtOriginalLengths, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { batchSize, 1, tgtSeqLen, tgtSeqLen });
if (paddingType == PaddingEnums.NoPadding || paddingType == PaddingEnums.NoPaddingInTgt || batchSize == 1)
{
tgtSelfTriMask = g.BuildTriMask(tgtSeqLen, batchSize, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { 1, 1, tgtSeqLen, tgtSeqLen });
}
else
{
tgtSelfTriMask = g.BuildSelfTriMask(tgtSeqLen, tgtOriginalLengths, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { batchSize, 1, tgtSeqLen, tgtSeqLen });
}
}

IWeightTensor inputEmbs = TensorUtils.CreateTokensEmbeddings(tgtSeqs, g, tgtEmbedding, segmentEmbeddings, tgtVocab, scaleFactor: (float)Math.Sqrt(tgtEmbedding.Columns), amp: amp);
Expand Down
5 changes: 5 additions & 0 deletions Seq2SeqSharp/Applications/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ public void ValidateOptions()
{
throw new FileNotFoundException($"Model '{ModelFilePath}' doesn't exist for task '{Task}'");
}

if (AttentionType == AttentionTypeEnums.FlashAttentionV2 && ProcessorType != ProcessorTypeEnums.GPU)
{
throw new ArgumentException("FlashAttentionV2 runs on GPU only, please use the classic attention layer instead.");
}
}
}
}
2 changes: 2 additions & 0 deletions Seq2SeqSharp/Networks/GPTDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public class GPTDecoder : IDecoder
private readonly NormEnums m_normType;
private readonly AttentionTypeEnums m_attentionType;

public AttentionTypeEnums AttentionType => m_attentionType;

public GPTDecoder(string name, int multiHeadNum, int hiddenDim, int intermediateDim, int inputDim, int depth, float dropoutRatio, int deviceId,
bool isTrainable, float learningRateFactor = 1.0f, ActivateFuncEnums activateFunc = ActivateFuncEnums.ReLU, int expertNum = 1,
int expertsPerTokenFactor = 1, DType elementType = DType.Float32, PositionEmbeddingEnums peType = PositionEmbeddingEnums.APE, NormEnums normType = NormEnums.LayerNorm, AttentionTypeEnums attentionType = AttentionTypeEnums.Classic)
Expand Down
1 change: 1 addition & 0 deletions Seq2SeqSharp/Networks/TransformerDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class TransformerDecoder : IDecoder
private readonly NormEnums m_normType;
private readonly AttentionTypeEnums m_attentionType;

public AttentionTypeEnums AttentionType => m_attentionType;
public TransformerDecoder(string name, int multiHeadNum, int hiddenDim, int intermediateDim, int inputDim, int depth, float dropoutRatio,
int deviceId, bool isTrainable, float learningRateFactor = 1.0f, ActivateFuncEnums activateFunc = ActivateFuncEnums.ReLU,
int expertNum = 1, int expertsPerTokenFactor = 1, DType elementType = DType.Float32, PositionEmbeddingEnums peType = PositionEmbeddingEnums.APE, NormEnums normType = NormEnums.LayerNorm, AttentionTypeEnums attentionType = AttentionTypeEnums.Classic)
Expand Down
5 changes: 2 additions & 3 deletions TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2616,7 +2616,7 @@ private void FlashAttention(TSCudaContext context, Tensor Q, Tensor K, Tensor V,
int N = (int)Q.Sizes[2];
int d = (int)Q.Sizes[3];

int Br = 32;
int Br = 112;
while (Br > 1)
{
if (N % Br == 0)
Expand Down Expand Up @@ -2685,7 +2685,7 @@ private void FlashAttentionGrad(TSCudaContext context, Tensor Q, Tensor K, Tenso
int N = (int)Q.Sizes[2];
int d = (int)Q.Sizes[3];

int Br = 32;
int Br = 64;
while (Br > 1)
{
if (N % Br == 0)
Expand All @@ -2695,7 +2695,6 @@ private void FlashAttentionGrad(TSCudaContext context, Tensor Q, Tensor K, Tenso
Br--;
}
int Bc = Br;
//Logger.WriteLine($"Grad: N = '{N}', Br = '{Br}'");

int Tc = (int)Math.Ceiling((float)N / Bc);
int Tr = (int)Math.Ceiling((float)N / Br);
Expand Down

0 comments on commit e698de1

Please sign in to comment.