Skip to content

Commit

Permalink
Improve training stable in float16 mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed May 16, 2024
1 parent 7fe451d commit 1a7621e
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions Seq2SeqSharp/Layers/MultiHeadAttention.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,6 @@ public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor keyMask, int ba
}
var attn = g.MulBatch(Qs, Ks, scale); // Shape: [batchSize * m_multiHeadNum, relPosSize, seqLenQ]

// Convert it back to Float16 for the following parts
if (useF16)
{
attn = g.Float2Half(attn);
}

// Add mask
attn = g.View(attn, dims: new long[] { batchSize, m_multiHeadNum, newTokensIdx, seqLenQ });
if (keyMask != null)
Expand All @@ -185,14 +179,20 @@ public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor keyMask, int ba
{
keyMask = g.Peek(keyMask, 2, seqLenQ - newTokensIdx, newTokensIdx);
}
//if (useF16)
//{
// keyMask = g.Half2Float(keyMask);
//}
if (useF16)
{
keyMask = g.Half2Float(keyMask);
}
attn = g.Add(attn, keyMask, inPlace: true);
}

attn = g.Softmax(attn, inPlace: true);

// Convert it back to Float16 for the following parts
if (useF16)
{
attn = g.Float2Half(attn);
}

attn = g.View(attn, dims: new long[] { batchSize * m_multiHeadNum, newTokensIdx, seqLenQ });
IWeightTensor o = g.View(g.MulBatch(attn, Vs), dims: new long[] { batchSize, m_multiHeadNum, newTokensIdx, m_d });
IWeightTensor W = g.View(g.AsContiguous(g.Transpose(o, 1, 2)), dims: new long[] { batchSize * newTokensIdx, m_multiHeadNum * m_d });
Expand Down

0 comments on commit 1a7621e

Please sign in to comment.