From b7a51dc6d30c1bb938365ac520f59b00335355cc Mon Sep 17 00:00:00 2001 From: Zhongkai Fu Date: Tue, 17 Sep 2024 17:38:01 -0700 Subject: [PATCH] Check if updated weight is valid in AdamHalf --- TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs b/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs index f97a3c2..ab55c1f 100644 --- a/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs +++ b/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs @@ -2129,7 +2129,11 @@ __global__ void AdamHalf(__half* __restrict__ w, __half* __restrict__ g, float* sm[i] = sm[i] * decay_rate_m + (1.0 - decay_rate_m) * g; sv[i] = sv[i] * decay_rate_v + (1.0 - decay_rate_v) * g * g; - sw[i] = __float2half(__half2float(sw[i]) - (adapted_learning_rate * sm[i] / (sqrtf(sv[i]) + eps))); + __half sw_i = __float2half(__half2float(sw[i]) - (adapted_learning_rate * sm[i] / (sqrtf(sv[i]) + eps))); + if (isfinite(sw_i)) + { + sw[i] = sw_i; + } } } }