From 43e7006fcac34e96a159720b1c74c9f88acbe0c8 Mon Sep 17 00:00:00 2001 From: FFengIll Date: Tue, 21 Nov 2023 12:15:11 +0800 Subject: [PATCH] refactor: extract bert graph build logic. --- bert.cpp | 358 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 189 insertions(+), 169 deletions(-) diff --git a/bert.cpp b/bert.cpp index a3fd819..ff19bcd 100644 --- a/bert.cpp +++ b/bert.cpp @@ -841,6 +841,176 @@ void bert_resize_ctx(bert_ctx *ctx, int32_t new_size) } } +// build the bert model graph with given tokens +static ggml_cgraph *bert_build(bert_ctx *ctx, struct ggml_context *ctx0, bert_vocab_id *const tokens, int N) +{ + const bert_model &model = ctx->model; + + const float eps = model.hparams.eps; + + const auto &hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_max_tokens = hparams.n_max_tokens; + const int n_head = hparams.n_head; + + const int d_head = n_embd / n_head; + + auto &mem_per_token = ctx->mem_per_token; + auto &buf_compute = ctx->buf_compute; + + ggml_cgraph *gf = ggml_new_graph(ctx0); + + // Embeddings. word_embeddings + token_type_embeddings + position_embeddings + // in bert, it is + // token_embedding + segment_embedding + position_embedding + struct ggml_tensor *token_layer = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(token_layer->data, tokens, N * ggml_element_size(token_layer)); + + struct ggml_tensor *token_types = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_set_zero(token_types); + + struct ggml_tensor *positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + for (int i = 0; i < N; i++) + { + ggml_set_i32_1d(positions, i, i); + } + + struct ggml_tensor *inpL = ggml_get_rows(ctx0, model.word_embeddings, token_layer); + + inpL = ggml_add(ctx0, + ggml_get_rows(ctx0, model.token_type_embeddings, token_types), + inpL); + inpL = ggml_add(ctx0, + ggml_get_rows(ctx0, model.position_embeddings, positions), + inpL); + + // embd norm + { + inpL = ggml_norm(ctx0, inpL, eps); + + inpL = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.ln_e_w, inpL), + inpL), + ggml_repeat(ctx0, model.ln_e_b, inpL)); + } + // layers + for (int il = 0; il < n_layer; il++) + { + struct ggml_tensor *cur = inpL; + + // self-attention (multiple head) + { + // linear + struct ggml_tensor *Qcur = cur; + Qcur = ggml_reshape_3d(ctx0, + ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, Qcur), + ggml_mul_mat(ctx0, model.layers[il].q_w, Qcur)), + d_head, n_head, N); + struct ggml_tensor *Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + + struct ggml_tensor *Kcur = cur; + Kcur = ggml_reshape_3d(ctx0, + ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, Kcur), + ggml_mul_mat(ctx0, model.layers[il].k_w, Kcur)), + d_head, n_head, N); + struct ggml_tensor *K = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); + + struct ggml_tensor *Vcur = cur; + Vcur = ggml_reshape_3d(ctx0, + ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, Vcur), + ggml_mul_mat(ctx0, model.layers[il].v_w, Vcur)), + d_head, n_head, N); + struct ggml_tensor *V = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); + + // Scaled Dot-Product Attention + // KQ = soft_max(KQ / sqrt(head width)) + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max(ctx0, + ggml_scale(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f / sqrt((float)d_head)))); + + V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + } + // attention output + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].o_b, cur), + ggml_mul_mat(ctx0, model.layers[il].o_w, cur)); + + // Add & Norm + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); + + // attention norm + { + cur = ggml_norm(ctx0, cur, eps); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].ln_att_w, cur), + cur), + ggml_repeat(ctx0, model.layers[il].ln_att_b, cur)); + } + struct ggml_tensor *att_output = cur; + + // Forward Feed + // intermediate_output = self.intermediate(attention_output) + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].ff_i_b, cur), + cur); + cur = ggml_gelu(ctx0, cur); + + // layer_output = self.output(intermediate_output, attention_output) + cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].ff_o_b, cur), + cur); + + // Add & Norm + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, att_output, cur); + + // output norm + { + cur = ggml_norm(ctx0, cur, eps); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].ln_out_w, cur), + cur), + ggml_repeat(ctx0, model.layers[il].ln_out_b, cur)); + } + inpL = cur; + } + inpL = ggml_cont(ctx0, ggml_transpose(ctx0, inpL)); + + // pooling + // FIXME: pooling method is hard code here + struct ggml_tensor *sum = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, 1); + ggml_set_f32(sum, 1.0f / N); + inpL = ggml_mul_mat(ctx0, inpL, sum); + + // normalizer + ggml_tensor *length = ggml_sqrt(ctx0, + ggml_sum(ctx0, ggml_sqr(ctx0, inpL))); + inpL = ggml_scale(ctx0, inpL, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length)); + + // run the computation + ggml_build_forward_expand(gf, inpL); + + return gf; +} + void bert_free(bert_ctx *ctx) { ggml_free(ctx->model.ctx); @@ -878,7 +1048,18 @@ void bert_eval_batch( } } - const float eps = model.hparams.eps; + const auto &hparams = model.hparams; + const int n_embd = hparams.n_embd; + const int n_max_tokens = hparams.n_max_tokens; + + auto &mem_per_token = ctx->mem_per_token; + auto &buf_compute = ctx->buf_compute; + + struct ggml_init_params params = { + .mem_size = buf_compute.size, + .mem_buffer = buf_compute.data, + .no_alloc = false, + }; // TODO: implement real batching for (int ba = 0; ba < n_batch_size; ba++) @@ -886,15 +1067,6 @@ void bert_eval_batch( const int N = n_tokens[ba]; const auto &tokens = batch_tokens[ba]; - const auto &hparams = model.hparams; - - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_max_tokens = hparams.n_max_tokens; - const int n_head = hparams.n_head; - - const int d_head = n_embd / n_head; - std::vector result; if (N > n_max_tokens) { @@ -902,165 +1074,15 @@ void bert_eval_batch( return; } - auto &mem_per_token = ctx->mem_per_token; - auto &buf_compute = ctx->buf_compute; - - struct ggml_init_params params = { - .mem_size = buf_compute.size, - .mem_buffer = buf_compute.data, - .no_alloc = false, - }; - struct ggml_context *ctx0 = ggml_init(params); - ggml_cgraph *gf = ggml_new_graph(ctx0); - - // Embeddings. word_embeddings + token_type_embeddings + position_embeddings - // in bert, it is - // token_embedding + segment_embedding + position_embedding - struct ggml_tensor *token_layer = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(token_layer->data, tokens, N * ggml_element_size(token_layer)); - - struct ggml_tensor *token_types = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - ggml_set_zero(token_types); - - struct ggml_tensor *positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - for (int i = 0; i < N; i++) - { - ggml_set_i32_1d(positions, i, i); - } - - struct ggml_tensor *inpL = ggml_get_rows(ctx0, model.word_embeddings, token_layer); - - inpL = ggml_add(ctx0, - ggml_get_rows(ctx0, model.token_type_embeddings, token_types), - inpL); - inpL = ggml_add(ctx0, - ggml_get_rows(ctx0, model.position_embeddings, positions), - inpL); - - // embd norm - { - inpL = ggml_norm(ctx0, inpL, eps); - - inpL = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.ln_e_w, inpL), - inpL), - ggml_repeat(ctx0, model.ln_e_b, inpL)); - } - // layers - for (int il = 0; il < n_layer; il++) - { - struct ggml_tensor *cur = inpL; - - // self-attention (multiple head) - { - // linear - struct ggml_tensor *Qcur = cur; - Qcur = ggml_reshape_3d(ctx0, - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, Qcur), - ggml_mul_mat(ctx0, model.layers[il].q_w, Qcur)), - d_head, n_head, N); - struct ggml_tensor *Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - - struct ggml_tensor *Kcur = cur; - Kcur = ggml_reshape_3d(ctx0, - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, Kcur), - ggml_mul_mat(ctx0, model.layers[il].k_w, Kcur)), - d_head, n_head, N); - struct ggml_tensor *K = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); - - struct ggml_tensor *Vcur = cur; - Vcur = ggml_reshape_3d(ctx0, - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, Vcur), - ggml_mul_mat(ctx0, model.layers[il].v_w, Vcur)), - d_head, n_head, N); - struct ggml_tensor *V = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); - - // Scaled Dot-Product Attention - // KQ = soft_max(KQ / sqrt(head width)) - struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max(ctx0, - ggml_scale(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f / sqrt((float)d_head)))); - - V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); - struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - cur = ggml_cpy(ctx0, - KQV, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - } - // attention output - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].o_b, cur), - ggml_mul_mat(ctx0, model.layers[il].o_w, cur)); - - // Add & Norm - // re-add the layer input - cur = ggml_add(ctx0, cur, inpL); - - // attention norm - { - cur = ggml_norm(ctx0, cur, eps); - - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].ln_att_w, cur), - cur), - ggml_repeat(ctx0, model.layers[il].ln_att_b, cur)); - } - struct ggml_tensor *att_output = cur; - - // Forward Feed - // intermediate_output = self.intermediate(attention_output) - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].ff_i_b, cur), - cur); - cur = ggml_gelu(ctx0, cur); + ggml_cgraph *gf = bert_build(ctx, ctx0, tokens, N); - // layer_output = self.output(intermediate_output, attention_output) - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].ff_o_b, cur), - cur); - - // Add & Norm - // attentions bypass the intermediate layer - cur = ggml_add(ctx0, att_output, cur); - - // output norm - { - cur = ggml_norm(ctx0, cur, eps); - - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].ln_out_w, cur), - cur), - ggml_repeat(ctx0, model.layers[il].ln_out_b, cur)); - } - inpL = cur; - } - inpL = ggml_cont(ctx0, ggml_transpose(ctx0, inpL)); - - // pooling - // FIXME: pooling method is hard code here - struct ggml_tensor *sum = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, 1); - ggml_set_f32(sum, 1.0f / N); - inpL = ggml_mul_mat(ctx0, inpL, sum); - - // normalizer - ggml_tensor *length = ggml_sqrt(ctx0, - ggml_sum(ctx0, ggml_sqr(ctx0, inpL))); - inpL = ggml_scale(ctx0, inpL, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length)); - - ggml_tensor *output = inpL; - // run the computation - ggml_build_forward_expand(gf, output); ggml_graph_compute_with_ctx(ctx0, gf, n_threads); + ggml_free(ctx0); + + // we can get result from the graph directly + struct ggml_tensor *res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor *embeddings = gf->nodes[gf->n_nodes - 2]; // float *dat = ggml_get_data_f32(output); // pretty_print_tensor(dat, output->ne, output->nb, output->n_dims - 1, ""); @@ -1073,7 +1095,7 @@ void bert_eval_batch( if (!mem_req_mode) { - memcpy(batch_embeddings[ba], (float *)ggml_get_data(output), sizeof(float) * n_embd); + memcpy(batch_embeddings[ba], (float *)ggml_get_data(embeddings), sizeof(float) * n_embd); } else { @@ -1082,8 +1104,6 @@ void bert_eval_batch( printf("used_mem = %zu KB \n", ggml_used_mem(ctx0) / 1024); printf("mem_per_token = %zu KB \n", mem_per_token / 1024); } - - ggml_free(ctx0); } }