Skip to content

Commit

Permalink
Check if updated weight is valid in AdamHalf
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Sep 18, 2024
1 parent be44132 commit b7a51dc
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2129,7 +2129,11 @@ __global__ void AdamHalf(__half* __restrict__ w, __half* __restrict__ g, float*
sm[i] = sm[i] * decay_rate_m + (1.0 - decay_rate_m) * g;
sv[i] = sv[i] * decay_rate_v + (1.0 - decay_rate_v) * g * g;
sw[i] = __float2half(__half2float(sw[i]) - (adapted_learning_rate * sm[i] / (sqrtf(sv[i]) + eps)));
__half sw_i = __float2half(__half2float(sw[i]) - (adapted_learning_rate * sm[i] / (sqrtf(sv[i]) + eps)));
if (isfinite(sw_i))
{
sw[i] = sw_i;
}
}
}
}
Expand Down

0 comments on commit b7a51dc

Please sign in to comment.