diff --git a/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs b/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs index 061d182..75ba841 100644 --- a/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs +++ b/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs @@ -37,7 +37,8 @@ public MultiProcessorNetworkWrapper(T networkOnDefaultDevice, int[] deviceIds, b m_isStaticWeights = isStaticWeights; m_weightsSynced = false; - for (int i = 0; i < deviceIds.Length; i++) + object locker = new object(); + Parallel.For(0, deviceIds.Length, i => { if (deviceIds[i] == m_defaultDeviceId) { @@ -48,8 +49,25 @@ public MultiProcessorNetworkWrapper(T networkOnDefaultDevice, int[] deviceIds, b m_networks[i] = (T)networkOnDefaultDevice.CloneToDeviceAt(deviceIds[i]); } - m_deviceId2Network.Add(deviceIds[i], m_networks[i]); - } + lock (locker) + { + m_deviceId2Network.Add(deviceIds[i], m_networks[i]); + } + }); + + //for (int i = 0; i < deviceIds.Length; i++) + //{ + // if (deviceIds[i] == m_defaultDeviceId) + // { + // m_networks[i] = networkOnDefaultDevice; + // } + // else + // { + // m_networks[i] = (T)networkOnDefaultDevice.CloneToDeviceAt(deviceIds[i]); + // } + + // m_deviceId2Network.Add(deviceIds[i], m_networks[i]); + //} var raDeviceIds = new RoundArray(deviceIds); var weights = networkOnDefaultDevice.GetParams();