Skip to content

Commit

Permalink
Add option to remove biases
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jul 10, 2024
1 parent db2454f commit 59dd59b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
3 changes: 3 additions & 0 deletions llmc/zero.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,9 @@ void multi_gpu_async_reduce_gradient(
cudaCheck(cudaStreamWaitEvent(multi_gpu_config->nccl_stream, multi_gpu_config->compute_nccl_sync));
ncclCheck(ncclGroupStart()); // NCCL group: aggregate all pointers in a single NCCL GPU kernel.
for (int i = 0; i < N; ++i) {
if (pointers[i] == NULL) {
continue;
}
if(multi_gpu_config->zero_stage == 0) {
ncclCheck(ncclAllReduce(
pointers[i], pointers[i],
Expand Down
34 changes: 22 additions & 12 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ typedef struct {
// todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch?
int* workload_indices; // encoder_backward, B*T*num_c_groups (int)
int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case
int no_biases; // disable biases in attn & fc layers
} GPT2;

void gpt2_init_common(GPT2 *model) {
Expand Down Expand Up @@ -396,6 +397,7 @@ void gpt2_init_common(GPT2 *model) {
model->use_master_weights = 1; // safe default: do keep master weights in fp32
model->recompute = 1; // good default: recompute gelu but not layernorm
model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main())
model->no_biases = 0; // default: use biases
}

void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) {
Expand Down Expand Up @@ -639,15 +641,15 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {

// get the pointers of the weights for this layer
floatX* l_qkvw = params.qkvw + l * 3*C * C;
floatX* l_qkvb = params.qkvb + l * 3*C;
floatX* l_qkvb = model->no_biases ? NULL : params.qkvb + l * 3*C;
floatX* l_attprojw = params.attprojw + l * C * C;
floatX* l_attprojb = params.attprojb + l * C;
floatX* l_attprojb = model->no_biases ? NULL : params.attprojb + l * C;
floatX* l_ln2w = params.ln2w + l * C;
floatX* l_ln2b = params.ln2b + l * C;
floatX* l_fcw = params.fcw + l * 4*C * C;
floatX* l_fcb = params.fcb + l * 4*C;
floatX* l_fcb = model->no_biases ? NULL : params.fcb + l * 4*C;
floatX* l_fcprojw = params.fcprojw + l * C * 4*C;
floatX* l_fcprojb = params.fcprojb + l * C;
floatX* l_fcprojb = model->no_biases ? NULL : params.fcprojb + l * C;

// get the pointers of the activations for this layer
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;
Expand Down Expand Up @@ -821,15 +823,15 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
floatX* dl_ln1w = grads.ln1w + l * C;
floatX* dl_ln1b = grads.ln1b + l * C;
floatX* dl_qkvw = grads.qkvw + l * 3*C * C;
floatX* dl_qkvb = grads.qkvb + l * 3*C;
floatX* dl_qkvb = model->no_biases ? NULL : grads.qkvb + l * 3*C;
floatX* dl_attprojw = grads.attprojw + l * C * C;
floatX* dl_attprojb = grads.attprojb + l * C;
floatX* dl_attprojb = model->no_biases ? NULL :grads.attprojb + l * C;
floatX* dl_ln2w = grads.ln2w + l * C;
floatX* dl_ln2b = grads.ln2b + l * C;
floatX* dl_fcw = grads.fcw + l * 4*C * C;
floatX* dl_fcb = grads.fcb + l * 4*C;
floatX* dl_fcb = model->no_biases ? NULL : grads.fcb + l * 4*C;
floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C;
floatX* dl_fcprojb = grads.fcprojb + l * C;
floatX* dl_fcprojb = model->no_biases ? NULL : grads.fcprojb + l * C;
// get the pointers of the activations for this layer
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;
float* l_ln1_mean = acts.ln1_mean + l * B * T;
Expand Down Expand Up @@ -886,11 +888,11 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
if(last_step) {
floatX* const pointers[] = {
dl_ln1w, dl_ln1b,
dl_qkvw, dl_qkvb,
dl_attprojw, dl_attprojb,
dl_qkvw, model->no_biases ? NULL : dl_qkvb,
dl_attprojw, model->no_biases ? NULL : dl_attprojb,
dl_ln2w, dl_ln2b,
dl_fcw, dl_fcb,
dl_fcprojw, dl_fcprojb
dl_fcw, model->no_biases ? NULL : dl_fcb,
dl_fcprojw, model->no_biases ? NULL : dl_fcprojb
};
const size_t nelem[] = {
C, C,
Expand Down Expand Up @@ -1035,6 +1037,11 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
// AdamW update
// handle adamw for all the transformer blocks
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {

if (model->no_biases && (i == 5 || i == 7 || i == 11 || i == 13)) {
continue;
}

// generate a unique seed for each tensor
unsigned int seed = random_u32(&model->rng_state);

Expand Down Expand Up @@ -1419,6 +1426,7 @@ int main(int argc, char *argv[]) {
int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu
int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training
int hellaswag_eval = 0;
int no_biases = 0; // default: include biases
// multi-node settings
int num_processes = 1; // this should be set by the slurm environment
int process_rank = 0; // this should be set by the slurm environment
Expand Down Expand Up @@ -1467,6 +1475,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 's' && argv[i][2] == 'g') { skip_update_gradz = atof(argv[i+1]); }
else if (argv[i][1] == 'n' && argv[i][2] == 'k') { checkpoints_keep = atoi(argv[i+1]); }
else if (argv[i][1] == 'n' && argv[i][2] == 'm') { major_checkpoint_every = atoi(argv[i+1]); }
else if (argv[i][1] == 'n' && argv[i][2] == 'b') { no_biases = atoi(argv[i+1]); }
else { error_usage(); }
}

Expand Down Expand Up @@ -1541,6 +1550,7 @@ int main(int argc, char *argv[]) {
// build the GPT-2 model
GPT2 model;
gpt2_init_common(&model);
model.no_biases = no_biases;
// if load_filename is of the form "dX" where X is an integer (e.g. d12), then we build
// a random model with the depth of the model specified by X (e.g. 12). otherwise interpret
// this variable as a checkpoint filename, and load that checkpoint
Expand Down

0 comments on commit 59dd59b

Please sign in to comment.