Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel] support pure fp16 for cpu adam #4896

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 125 additions & 24 deletions colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ SOFTWARE
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = 0;

float betta1_minus1 = 1 - _betta1;
Expand All @@ -45,13 +46,21 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,

__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
__half *momentum_cast_h = NULL;
__half *variance_cast_h = NULL;

if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
if (momentum_half_precision) {
momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
}
if (variance_half_precision) {
variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
}

#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
Expand Down Expand Up @@ -98,10 +107,18 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
}
AVX_Data momentum_4;
momentum_4.data = SIMD_LOAD(_exp_avg + i);
if (momentum_half_precision) {
momentum_4.data = SIMD_LOAD_HALF(momentum_cast_h + i);
} else {
momentum_4.data = SIMD_LOAD(_exp_avg + i);
}

AVX_Data variance_4;
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
if (variance_half_precision) {
variance_4.data = SIMD_LOAD_HALF(variance_cast_h + i);
} else {
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
}

AVX_Data param_4;
if (param_half_precision) {
Expand Down Expand Up @@ -135,8 +152,16 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
} else {
SIMD_STORE(_params + i, param_4.data);
}
SIMD_STORE(_exp_avg + i, momentum_4.data);
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
if (momentum_half_precision) {
SIMD_STORE_HALF((float *)(momentum_cast_h + i), momentum_4.data);
} else {
SIMD_STORE(_exp_avg + i, momentum_4.data);
}
if (variance_half_precision) {
SIMD_STORE_HALF((float *)(variance_cast_h + i), variance_4.data);
} else {
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
}
}
}
#endif
Expand All @@ -154,8 +179,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
}
float param =
param_half_precision ? (float)params_cast_h[k] : _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
float momentum =
momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k];
float variance = variance_half_precision ? (float)variance_cast_h[k]
: _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
Expand All @@ -178,8 +205,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
params_cast_h[k] = (__half)param;
else
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
if (momentum_half_precision)
momentum_cast_h[k] = (__half)(momentum);
else
_exp_avg[k] = momentum;
if (variance_half_precision)
variance_cast_h[k] = (__half)(variance);
else
_exp_avg_sq[k] = variance;
}
}
}
Expand All @@ -188,17 +221,26 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = 0;

__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
__half *momentum_cast_h = NULL;
__half *variance_cast_h = NULL;
if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
if (momentum_half_precision) {
momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
}
if (variance_half_precision) {
variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
}

#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
Expand Down Expand Up @@ -255,8 +297,18 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}

momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (momentum_half_precision) {
momentum_4[j].data =
SIMD_LOAD_HALF(momentum_cast_h + i + SIMD_WIDTH * j);
} else {
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
}
if (variance_half_precision) {
variance_4[j].data =
SIMD_LOAD_HALF(variance_cast_h + i + SIMD_WIDTH * j);
} else {
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
}

if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
Expand Down Expand Up @@ -291,8 +343,18 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
} else {
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
if (momentum_half_precision) {
SIMD_STORE_HALF((float *)(momentum_cast_h + i + SIMD_WIDTH * j),
momentum_4[j].data);
} else {
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
}
if (variance_half_precision) {
SIMD_STORE_HALF((float *)(variance_cast_h + i + SIMD_WIDTH * j),
variance_4[j].data);
} else {
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
}
}
}
}
Expand All @@ -302,24 +364,37 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, loss_scale);
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}

void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = 0;
__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
__half *momentum_cast_h = NULL;
__half *variance_cast_h = NULL;
if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
if (momentum_half_precision) {
momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
}
if (variance_half_precision) {
variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
Expand Down Expand Up @@ -375,8 +450,18 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}

momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (momentum_half_precision) {
momentum_4[j].data =
SIMD_LOAD_HALF(momentum_cast_h + i + SIMD_WIDTH * j);
} else {
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
}
if (variance_half_precision) {
variance_4[j].data =
SIMD_LOAD_HALF(variance_cast_h + i + SIMD_WIDTH * j);
} else {
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
}

if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
Expand Down Expand Up @@ -412,8 +497,18 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}

SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data);
if (momentum_half_precision) {
SIMD_STORE_HALF((float *)(momentum_cast_h + i + SIMD_WIDTH * j),
momentum_4[j].data);
} else {
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
}
if (variance_half_precision) {
SIMD_STORE_HALF((float *)(variance_cast_h + i + SIMD_WIDTH * j),
variance_4[j].data);
} else {
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
}
}
}
}
Expand All @@ -423,9 +518,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, loss_scale);
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}

void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
Expand All @@ -447,7 +546,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf), loss_scale);
(grads.options().dtype() == at::kHalf),
(exp_avg.options().dtype() == at::kHalf),
(exp_avg_sq.options().dtype() == at::kHalf), loss_scale);
}

namespace py = pybind11;
Expand Down
11 changes: 6 additions & 5 deletions colossalai/kernel/cuda_native/csrc/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@ union AVX_Data {

#endif

#define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1);
#define STEP(SPAN) \
void Step_##SPAN( \
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
size_t _param_size, bool param_half_precision = false, \
bool grad_half_precision = false, bool momentum_half_precision = false, \
bool variance_half_precision = false, float loss_scale = -1);

class Adam_Optimizer {
public:
Expand Down
3 changes: 1 addition & 2 deletions colossalai/nn/optimizer/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def step(self, closure=None, div_scale: float = -1):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
# FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
Expand Down
3 changes: 1 addition & 2 deletions colossalai/nn/optimizer/hybrid_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def step(self, closure=None, div_scale: float = -1):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
# FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
Expand Down
7 changes: 2 additions & 5 deletions tests/test_optimizer/test_adam_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
_FUSED_ALLOWED_P_G_TYPES = [
(torch.float, torch.half),
(torch.float, torch.float),
(torch.half, torch.float),
(torch.half, torch.half),
(torch.bfloat16, torch.float),
(torch.float, torch.bfloat16),
(torch.bfloat16, torch.bfloat16),
]

_CPU_ALLOWED_P_G_TYPES = [
(torch.float, torch.half),
(torch.float, torch.float),
(torch.half, torch.float),
(torch.half, torch.half),
]

Expand Down Expand Up @@ -138,8 +135,8 @@ def check_adam_kernel(
master_exp_avg_sq = torch.zeros_like(master_p)
p = master_p.clone().to(p_dtype)
g = master_g.clone().to(g_dtype)
exp_avg = master_exp_avg.clone()
exp_avg_sq = master_exp_avg_sq.clone()
exp_avg = master_exp_avg.clone().to(p_dtype)
exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)

for step in range(1, 1 + n_steps):
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_optimizer/test_adam_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
(torch.float, torch.float), # pure fp32
(torch.float, torch.half), # fp16 amp
(torch.float, torch.bfloat16), # bfloat16 amp
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
]

N_STEPS = 3
Expand Down
Loading