Skip to content

Commit

Permalink
minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Sep 21, 2024
1 parent 6c518e6 commit 485b7e9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 28 deletions.
26 changes: 3 additions & 23 deletions Seq2SeqSharp/MultiProcessorNetworkWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ public class MultiProcessorNetworkWrapper<T> : IMultiProcessorNetworkWrapper whe
{
private readonly T[] m_networks;
private readonly int m_defaultDeviceId;
// private readonly T m_networkOnDefaultDevice;
private readonly bool m_isStaticWeights;
private bool m_weightsSynced;

Expand All @@ -33,12 +32,10 @@ public MultiProcessorNetworkWrapper(T networkOnDefaultDevice, int[] deviceIds, b
{
m_networks = new T[deviceIds.Length];
m_defaultDeviceId = networkOnDefaultDevice.GetDeviceId();
// m_networkOnDefaultDevice = networkOnDefaultDevice;
m_isStaticWeights = isStaticWeights;
m_weightsSynced = false;

object locker = new object();
Parallel.For(0, deviceIds.Length, i =>
for (int i = 0; i < deviceIds.Length; i++)
{
if (deviceIds[i] == m_defaultDeviceId)
{
Expand All @@ -49,25 +46,8 @@ public MultiProcessorNetworkWrapper(T networkOnDefaultDevice, int[] deviceIds, b
m_networks[i] = (T)networkOnDefaultDevice.CloneToDeviceAt(deviceIds[i]);
}

lock (locker)
{
m_deviceId2Network.Add(deviceIds[i], m_networks[i]);
}
});

//for (int i = 0; i < deviceIds.Length; i++)
//{
// if (deviceIds[i] == m_defaultDeviceId)
// {
// m_networks[i] = networkOnDefaultDevice;
// }
// else
// {
// m_networks[i] = (T)networkOnDefaultDevice.CloneToDeviceAt(deviceIds[i]);
// }

// m_deviceId2Network.Add(deviceIds[i], m_networks[i]);
//}
m_deviceId2Network.Add(deviceIds[i], m_networks[i]);
}

var raDeviceIds = new RoundArray<int>(deviceIds);
var weights = networkOnDefaultDevice.GetParams();
Expand Down
9 changes: 4 additions & 5 deletions Seq2SeqSharp/Tools/ComputeGraphTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4035,7 +4035,7 @@ void backward()

private (float, IWeightTensor) CalculateEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float label_smoothing = 0.1f)
{
// float N = (float)probs.Sizes[0];
float N = (float)probs.Sizes[0];
float num_classes = (float)probs.Sizes[1];
float eps = 1e-9f;

Expand All @@ -4050,11 +4050,10 @@ void backward()
probs = Clip(probs, eps, 1.0f);
var logProbs = Log(probs);
var smooth_LogProbs = EltMul(smooth_targets, logProbs);
smooth_LogProbs = Sum(smooth_LogProbs, 1); // [seq_size * batch_size, 1]
smooth_LogProbs = Mean(smooth_LogProbs, 0); //[1,1]
smooth_LogProbs = Mul(smooth_LogProbs, -1.0f, inPlace: true);
smooth_LogProbs = Mul(smooth_LogProbs, -1.0f / N, inPlace: true);

var lossValue = smooth_LogProbs.ToWeightArray().Sum() / smooth_LogProbs.ElementCount;
var lossTrue = Gather(smooth_LogProbs, scatterIdxTensor, 1, runGradients: false);
var lossValue = lossTrue.ToWeightArray().Sum() / smooth_LogProbs.ElementCount;

return (lossValue, smooth_LogProbs);
}
Expand Down

0 comments on commit 485b7e9

Please sign in to comment.