Skip to content

Commit

Permalink
[kernel] support pure fp16 for cpu adam (#4896)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 authored Oct 12, 2023
1 parent 83b52c5 commit 9c88c87
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 40 deletions.
149 changes: 125 additions & 24 deletions colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ SOFTWARE
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = 0;

float betta1_minus1 = 1 - _betta1;
Expand All @@ -45,13 +46,21 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,

__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
__half *momentum_cast_h = NULL;
__half *variance_cast_h = NULL;

if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
if (momentum_half_precision) {
momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
}
if (variance_half_precision) {
variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
}

#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
Expand Down Expand Up @@ -98,10 +107,18 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
}
AVX_Data momentum_4;
momentum_4.data = SIMD_LOAD(_exp_avg + i);
if (momentum_half_precision) {
momentum_4.data = SIMD_LOAD_HALF(momentum_cast_h + i);
} else {
momentum_4.data = SIMD_LOAD(_exp_avg + i);
}

AVX_Data variance_4;
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
if (variance_half_precision) {
variance_4.data = SIMD_LOAD_HALF(variance_cast_h + i);
} else {
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
}

AVX_Data param_4;
if (param_half_precision) {
Expand Down Expand Up @@ -135,8 +152,16 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
} else {
SIMD_STORE(_params + i, param_4.data);
}
SIMD_STORE(_exp_avg + i, momentum_4.data);
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
if (momentum_half_precision) {
SIMD_STORE_HALF((float *)(momentum_cast_h + i), momentum_4.data);
} else {
SIMD_STORE(_exp_avg + i, momentum_4.data);
}
if (variance_half_precision) {
SIMD_STORE_HALF((float *)(variance_cast_h + i), variance_4.data);
} else {
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
}
}
}
#endif
Expand All @@ -154,8 +179,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
}
float param =
param_half_precision ? (float)params_cast_h[k] : _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
float momentum =
momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k];
float variance = variance_half_precision ? (float)variance_cast_h[k]
: _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
Expand All @@ -178,8 +205,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
params_cast_h[k] = (__half)param;
else
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
if (momentum_half_precision)
momentum_cast_h[k] = (__half)(momentum);
else
_exp_avg[k] = momentum;
if (variance_half_precision)
variance_cast_h[k] = (__half)(variance);
else
_exp_avg_sq[k] = variance;
}
}
}
Expand All @@ -188,17 +221,26 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = 0;

__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
__half *momentum_cast_h = NULL;
__half *variance_cast_h = NULL;
if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
if (momentum_half_precision) {
momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
}
if (variance_half_precision) {
variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
}

#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
Expand Down Expand Up @@ -255,8 +297,18 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}

momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (momentum_half_precision) {
momentum_4[j].data =
SIMD_LOAD_HALF(momentum_cast_h + i + SIMD_WIDTH * j);
} else {
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
}
if (variance_half_precision) {
variance_4[j].data =
SIMD_LOAD_HALF(variance_cast_h + i + SIMD_WIDTH * j);
} else {
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
}

if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
Expand Down Expand Up @@ -291,8 +343,18 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
} else {
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
if (momentum_half_precision) {
SIMD_STORE_HALF((float *)(momentum_cast_h + i + SIMD_WIDTH * j),
momentum_4[j].data);
} else {
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
}
if (variance_half_precision) {
SIMD_STORE_HALF((float *)(variance_cast_h + i + SIMD_WIDTH * j),
variance_4[j].data);
} else {
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
}
}
}
}
Expand All @@ -302,24 +364,37 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, loss_scale);
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}

void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = 0;
__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
__half *momentum_cast_h = NULL;
__half *variance_cast_h = NULL;
if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
if (momentum_half_precision) {
momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
}
if (variance_half_precision) {
variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
Expand Down Expand Up @@ -375,8 +450,18 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}

momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (momentum_half_precision) {
momentum_4[j].data =
SIMD_LOAD_HALF(momentum_cast_h + i + SIMD_WIDTH * j);
} else {
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
}
if (variance_half_precision) {
variance_4[j].data =
SIMD_LOAD_HALF(variance_cast_h + i + SIMD_WIDTH * j);
} else {
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
}

if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
Expand Down Expand Up @@ -412,8 +497,18 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}

SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data);
if (momentum_half_precision) {
SIMD_STORE_HALF((float *)(momentum_cast_h + i + SIMD_WIDTH * j),
momentum_4[j].data);
} else {
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
}
if (variance_half_precision) {
SIMD_STORE_HALF((float *)(variance_cast_h + i + SIMD_WIDTH * j),
variance_4[j].data);
} else {
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
}
}
}
}
Expand All @@ -423,9 +518,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, loss_scale);
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}

void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
Expand All @@ -447,7 +546,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf), loss_scale);
(grads.options().dtype() == at::kHalf),
(exp_avg.options().dtype() == at::kHalf),
(exp_avg_sq.options().dtype() == at::kHalf), loss_scale);
}

namespace py = pybind11;
Expand Down
11 changes: 6 additions & 5 deletions colossalai/kernel/cuda_native/csrc/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@ union AVX_Data {

#endif

#define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1);
#define STEP(SPAN) \
void Step_##SPAN( \
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
size_t _param_size, bool param_half_precision = false, \
bool grad_half_precision = false, bool momentum_half_precision = false, \
bool variance_half_precision = false, float loss_scale = -1);

class Adam_Optimizer {
public:
Expand Down
3 changes: 1 addition & 2 deletions colossalai/nn/optimizer/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def step(self, closure=None, div_scale: float = -1):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
# FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
Expand Down
3 changes: 1 addition & 2 deletions colossalai/nn/optimizer/hybrid_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def step(self, closure=None, div_scale: float = -1):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
# FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
Expand Down
7 changes: 2 additions & 5 deletions tests/test_optimizer/test_adam_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
_FUSED_ALLOWED_P_G_TYPES = [
(torch.float, torch.half),
(torch.float, torch.float),
(torch.half, torch.float),
(torch.half, torch.half),
(torch.bfloat16, torch.float),
(torch.float, torch.bfloat16),
(torch.bfloat16, torch.bfloat16),
]

_CPU_ALLOWED_P_G_TYPES = [
(torch.float, torch.half),
(torch.float, torch.float),
(torch.half, torch.float),
(torch.half, torch.half),
]

Expand Down Expand Up @@ -138,8 +135,8 @@ def check_adam_kernel(
master_exp_avg_sq = torch.zeros_like(master_p)
p = master_p.clone().to(p_dtype)
g = master_g.clone().to(g_dtype)
exp_avg = master_exp_avg.clone()
exp_avg_sq = master_exp_avg_sq.clone()
exp_avg = master_exp_avg.clone().to(p_dtype)
exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)

for step in range(1, 1 + n_steps):
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_optimizer/test_adam_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
(torch.float, torch.float), # pure fp32
(torch.float, torch.half), # fp16 amp
(torch.float, torch.bfloat16), # bfloat16 amp
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
]

N_STEPS = 3
Expand Down

0 comments on commit 9c88c87

Please sign in to comment.