From 9c88c875d8cd4873ca490917cd22bb7d26076b12 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 12 Oct 2023 14:19:56 +0800 Subject: [PATCH] [kernel] support pure fp16 for cpu adam (#4896) --- .../kernel/cuda_native/csrc/cpu_adam.cpp | 149 +++++++++++++++--- colossalai/kernel/cuda_native/csrc/cpu_adam.h | 11 +- colossalai/nn/optimizer/cpu_adam.py | 3 +- colossalai/nn/optimizer/hybrid_adam.py | 3 +- tests/test_optimizer/test_adam_kernel.py | 7 +- tests/test_optimizer/test_adam_optim.py | 2 - 6 files changed, 135 insertions(+), 40 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp index 0ab250218da3..027d18a9dd58 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -35,7 +35,8 @@ SOFTWARE void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { size_t rounded_size = 0; float betta1_minus1 = 1 - _betta1; @@ -45,6 +46,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, __half *params_cast_h = NULL; __half *grads_cast_h = NULL; + __half *momentum_cast_h = NULL; + __half *variance_cast_h = NULL; if (param_half_precision) { params_cast_h = reinterpret_cast<__half *>(_params); @@ -52,6 +55,12 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, if (grad_half_precision) { grads_cast_h = reinterpret_cast<__half *>(grads); } + if (momentum_half_precision) { + momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + } + if (variance_half_precision) { + variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); + } #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -98,10 +107,18 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); } AVX_Data momentum_4; - momentum_4.data = SIMD_LOAD(_exp_avg + i); + if (momentum_half_precision) { + momentum_4.data = SIMD_LOAD_HALF(momentum_cast_h + i); + } else { + momentum_4.data = SIMD_LOAD(_exp_avg + i); + } AVX_Data variance_4; - variance_4.data = SIMD_LOAD(_exp_avg_sq + i); + if (variance_half_precision) { + variance_4.data = SIMD_LOAD_HALF(variance_cast_h + i); + } else { + variance_4.data = SIMD_LOAD(_exp_avg_sq + i); + } AVX_Data param_4; if (param_half_precision) { @@ -135,8 +152,16 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, } else { SIMD_STORE(_params + i, param_4.data); } - SIMD_STORE(_exp_avg + i, momentum_4.data); - SIMD_STORE(_exp_avg_sq + i, variance_4.data); + if (momentum_half_precision) { + SIMD_STORE_HALF((float *)(momentum_cast_h + i), momentum_4.data); + } else { + SIMD_STORE(_exp_avg + i, momentum_4.data); + } + if (variance_half_precision) { + SIMD_STORE_HALF((float *)(variance_cast_h + i), variance_4.data); + } else { + SIMD_STORE(_exp_avg_sq + i, variance_4.data); + } } } #endif @@ -154,8 +179,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, } float param = param_half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = _exp_avg[k]; - float variance = _exp_avg_sq[k]; + float momentum = + momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k]; + float variance = variance_half_precision ? (float)variance_cast_h[k] + : _exp_avg_sq[k]; if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } @@ -178,8 +205,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, params_cast_h[k] = (__half)param; else _params[k] = param; - _exp_avg[k] = momentum; - _exp_avg_sq[k] = variance; + if (momentum_half_precision) + momentum_cast_h[k] = (__half)(momentum); + else + _exp_avg[k] = momentum; + if (variance_half_precision) + variance_cast_h[k] = (__half)(variance); + else + _exp_avg_sq[k] = variance; } } } @@ -188,17 +221,26 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { size_t rounded_size = 0; __half *params_cast_h = NULL; __half *grads_cast_h = NULL; + __half *momentum_cast_h = NULL; + __half *variance_cast_h = NULL; if (param_half_precision) { params_cast_h = reinterpret_cast<__half *>(_params); } if (grad_half_precision) { grads_cast_h = reinterpret_cast<__half *>(grads); } + if (momentum_half_precision) { + momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + } + if (variance_half_precision) { + variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); + } #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -255,8 +297,18 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); + if (momentum_half_precision) { + momentum_4[j].data = + SIMD_LOAD_HALF(momentum_cast_h + i + SIMD_WIDTH * j); + } else { + momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); + } + if (variance_half_precision) { + variance_4[j].data = + SIMD_LOAD_HALF(variance_cast_h + i + SIMD_WIDTH * j); + } else { + variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); + } if (param_half_precision) { param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); @@ -291,8 +343,18 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, } else { SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); } - SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); - SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); + if (momentum_half_precision) { + SIMD_STORE_HALF((float *)(momentum_cast_h + i + SIMD_WIDTH * j), + momentum_4[j].data); + } else { + SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); + } + if (variance_half_precision) { + SIMD_STORE_HALF((float *)(variance_cast_h + i + SIMD_WIDTH * j), + variance_4[j].data); + } else { + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); + } } } } @@ -302,24 +364,37 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, : _params + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size) : grads + rounded_size), - (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), (_param_size - rounded_size), param_half_precision, - grad_half_precision, loss_scale); + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); } void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { size_t rounded_size = 0; __half *params_cast_h = NULL; __half *grads_cast_h = NULL; + __half *momentum_cast_h = NULL; + __half *variance_cast_h = NULL; if (param_half_precision) { params_cast_h = reinterpret_cast<__half *>(_params); } if (grad_half_precision) { grads_cast_h = reinterpret_cast<__half *>(grads); } + if (momentum_half_precision) { + momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + } + if (variance_half_precision) { + variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); + } #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; betta1_4.data = SIMD_SET(_betta1); @@ -375,8 +450,18 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); + if (momentum_half_precision) { + momentum_4[j].data = + SIMD_LOAD_HALF(momentum_cast_h + i + SIMD_WIDTH * j); + } else { + momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); + } + if (variance_half_precision) { + variance_4[j].data = + SIMD_LOAD_HALF(variance_cast_h + i + SIMD_WIDTH * j); + } else { + variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); + } if (param_half_precision) { param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); @@ -412,8 +497,18 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); } - SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data); - SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data); + if (momentum_half_precision) { + SIMD_STORE_HALF((float *)(momentum_cast_h + i + SIMD_WIDTH * j), + momentum_4[j].data); + } else { + SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); + } + if (variance_half_precision) { + SIMD_STORE_HALF((float *)(variance_cast_h + i + SIMD_WIDTH * j), + variance_4[j].data); + } else { + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); + } } } } @@ -423,9 +518,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, : _params + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size) : grads + rounded_size), - (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), (_param_size - rounded_size), param_half_precision, - grad_half_precision, loss_scale); + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); } void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, @@ -447,7 +546,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, this->update_state(lr, epsilon, weight_decay, bias_correction); this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), (params.options().dtype() == at::kHalf), - (grads.options().dtype() == at::kHalf), loss_scale); + (grads.options().dtype() == at::kHalf), + (exp_avg.options().dtype() == at::kHalf), + (exp_avg_sq.options().dtype() == at::kHalf), loss_scale); } namespace py = pybind11; diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h index 4247da942775..67f3bffaf46a 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -83,11 +83,12 @@ union AVX_Data { #endif -#define STEP(SPAN) \ - void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ - float *_exp_avg_sq, size_t _param_size, \ - bool param_half_precision = false, \ - bool grad_half_precision = false, float loss_scale = -1); +#define STEP(SPAN) \ + void Step_##SPAN( \ + float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \ + size_t _param_size, bool param_half_precision = false, \ + bool grad_half_precision = false, bool momentum_half_precision = false, \ + bool variance_half_precision = false, float loss_scale = -1); class Adam_Optimizer { public: diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 1bdb81e2d6ec..238ba366da43 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -146,8 +146,7 @@ def step(self, closure=None, div_scale: float = -1): assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - # FIXME(ver217): CPU adam kernel only supports fp32 states now - if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float: + if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 7dc4590dc3f2..c7a309b872ce 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -122,8 +122,7 @@ def step(self, closure=None, div_scale: float = -1): assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - # FIXME(ver217): CPU adam kernel only supports fp32 states now - if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float: + if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 8131ea3234d8..6bbe3e4e8172 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -13,9 +13,7 @@ _FUSED_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), (torch.float, torch.float), - (torch.half, torch.float), (torch.half, torch.half), - (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), (torch.bfloat16, torch.bfloat16), ] @@ -23,7 +21,6 @@ _CPU_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), (torch.float, torch.float), - (torch.half, torch.float), (torch.half, torch.half), ] @@ -138,8 +135,8 @@ def check_adam_kernel( master_exp_avg_sq = torch.zeros_like(master_p) p = master_p.clone().to(p_dtype) g = master_g.clone().to(g_dtype) - exp_avg = master_exp_avg.clone() - exp_avg_sq = master_exp_avg_sq.clone() + exp_avg = master_exp_avg.clone().to(p_dtype) + exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype) for step in range(1, 1 + n_steps): torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq) diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py index 59b40a0afa3c..68d71e3c4194 100644 --- a/tests/test_optimizer/test_adam_optim.py +++ b/tests/test_optimizer/test_adam_optim.py @@ -21,8 +21,6 @@ (torch.float, torch.float), # pure fp32 (torch.float, torch.half), # fp16 amp (torch.float, torch.bfloat16), # bfloat16 amp - # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 - # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 ] N_STEPS = 3