Skip to content

Commit

Permalink
Switch rmsnorm weights to float
Browse files Browse the repository at this point in the history
This doesn't have a tangible impact on precision, but this will make
future conversions of actual weights a little easier to do mechanically
as we don't want these to use fp8 et al.

This costs us ~20us/tok (out of ~16ms), which is something that may need
to be reconsidered later.
  • Loading branch information
zeux committed Dec 29, 2023
1 parent 112fd49 commit e3a250a
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/infer.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void prepare(struct Transformer* transformer) {
#endif
}

static void rmsnorm(float* o, float* x, dtype_t* weight, int size) {
static void rmsnorm(float* o, float* x, float* weight, int size) {
// calculate sum of squares
float ss = 0.0f;
for (int j = 0; j < size; j++) {
Expand Down
10 changes: 5 additions & 5 deletions src/infer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ extern "C" void prepare_cuda(struct Transformer* transformer) {
int kv_dim = (config->dim * config->n_kv_heads) / config->n_heads;

for (int l = 0; l < config->n_layers; ++l) {
weights->rms_att_weight[l] = (dtype_t*)cuda_devicecopy(weights->rms_att_weight[l], dim * sizeof(dtype_t));
weights->rms_ffn_weight[l] = (dtype_t*)cuda_devicecopy(weights->rms_ffn_weight[l], dim * sizeof(dtype_t));
weights->rms_att_weight[l] = (float*)cuda_devicecopy(weights->rms_att_weight[l], dim * sizeof(float));
weights->rms_ffn_weight[l] = (float*)cuda_devicecopy(weights->rms_ffn_weight[l], dim * sizeof(float));

weights->wq[l] = (dtype_t*)cuda_devicecopy(weights->wq[l], dim * dim * sizeof(dtype_t));
weights->wk[l] = (dtype_t*)cuda_devicecopy(weights->wk[l], dim * kv_dim * sizeof(dtype_t));
Expand All @@ -72,7 +72,7 @@ extern "C" void prepare_cuda(struct Transformer* transformer) {
weights->w3[l] = (dtype_t*)cuda_devicecopy(weights->w3[l], dim * hidden_dim * sizeof(dtype_t));
}

weights->rms_final_weight = (dtype_t*)cuda_devicecopy(weights->rms_final_weight, dim * sizeof(dtype_t));
weights->rms_final_weight = (float*)cuda_devicecopy(weights->rms_final_weight, dim * sizeof(float));
weights->token_embedding_table = (dtype_t*)cuda_devicecopy(weights->token_embedding_table, config->vocab_size * dim * sizeof(dtype_t));
weights->wcls = (dtype_t*)cuda_devicecopy(weights->wcls, dim * config->vocab_size * sizeof(dtype_t));

Expand All @@ -98,7 +98,7 @@ __global__ static void kernel_embed(float* o, dtype_t* weight, int size) {
o[i] = float(weight[i]);
}

__global__ static void kernel_rmsnorm(float* o, float* x, dtype_t* weight, int size) {
__global__ static void kernel_rmsnorm(float* o, float* x, float* weight, int size) {
int i = threadIdx.x;
int blockSize = blockDim.x;

Expand All @@ -118,7 +118,7 @@ __global__ static void kernel_rmsnorm(float* o, float* x, dtype_t* weight, int s

// normalize and scale
for (int j = i; j < size; j += blockSize) {
o[j] = float(weight[j]) * (ss * x[j]);
o[j] = weight[j] * (ss * x[j]);
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ struct Weights {
// token embedding table
dtype_t* token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
dtype_t* rms_att_weight[MAX_LAYERS]; // (dim) rmsnorm weights
dtype_t* rms_ffn_weight[MAX_LAYERS]; // (dim)
float* rms_att_weight[MAX_LAYERS]; // (dim) rmsnorm weights
float* rms_ffn_weight[MAX_LAYERS]; // (dim)
// weights for matmuls. note dim == n_heads * head_size
dtype_t* wq[MAX_LAYERS]; // (dim, n_heads * head_size)
dtype_t* wk[MAX_LAYERS]; // (dim, n_kv_heads * head_size)
Expand All @@ -46,8 +46,8 @@ struct Weights {
dtype_t* w2[MAX_LAYERS]; // (dim, hidden_dim)
dtype_t* w3[MAX_LAYERS]; // (hidden_dim, dim)
// final rmsnorm
dtype_t* rms_final_weight; // (dim,)
// (optional) classifier weights for the logits, on the last layer
float* rms_final_weight; // (dim,)
// classifier weights for the logits, on the last layer
dtype_t* wcls;
};

Expand Down
9 changes: 6 additions & 3 deletions src/run.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,21 @@ void build_transformer(struct Config* config, struct Weights* weights, struct Te
enum DType dtype = dt_f16;

weights->token_embedding_table = (dtype_t*)tensors_get(tensors, "model.embed.weight", 0, dtype, (int[]){config->vocab_size, config->dim, 0, 0});

for (int l = 0; l < config->n_layers; ++l) {
weights->rms_att_weight[l] = (dtype_t*)tensors_get(tensors, "model.layers.%d.attn.norm.weight", l, dtype, (int[]){config->dim, 0, 0, 0});
weights->rms_att_weight[l] = (float*)tensors_get(tensors, "model.layers.%d.attn.norm.weight", l, dt_f32, (int[]){config->dim, 0, 0, 0});
weights->wq[l] = (dtype_t*)tensors_get(tensors, "model.layers.%d.attn.wq.weight", l, dtype, (int[]){config->dim, config->n_heads * head_size, 0, 0});
weights->wk[l] = (dtype_t*)tensors_get(tensors, "model.layers.%d.attn.wk.weight", l, dtype, (int[]){config->n_kv_heads * head_size, config->dim, 0, 0});
weights->wv[l] = (dtype_t*)tensors_get(tensors, "model.layers.%d.attn.wv.weight", l, dtype, (int[]){config->n_kv_heads * head_size, config->dim, 0, 0});
weights->wo[l] = (dtype_t*)tensors_get(tensors, "model.layers.%d.attn.wo.weight", l, dtype, (int[]){config->n_heads * head_size, config->dim, 0, 0});
weights->rms_ffn_weight[l] = (dtype_t*)tensors_get(tensors, "model.layers.%d.mlp.norm.weight", l, dtype, (int[]){config->dim, 0, 0, 0});

weights->rms_ffn_weight[l] = (float*)tensors_get(tensors, "model.layers.%d.mlp.norm.weight", l, dt_f32, (int[]){config->dim, 0, 0, 0});
weights->w1[l] = (dtype_t*)tensors_get(tensors, "model.layers.%d.mlp.w1.weight", l, dtype, (int[]){config->hidden_dim, config->dim, 0, 0});
weights->w2[l] = (dtype_t*)tensors_get(tensors, "model.layers.%d.mlp.w2.weight", l, dtype, (int[]){config->dim, config->hidden_dim, 0, 0});
weights->w3[l] = (dtype_t*)tensors_get(tensors, "model.layers.%d.mlp.w3.weight", l, dtype, (int[]){config->hidden_dim, config->dim, 0, 0});
}
weights->rms_final_weight = (dtype_t*)tensors_get(tensors, "model.norm.weight", 0, dtype, (int[]){config->dim, 0, 0, 0});

weights->rms_final_weight = (float*)tensors_get(tensors, "model.norm.weight", 0, dt_f32, (int[]){config->dim, 0, 0, 0});
weights->wcls = (dtype_t*)tensors_get(tensors, "model.output.weight", 0, dtype, (int[]){config->vocab_size, config->dim, 0, 0});
}

Expand Down
8 changes: 5 additions & 3 deletions tools/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,19 @@ def permute_reverse(w, heads):
tensors["model.embed.weight"] = weights["model.embed_tokens.weight"].to(dtype)

for l in range(config["num_hidden_layers"]):
tensors[f"model.layers.{l}.attn.norm.weight"] = weights[f"model.layers.{l}.input_layernorm.weight"].to(dtype)
tensors[f"model.layers.{l}.attn.norm.weight"] = weights[f"model.layers.{l}.input_layernorm.weight"].float()
tensors[f"model.layers.{l}.attn.wq.weight"] = permute_reverse(weights[f"model.layers.{l}.self_attn.q_proj.weight"], config["num_attention_heads"]).to(dtype)
tensors[f"model.layers.{l}.attn.wk.weight"] = permute_reverse(weights[f"model.layers.{l}.self_attn.k_proj.weight"], config["num_key_value_heads"]).to(dtype)
tensors[f"model.layers.{l}.attn.wv.weight"] = weights[f"model.layers.{l}.self_attn.v_proj.weight"].to(dtype)
tensors[f"model.layers.{l}.attn.wo.weight"] = weights[f"model.layers.{l}.self_attn.o_proj.weight"].to(dtype)
tensors[f"model.layers.{l}.mlp.norm.weight"] = weights[f"model.layers.{l}.post_attention_layernorm.weight"].to(dtype)

tensors[f"model.layers.{l}.mlp.norm.weight"] = weights[f"model.layers.{l}.post_attention_layernorm.weight"].float()

tensors[f"model.layers.{l}.mlp.w1.weight"] = weights[f"model.layers.{l}.mlp.gate_proj.weight"].to(dtype)
tensors[f"model.layers.{l}.mlp.w2.weight"] = weights[f"model.layers.{l}.mlp.down_proj.weight"].to(dtype)
tensors[f"model.layers.{l}.mlp.w3.weight"] = weights[f"model.layers.{l}.mlp.up_proj.weight"].to(dtype)

tensors["model.norm.weight"] = weights["model.norm.weight"].to(dtype)
tensors["model.norm.weight"] = weights["model.norm.weight"].float()
tensors["model.output.weight"] = weights["lm_head.weight"].to(dtype)

# metadata values must be strings in safetensors
Expand Down

0 comments on commit e3a250a

Please sign in to comment.