Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/fp32 weight master copy #328

Merged
merged 5 commits into from
May 1, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ __device__ inline float lerp(float start, float end, float weight) {

// Termplate type T instead of floatx
template <typename Tp, typename Tg>
__global__ void adamw_kernel3(Tp* params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
__global__ void adamw_kernel3(Tp* params_memory, float* master_params, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay,
unsigned int seed) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -1176,10 +1176,18 @@ __global__ void adamw_kernel3(Tp* params_memory, Tg* grads_memory, float* m_memo
m /= beta1_correction; // m_hat
v /= beta2_correction; // v_hat
// update the parameters (weight/bias)
float param = (float)params_memory[i] - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * (float)params_memory[i]));
unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed);
// todo - explain stochastic rounding here
stochastic_rounding(param, &params_memory[i], random);
float old_param = master_params != NULL ? master_params[i] : (float)params_memory[i];
float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param));
// if we have master parameters, directly update the two weight copies
if (master_params != NULL) {
params_memory[i] = (floatX)param; // low-precision copy, for use in the forward pass
master_params[i] = param; // float copy, for use in the next parameter update
} else {
// without a master copy of params in float, do a direct update in low precision
// and use stochastic rounding to mitigate loss of training stability
unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed);
stochastic_rounding(param, &params_memory[i], random);
}
}

struct SoftmaxParams {
Expand Down Expand Up @@ -1277,6 +1285,11 @@ __global__ void fused_classifier_kernel3(floatX* logits, floatX* losses, floatX*
}
}

__global__ void copy_kernel(float* dst, const floatX* src, size_t n) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) { dst[i] = (float)src[i]; }
}

// ----------------------------------------------------------------------------
// kernel launchers

Expand Down Expand Up @@ -1822,6 +1835,7 @@ typedef struct {
// buffers for the AdamW optimizer
float* m_memory;
float* v_memory;
float* master_weights; // is NULL unless fp32 weights is enabled.
// the activations of the model, and their sizes
ActivationTensors acts;
size_t act_sizes[NUM_ACTIVATION_TENSORS];
Expand All @@ -1840,6 +1854,7 @@ typedef struct {
float accumulated_mean_loss; // Mean loss after aggregating it on all GPUs
floatX* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost
unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc.
int use_master_weights;
} GPT2;

void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
Expand Down Expand Up @@ -1899,6 +1914,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
model->grads_memory = NULL;
model->m_memory = NULL;
model->v_memory = NULL;
model->master_weights = NULL;
model->grads_acts_memory = NULL;
model->inputs = NULL;
model->targets = NULL;
Expand All @@ -1907,6 +1923,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
model->seq_len = 0;
model->mean_loss = -1.0f; // -1.0f will designate no loss
model->rng_state = 13371337;
model->use_master_weights = 1; // keep master weights copy in float for optim update?
}

void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
Expand Down Expand Up @@ -2229,14 +2246,20 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
cudaCheck(cudaMemset(model->v_memory, 0, model->num_parameters * sizeof(float)));
printf0("allocated %zu MiB for AdamW optimizer state m\n", (model->num_parameters * sizeof(float)) >> 20);
printf0("allocated %zu MiB for AdamW optimizer state v\n", (model->num_parameters * sizeof(float)) >> 20);
if (model->use_master_weights == 1) {
// allocate one more buffer to keep the master copy of weights as float, and copy the weights over
cudaCheck(cudaMalloc((void**)&model->master_weights, model->num_parameters * sizeof(float)));
copy_kernel<<<CEIL_DIV(model->num_parameters, 512), 512>>>(model->master_weights, (floatX*)model->params_memory, model->num_parameters);
}
}

int block_size = 512;
int num_blocks = CEIL_DIV(model->num_parameters, block_size);
float beta1_correction = 1.0f - powf(beta1, t);
float beta2_correction = 1.0f - powf(beta2, t);
unsigned int seed = random_u32(&model->rng_state);
adamw_kernel3<<<num_blocks, block_size>>>((floatX*)model->params_memory, (floatX*)model->grads_memory, model->m_memory, model->v_memory,
adamw_kernel3<<<num_blocks, block_size>>>((floatX*)model->params_memory, model->master_weights,
(floatX*)model->grads_memory, model->m_memory, model->v_memory,
model->num_parameters,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, seed);
cudaCheck(cudaGetLastError());
Expand All @@ -2247,6 +2270,7 @@ void gpt2_free(GPT2 *model) {
cudaCheck(cudaFree(model->grads_memory));
cudaCheck(cudaFree(model->m_memory));
cudaCheck(cudaFree(model->v_memory));
cudaCheck(cudaFree(model->master_weights));
cudaCheck(cudaFree(model->acts_memory));
cudaCheck(cudaFree(model->grads_acts_memory));
cudaCheck(cudaFree(model->inputs));
Expand Down Expand Up @@ -2408,6 +2432,7 @@ void error_usage() {
fprintf(stderr, " -g <int> genT, how many steps of inference we do (default = 64)\n");
fprintf(stderr, " -a <int> overfit a single batch? 0/1. useful for debugging\n");
fprintf(stderr, " -f <int> enable_tf32 override (default: 1, set to 0 to disable tf32)\n");
fprintf(stderr, " -w <int> keep f32 copy of weights for the optimizer? (default: 1)\n");
exit(EXIT_FAILURE);
}

Expand All @@ -2429,6 +2454,7 @@ int main(int argc, char *argv[]) {
int overfit_single_batch = 0; // useful for debugging, 1 = only load a single data batch once
int max_steps = -1;
int override_enable_tf32 = 1;
int use_master_weights = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for this naming change, the name in my original PR was just confusing :)

for (int i = 1; i < argc; i+=2) {
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
if (argv[i][0] != '-') { error_usage(); } // must start with dash
Expand All @@ -2446,6 +2472,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); }
else if (argv[i][1] == 'a') { overfit_single_batch = atoi(argv[i+1]); }
else if (argv[i][1] == 'f') { override_enable_tf32 = atoi(argv[i+1]); }
else if (argv[i][1] == 'w') { use_master_weights = atoi(argv[i+1]); }
else { error_usage(); }
}
printf0("+-----------------------+----------------------------------------------------+\n");
Expand All @@ -2462,6 +2489,7 @@ int main(int argc, char *argv[]) {
printf0("| sample_every | %-50d |\n", sample_every);
printf0("| genT | %-50d |\n", genT);
printf0("| overfit_single_batch | %-50d |\n", overfit_single_batch);
printf0("| use_master_weights | %-50s |\n", use_master_weights ? "enabled" : "disabled");
printf0("+-----------------------+----------------------------------------------------+\n");

// set up the device
Expand Down Expand Up @@ -2498,6 +2526,7 @@ int main(int argc, char *argv[]) {
// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_build_from_checkpoint(&model, load_filename);
model.use_master_weights = use_master_weights;
printf0("| load_filename | %-50s |\n", load_filename);
printf0("| max_sequence_length T | %-50d |\n", model.config.max_seq_len);
printf0("| vocab_size V | %-50d |\n", model.config.vocab_size);
Expand Down
Loading