From 485b7e92ef3381e7fa877b0b1b0cbcf95a1acfd1 Mon Sep 17 00:00:00 2001 From: Zhongkai Fu Date: Sat, 21 Sep 2024 12:23:03 -0700 Subject: [PATCH] minor update --- Seq2SeqSharp/MultiProcessorNetworkWrapper.cs | 26 +++----------------- Seq2SeqSharp/Tools/ComputeGraphTensor.cs | 9 +++---- 2 files changed, 7 insertions(+), 28 deletions(-) diff --git a/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs b/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs index 75ba841..c54c90f 100644 --- a/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs +++ b/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs @@ -22,7 +22,6 @@ public class MultiProcessorNetworkWrapper : IMultiProcessorNetworkWrapper whe { private readonly T[] m_networks; private readonly int m_defaultDeviceId; - // private readonly T m_networkOnDefaultDevice; private readonly bool m_isStaticWeights; private bool m_weightsSynced; @@ -33,12 +32,10 @@ public MultiProcessorNetworkWrapper(T networkOnDefaultDevice, int[] deviceIds, b { m_networks = new T[deviceIds.Length]; m_defaultDeviceId = networkOnDefaultDevice.GetDeviceId(); - // m_networkOnDefaultDevice = networkOnDefaultDevice; m_isStaticWeights = isStaticWeights; m_weightsSynced = false; - object locker = new object(); - Parallel.For(0, deviceIds.Length, i => + for (int i = 0; i < deviceIds.Length; i++) { if (deviceIds[i] == m_defaultDeviceId) { @@ -49,25 +46,8 @@ public MultiProcessorNetworkWrapper(T networkOnDefaultDevice, int[] deviceIds, b m_networks[i] = (T)networkOnDefaultDevice.CloneToDeviceAt(deviceIds[i]); } - lock (locker) - { - m_deviceId2Network.Add(deviceIds[i], m_networks[i]); - } - }); - - //for (int i = 0; i < deviceIds.Length; i++) - //{ - // if (deviceIds[i] == m_defaultDeviceId) - // { - // m_networks[i] = networkOnDefaultDevice; - // } - // else - // { - // m_networks[i] = (T)networkOnDefaultDevice.CloneToDeviceAt(deviceIds[i]); - // } - - // m_deviceId2Network.Add(deviceIds[i], m_networks[i]); - //} + m_deviceId2Network.Add(deviceIds[i], m_networks[i]); + } var raDeviceIds = new RoundArray(deviceIds); var weights = networkOnDefaultDevice.GetParams(); diff --git a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs index 962df68..b813cc4 100644 --- a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs +++ b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs @@ -4035,7 +4035,7 @@ void backward() private (float, IWeightTensor) CalculateEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float label_smoothing = 0.1f) { -// float N = (float)probs.Sizes[0]; + float N = (float)probs.Sizes[0]; float num_classes = (float)probs.Sizes[1]; float eps = 1e-9f; @@ -4050,11 +4050,10 @@ void backward() probs = Clip(probs, eps, 1.0f); var logProbs = Log(probs); var smooth_LogProbs = EltMul(smooth_targets, logProbs); - smooth_LogProbs = Sum(smooth_LogProbs, 1); // [seq_size * batch_size, 1] - smooth_LogProbs = Mean(smooth_LogProbs, 0); //[1,1] - smooth_LogProbs = Mul(smooth_LogProbs, -1.0f, inPlace: true); + smooth_LogProbs = Mul(smooth_LogProbs, -1.0f / N, inPlace: true); - var lossValue = smooth_LogProbs.ToWeightArray().Sum() / smooth_LogProbs.ElementCount; + var lossTrue = Gather(smooth_LogProbs, scatterIdxTensor, 1, runGradients: false); + var lossValue = lossTrue.ToWeightArray().Sum() / smooth_LogProbs.ElementCount; return (lossValue, smooth_LogProbs); }