Skip to content

Commit

Permalink
refactor: follow ggml changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
FFengIll committed Nov 21, 2023
1 parent e25766d commit b0a2b2e
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions bert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ struct bert_loader
printf("%s: ggml ctx size = %6.2f MB\n", __func__, mmapped_size_p / (1024.0 * 1024.0));
}

struct ggml_tensor *create_tensor_for(struct ggml_context *ctx, struct ggml_tensor *meta, ggml_backend backend)
struct ggml_tensor *create_tensor_for(struct ggml_context *ctx, struct ggml_tensor *meta, ggml_backend_type backend)
{
if (backend != GGML_BACKEND_CPU)
{
Expand All @@ -363,7 +363,7 @@ struct bert_loader
return tensor;
}

struct ggml_tensor *create_tensor(struct ggml_context *ctx, const std::string &name, const std::vector<int64_t> &ne, ggml_backend backend)
struct ggml_tensor *create_tensor(struct ggml_context *ctx, const std::string &name, const std::vector<int64_t> &ne, ggml_backend_type backend)
{
struct ggml_tensor *cur = ggml_get_tensor(ctx_meta, name.c_str());

Expand Down Expand Up @@ -613,7 +613,7 @@ struct bert_loader
const int n_vocab = hparams.n_vocab;
const int n_vocab_size = hparams.n_vocab_size;

const ggml_backend backend = GGML_BACKEND_CPU;
const ggml_backend_type backend = GGML_BACKEND_CPU;

size_t ctx_size;
size_t mmapped_size;
Expand Down Expand Up @@ -912,7 +912,7 @@ void bert_eval_batch(
};

struct ggml_context *ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
ggml_cgraph *gf = ggml_new_graph(ctx0);

// Embeddings. word_embeddings + token_type_embeddings + position_embeddings
// in bert, it is
Expand Down Expand Up @@ -1059,8 +1059,8 @@ void bert_eval_batch(

ggml_tensor *output = inpL;
// run the computation
ggml_build_forward_expand(&gf, output);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
ggml_build_forward_expand(gf, output);
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);

// float *dat = ggml_get_data_f32(output);
// pretty_print_tensor(dat, output->ne, output->nb, output->n_dims - 1, "");
Expand Down

0 comments on commit b0a2b2e

Please sign in to comment.