Skip to content

Commit

Permalink
Update loss scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Apr 6, 2024
1 parent b4e8c72 commit 9d989ba
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 32 deletions.
8 changes: 4 additions & 4 deletions Seq2SeqSharp/Applications/Decoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ public static (float, List<List<BeamSearchStatus>>) DecodeTransformer(List<List<
IWeightTensor tgtEmbedding, float[] srcOriginalLenghts, Vocab tgtVocab, PaddingEnums paddingType, float dropoutRatio, DecodingOptions decodingOptions, bool isTraining = true,
bool outputSentScore = true, List<BeamSearchStatus> previousBeamSearchResults = null, IFeedForwardLayer pointerGenerator = null, List<List<int>> srcSeqs = null, Dictionary<string, IWeightTensor> cachedTensors = null,
List<List<int>> alignmentsToSrc = null, List<List<float>> alignmentScoresToSrc = null, bool teacherForcedAlignment = false, LossEnums lossType = LossEnums.CrossEntropy, float focalLossGamma = 0.0f, float lossSmooth = 1e-9f,
List<int> blockedTokens = null, IWeightTensor segmentEmbeddings = null, bool amp = false, IWeightTensor posEmbeddings = null, float lossScaling = 0.0f)
List<int> blockedTokens = null, IWeightTensor segmentEmbeddings = null, bool amp = false, IWeightTensor posEmbeddings = null, float lossScaling = 1.0f)
{
int eosTokenId = tgtVocab.GetWordIndex(BuildInTokens.EOS, logUnk: true);
int batchSize = tgtSeqs.Count;
Expand Down Expand Up @@ -391,7 +391,7 @@ public static (float, List<List<BeamSearchStatus>>) DecodeTransformer(List<List<
if (isTraining)
{
var leftShiftTgtSeqs = g.LeftShiftTokens(tgtSeqs, eosTokenId);
var cost = lossType == LossEnums.CrossEntropy ? g.CrossEntropyLoss(probs, leftShiftTgtSeqs, smooth: lossSmooth, gamma: focalLossGamma, lossScaling: lossScaling) : g.NLLLoss(probs, leftShiftTgtSeqs);
var cost = lossType == LossEnums.CrossEntropy ? g.CrossEntropyLoss(probs, leftShiftTgtSeqs, graident: lossScaling, smooth: lossSmooth, gamma: focalLossGamma) : g.NLLLoss(probs, leftShiftTgtSeqs);

return (cost, null);
}
Expand Down Expand Up @@ -490,7 +490,7 @@ public static (float, List<List<BeamSearchStatus>>) GPTDecode(List<List<int>> tg
IWeightTensor tgtEmbedding, Vocab tgtVocab, PaddingEnums paddingType, float dropoutRatio, DecodingOptions decodingOptions, bool isTraining = true,
bool outputSentScore = true, List<BeamSearchStatus> previousBeamSearchResults = null, Dictionary<string, IWeightTensor> cachedTensors = null,
LossEnums lossType = LossEnums.CrossEntropy, float focalLossGamma = 0.0f, float lossSmooth = 1e-9f, IWeightTensor segmentEmbeddings = null, bool amp = true,
IWeightTensor posEmbeddings = null, float lossScaling = 0.0f)
IWeightTensor posEmbeddings = null, float lossScaling = 1.0f)
{
int eosTokenId = tgtVocab.GetWordIndex(BuildInTokens.EOS, logUnk: true);
int batchSize = tgtSeqs.Count;
Expand Down Expand Up @@ -551,7 +551,7 @@ public static (float, List<List<BeamSearchStatus>>) GPTDecode(List<List<int>> tg
if (isTraining)
{
var leftShiftTgtSeqs = g.LeftShiftTokens(tgtSeqs, eosTokenId);
var cost = lossType == LossEnums.CrossEntropy ? g.CrossEntropyLoss(probs, leftShiftTgtSeqs, graident:1.0f, smooth: lossSmooth, gamma: focalLossGamma, lossScaling: lossScaling) : g.NLLLoss(probs, leftShiftTgtSeqs);
var cost = lossType == LossEnums.CrossEntropy ? g.CrossEntropyLoss(probs, leftShiftTgtSeqs, graident: lossScaling, smooth: lossSmooth, gamma: focalLossGamma) : g.NLLLoss(probs, leftShiftTgtSeqs);

return (cost, null);
}
Expand Down
6 changes: 3 additions & 3 deletions Seq2SeqSharp/Applications/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ public class Options
[Range(-1, 9999999)]
public int RandomSeed = -1;

[Arg("Initial loss Scaling when AMP is enabled. Default is 0 which is disabled.", nameof(InitLossScaling))]
[Range(0, 65000)]
public float InitLossScaling = 0.0f;
[Arg("Initial loss Scaling when AMP is enabled. Default is 1 which is disabled.", nameof(InitLossScaling))]
[Range(1, 65000)]
public float InitLossScaling = 1.0f;

[Arg("The Positional Embeddings Type. It supports APE, NoPE and RoPE", nameof(PEType))]
[RegularExpression("APE|NoPE|RoPE")]
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/SeqClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
var tgtSnts = sntPairBatch.GetTgtTokens();
var tgtTokensLists = m_modelMetaData.TgtVocab.GetWordIndex(tgtSnts);
var tgtTokensTensor = computeGraph.CreateTokensTensor(tgtTokensLists);
nr.Cost = computeGraph.CrossEntropyLoss(probs, tgtTokensTensor, lossScaling: LossScaling);
nr.Cost = computeGraph.CrossEntropyLoss(probs, tgtTokensTensor, graident: LossScaling);
}
else
{
Expand Down
4 changes: 2 additions & 2 deletions Seq2SeqSharp/Applications/SeqLabel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph g, IP

if (m_tagWeightsList == null)
{
cost = g.CrossEntropyLoss(probs, tgtTokensTensor, smooth: m_options.LossSmooth, gamma: m_options.FocalLossGamma, lossScaling: LossScaling);
cost = g.CrossEntropyLoss(probs, tgtTokensTensor, smooth: m_options.LossSmooth, gamma: m_options.FocalLossGamma, graident: LossScaling);
}
else
{
var tagWeightsTensor = g.CreateTensorWeights(sizes: new long[] { 1, m_tagWeightsList.Length }, m_tagWeightsList);
tagWeightsTensor = g.Expand(tagWeightsTensor, dims: probs.Sizes);
cost = g.CrossEntropyLoss(probs, tgtTokensTensor, tagWeightsTensor, smooth: m_options.LossSmooth, gamma: m_options.FocalLossGamma, lossScaling: LossScaling);
cost = g.CrossEntropyLoss(probs, tgtTokensTensor, tagWeightsTensor, smooth: m_options.LossSmooth, gamma: m_options.FocalLossGamma);
}
}
else
Expand Down
6 changes: 3 additions & 3 deletions Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,14 @@ public abstract class BaseSeq2SeqFramework<T> where T : Model
bool m_saveGPUMemoryMode = false;
CudaMemoryDeviceAllocatorType m_cudaMemoryAllocatorType = CudaMemoryDeviceAllocatorType.CudaMemoryPool;
DType m_elementType = DType.Float32;
float m_initLossScaling = 0.0f;
float m_initLossScaling = 1.0f;

public float LossScaling = 0.0f;
public float LossScaling = 1.0f;

public BaseSeq2SeqFramework(string deviceIds, ProcessorTypeEnums processorType, string modelFilePath, float memoryUsageRatio = 0.9f,
string compilerOptions = null, int runValidEveryUpdates = 10000, int primaryTaskId = 0, int updateFreq = 1, int startToRunValidAfterUpdates = 0,
int maxDegressOfParallelism = 1, string mklInstructions = "AVX2", int weightsUpdateCount = 0, bool enableTensorCore = true, CudaMemoryDeviceAllocatorType cudaMemoryAllocatorType = CudaMemoryDeviceAllocatorType.CudaMemoryPool,
DType elementType = DType.Float32, int randomSeed = -1, int saveModelEveryUpdats = 10000, bool saveGPUMemoryMode = false, float initLossScaling = 0.0f)
DType elementType = DType.Float32, int randomSeed = -1, int saveModelEveryUpdats = 10000, bool saveGPUMemoryMode = false, float initLossScaling = 1.0f)
{
m_deviceIds = deviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
m_compilerOptions = compilerOptions;
Expand Down
23 changes: 6 additions & 17 deletions Seq2SeqSharp/Tools/ComputeGraphTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2929,7 +2929,7 @@ void backward()



private (float, IWeightTensor) CalculateEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float smooth, float gamma, float lossScaling)
private (float, IWeightTensor) CalculateEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float smooth, float gamma)
{
var scatterIdxTensor = View(truthTgtSeqs, new long[] { -1, 1 });
var scatterTrue = Scatter(scatterIdxTensor, 1.0f, 1, needGradient: false, shape: probs.Sizes);
Expand All @@ -2949,41 +2949,30 @@ void backward()
}

loss = Log(loss);
loss = Mul(loss, -1.0f, inPlace: true);

if (lossScaling > 0.0f)
{
loss = Mul(loss, -1.0f * lossScaling, inPlace: true);
}
else
{
loss = Mul(loss, -1.0f, inPlace: true);
}

if (focalFactor != null)
{
loss = EltMul(loss, focalFactor);
}
var lossTrue = Gather(loss, scatterIdxTensor, 1, runGradients: false);
var lossValue = lossTrue.ToWeightArray().Sum() / loss.ElementCount;
if (lossScaling > 0.0f)
{
lossValue = lossValue / lossScaling;
}

return (lossValue, loss);
}

public float CrossEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float graident = 1.0f, float smooth = 0.0f, float gamma = 0.0f, float lossScaling = 0.0f)
public float CrossEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float graident = 1.0f, float smooth = 0.0f, float gamma = 0.0f)
{
(float lossValue, IWeightTensor loss) = CalculateEntropyLoss(probs, truthTgtSeqs, smooth, gamma, lossScaling);
(float lossValue, IWeightTensor loss) = CalculateEntropyLoss(probs, truthTgtSeqs, smooth, gamma);
loss.FillGradient(graident);

return lossValue;
}

public float CrossEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, IWeightTensor graident, float smooth = 0.0f, float gamma = 0.0f, float lossScaling = 0.0f)
public float CrossEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, IWeightTensor graident, float smooth = 0.0f, float gamma = 0.0f)
{
(float lossValue, IWeightTensor loss) = CalculateEntropyLoss(probs, truthTgtSeqs, smooth, gamma, lossScaling);
(float lossValue, IWeightTensor loss) = CalculateEntropyLoss(probs, truthTgtSeqs, smooth, gamma);
loss.CopyWeightsToGradients(graident);

return lossValue;
Expand Down
4 changes: 2 additions & 2 deletions Seq2SeqSharp/Tools/IComputeGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ public interface IComputeGraph : IDisposable
IWeightTensor Exp(IWeightTensor w);
IWeightTensor Pow(IWeightTensor w, float n);

float CrossEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float graident = 1.0f, float smooth = 0.0f, float gamma = 0.0f, float lossScaling = 0.0f);
float CrossEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, IWeightTensor graident, float smooth = 0.0f, float gamma = 0.0f, float lossScaling = 0.0f);
float CrossEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float graident = 1.0f, float smooth = 0.0f, float gamma = 0.0f);
float CrossEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, IWeightTensor graident, float smooth = 0.0f, float gamma = 0.0f);
float NLLLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float graident = 1.0f, float smooth = 0.0f);

IWeightTensor CreateUniformRandomTensor(long[] sizes, float minVal, float maxVal);
Expand Down

0 comments on commit 9d989ba

Please sign in to comment.