Skip to content

Commit

Permalink
fix webassembly, update README
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Aug 17, 2024
1 parent 5620754 commit 5428681
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 80 deletions.
2 changes: 2 additions & 0 deletions examples/mnist/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
data/
*.gguf
*.ggml
41 changes: 31 additions & 10 deletions examples/mnist/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# MNIST Examples for GGML

This directory contains simple examples of how to use GGML for training and inference using the [MNIST dataset](https://yann.lecun.com/exdb/mnist/).
All commands listed in this README assume the working directory to be `examples/mnist`.
Please note that training in GGML is a work-in-progress and not production ready.

## Obtaining the data
Expand All @@ -13,7 +14,7 @@ For our first example we will train a fully connected network.
To train a fully connected model in PyTorch and save it as a GGUF file, run:

```bash
$ python3 ../examples/mnist/mnist-train-fc.py mnist-fc-f32.gguf
$ python3 mnist-train-fc.py mnist-fc-f32.gguf

...

Expand All @@ -30,7 +31,7 @@ The training script includes an evaluation of the model on the test set.
To evaluate the model using GGML, run:

```bash
$ bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte
$ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte

________________________________________________________
________________________________________________________
Expand Down Expand Up @@ -77,7 +78,7 @@ In addition to the evaluation on the test set the GGML evaluation also prints a
To train a fully connected model using GGML run:

``` bash
$ bin/mnist-train mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte
$ ../../build/bin/mnist-train mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte
```

It can then be evaluated with the same binary as above.
Expand All @@ -91,7 +92,7 @@ It can be evaluated using the `mnist-eval` binary by substituting the argument f
To train a convolutional network using TensorFlow run:

```bash
$ python3 ../examples/mnist/mnist-train-cnn.py mnist-cnn-f32.gguf
$ python3 mnist-train-cnn.py mnist-cnn-f32.gguf

...

Expand All @@ -103,7 +104,7 @@ GGUF model saved to 'mnist-cnn-f32.gguf'
The saved model can be evaluated using the `mnist-eval` binary:

```bash
$ bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte
$ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte

________________________________________________________
________________________________________________________
Expand Down Expand Up @@ -149,18 +150,38 @@ main: test_acc=98.40+-0.13%
Like with the fully connected network the convolutional network can also be trained using GGML:

``` bash
bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte
$ ../../build/bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte
```

As always, the evaluation is done using `mnist-eval` and like with the fully connected network the GGML graph is exported to `mnist-cnn-f32.ggml`.

## Web demo

The example can be compiled with Emscripten like this:
The evaluation code can be compiled to WebAssembly using [Emscripten](https://emscripten.org/) (may need to re-login to update `$PATH` after installation).
First, copy the GGUF file of either of the trained models to `examples/mnist` and name it `mnist-f32.gguf`.
Copy the test set to `examples/mnist` and name it `t10k-images-idx3-ubyte`.
Symlinking these files will *not* work!
Compile the code like so:

```bash
cd examples/mnist
emcc -I../../include -I../../include/ggml -I../../examples ../../src/ggml.c ../../src/ggml-quants.c main.cpp -o web/mnist.js -s EXPORTED_FUNCTIONS='["_wasm_eval","_wasm_random_digit","_malloc","_free"]' -s EXPORTED_RUNTIME_METHODS='["ccall"]' -s ALLOW_MEMORY_GROWTH=1 --preload-file models/mnist
$ emcc -I../../include -I../../include/ggml -I../../examples ../../src/ggml.c ../../src/ggml-quants.c ../../src/ggml-aarch64.c mnist-common.cpp -o web/mnist.js -s EXPORTED_FUNCTIONS='["_wasm_eval","_wasm_random_digit","_malloc","_free"]' -s EXPORTED_RUNTIME_METHODS='["ccall"]' -s ALLOW_MEMORY_GROWTH=1 --preload-file mnist-f32.gguf --preload-file t10k-images-idx3-ubyte
```

Online demo: https://mnist.ggerganov.com
The compilation output is in `examples/mnist/web`.
To run it, you need an HTTP server.
For example:

``` bash
$ cd web
$ python3 -m http.server

Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ...
```

The web demo can then be accessed via the link printed on the console.
Simply draw a digit on the canvas and the model will try to predict what it's supposed to be.
Alternatively, click the "Random" button to retrieve a random digit from the test set.
Be aware that like all neural networks the one we trained is susceptible to distributional shift:
if the numbers you draw look different than the ones in the training set
(e.g. because they're not centered) the model will perform comparatively worse.
An online demo can be accessed [here](https://mnist.ggerganov.com).
107 changes: 79 additions & 28 deletions examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,28 +308,30 @@ mnist_model mnist_model_init_random(const std::string & arch) {
return model;
}

void mnist_model_build(mnist_model & model) {
void mnist_model_build(mnist_model & model, const int nbatch) {
model.nbatch = nbatch;

if (model.arch == "mnist-fc") {
ggml_set_param(model.ctx_compute, model.fc1_weight);
ggml_set_param(model.ctx_compute, model.fc1_bias);
ggml_set_param(model.ctx_compute, model.fc2_weight);
ggml_set_param(model.ctx_compute, model.fc2_bias);

model.images = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NBATCH);
model.images = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NINPUT, model.nbatch);
ggml_set_input(model.images);
ggml_set_name(model.images, "images");

ggml_tensor * fc1_bias = model.fc1_bias;
if (MNIST_NBATCH > 1) {
if (model.nbatch > 1) {
fc1_bias = ggml_repeat(model.ctx_compute,
model.fc1_bias,
ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NHIDDEN, MNIST_NBATCH));
ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NHIDDEN, model.nbatch));
}
ggml_tensor * fc2_bias = model.fc2_bias;
if (MNIST_NBATCH > 1) {
if (model.nbatch > 1) {
fc2_bias = ggml_repeat(model.ctx_compute,
model.fc2_bias,
ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, MNIST_NBATCH));
ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, model.nbatch));
}

ggml_tensor * fc1 = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
Expand All @@ -346,15 +348,15 @@ void mnist_model_build(mnist_model & model) {
ggml_set_param(model.ctx_compute, model.dense_weight);
ggml_set_param(model.ctx_compute, model.dense_bias);

model.images = ggml_new_tensor_4d(model.ctx_compute, GGML_TYPE_F32, 28, 28, 1, MNIST_NBATCH);
model.images = ggml_new_tensor_4d(model.ctx_compute, GGML_TYPE_F32, 28, 28, 1, model.nbatch);
ggml_set_input(model.images);
ggml_set_name(model.images, "images");

struct ggml_tensor * conv2d_1_bias = model.conv1_bias;
if (MNIST_NBATCH > 1) {
if (model.nbatch > 1) {
int64_t ne[4];
memcpy(ne, conv2d_1_bias->ne, sizeof(ne));
ne[3] = MNIST_NBATCH;
ne[3] = model.nbatch;
conv2d_1_bias = ggml_repeat(model.ctx_compute, conv2d_1_bias, ggml_new_tensor(model.ctx_compute, GGML_TYPE_F32, 4, ne));
}

Expand All @@ -364,19 +366,19 @@ void mnist_model_build(mnist_model & model) {
GGML_ASSERT(conv1_out->ne[0] == MNIST_HW);
GGML_ASSERT(conv1_out->ne[1] == MNIST_HW);
GGML_ASSERT(conv1_out->ne[2] == MNIST_CNN_NCB);
GGML_ASSERT(conv1_out->ne[3] == MNIST_NBATCH);
GGML_ASSERT(conv1_out->ne[3] == model.nbatch);

struct ggml_tensor * conv2_in = ggml_pool_2d(model.ctx_compute, conv1_out, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
GGML_ASSERT(conv2_in->ne[0] == MNIST_HW/2);
GGML_ASSERT(conv2_in->ne[1] == MNIST_HW/2);
GGML_ASSERT(conv2_in->ne[2] == MNIST_CNN_NCB);
GGML_ASSERT(conv2_in->ne[3] == MNIST_NBATCH);
GGML_ASSERT(conv2_in->ne[3] == model.nbatch);

struct ggml_tensor * conv2d_2_bias = model.conv2_bias;
if (MNIST_NBATCH > 1) {
if (model.nbatch > 1) {
int64_t ne[4];
memcpy(ne, conv2d_2_bias->ne, sizeof(ne));
ne[3] = MNIST_NBATCH;
ne[3] = model.nbatch;
conv2d_2_bias = ggml_repeat(model.ctx_compute, conv2d_2_bias, ggml_new_tensor(model.ctx_compute, GGML_TYPE_F32, 4, ne));
}

Expand All @@ -386,27 +388,27 @@ void mnist_model_build(mnist_model & model) {
GGML_ASSERT(conv2_out->ne[0] == MNIST_HW/2);
GGML_ASSERT(conv2_out->ne[1] == MNIST_HW/2);
GGML_ASSERT(conv2_out->ne[2] == MNIST_CNN_NCB*2);
GGML_ASSERT(conv2_out->ne[3] == MNIST_NBATCH);
GGML_ASSERT(conv2_out->ne[3] == model.nbatch);

struct ggml_tensor * dense_in = ggml_pool_2d(model.ctx_compute, conv2_out, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
GGML_ASSERT(dense_in->ne[0] == MNIST_HW/4);
GGML_ASSERT(dense_in->ne[1] == MNIST_HW/4);
GGML_ASSERT(dense_in->ne[2] == MNIST_CNN_NCB*2);
GGML_ASSERT(dense_in->ne[3] == MNIST_NBATCH);
GGML_ASSERT(dense_in->ne[3] == model.nbatch);

dense_in = ggml_reshape_2d(model.ctx_compute,
ggml_cont(model.ctx_compute, ggml_permute(model.ctx_compute, dense_in, 1, 2, 0, 3)),
(MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), MNIST_NBATCH);
(MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), model.nbatch);
GGML_ASSERT(dense_in->ne[0] == (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2));
GGML_ASSERT(dense_in->ne[1] == MNIST_NBATCH);
GGML_ASSERT(dense_in->ne[1] == model.nbatch);
GGML_ASSERT(dense_in->ne[2] == 1);
GGML_ASSERT(dense_in->ne[3] == 1);

struct ggml_tensor * dense_bias = model.dense_bias;
if (MNIST_NBATCH > 1) {
if (model.nbatch > 1) {
int64_t ne[4];
memcpy(ne, dense_bias->ne, sizeof(ne));
ne[1] = MNIST_NBATCH;
ne[1] = model.nbatch;
dense_bias = ggml_repeat(model.ctx_compute, dense_bias, ggml_new_tensor(model.ctx_compute, GGML_TYPE_F32, 4, ne));
}
model.logits = ggml_add(model.ctx_compute, ggml_mul_mat(model.ctx_compute, model.dense_weight, dense_in), dense_bias);
Expand All @@ -416,21 +418,36 @@ void mnist_model_build(mnist_model & model) {

ggml_set_output(model.logits);
ggml_set_name(model.logits, "logits");
GGML_ASSERT(model.logits->type == GGML_TYPE_F32);
GGML_ASSERT(model.logits->ne[0] == MNIST_NCLASSES);
GGML_ASSERT(model.logits->ne[1] == model.nbatch);
GGML_ASSERT(model.logits->ne[2] == 1);
GGML_ASSERT(model.logits->ne[3] == 1);

model.probs = ggml_soft_max(model.ctx_compute, model.logits);
ggml_set_output(model.probs);
ggml_set_name(model.probs, "probs");
GGML_ASSERT(model.probs->type == GGML_TYPE_F32);
GGML_ASSERT(model.probs->ne[0] == MNIST_NCLASSES);
GGML_ASSERT(model.probs->ne[1] == model.nbatch);
GGML_ASSERT(model.probs->ne[2] == 1);
GGML_ASSERT(model.probs->ne[3] == 1);

model.labels = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, MNIST_NBATCH);
model.labels = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, model.nbatch);
ggml_set_input(model.labels);
ggml_set_name(model.labels, "labels");

model.loss = ggml_cross_entropy_loss(model.ctx_compute, model.logits, model.labels);
ggml_set_output(model.loss);
ggml_set_name(model.loss, "loss");
GGML_ASSERT(model.loss->type == GGML_TYPE_F32);
GGML_ASSERT(model.loss->ne[0] == 1);
GGML_ASSERT(model.loss->ne[1] == 1);
GGML_ASSERT(model.loss->ne[2] == 1);
GGML_ASSERT(model.loss->ne[3] == 1);
}

mnist_eval_result mnist_model_eval(const mnist_model & model, const float * images, const float * labels, const int nex) {
mnist_eval_result mnist_model_eval(const mnist_model & model, const float * images, const float * labels, const int nex, const int nthreads) {
mnist_eval_result result;

struct ggml_cgraph * gf = ggml_new_graph(model.ctx_compute);
Expand All @@ -439,15 +456,15 @@ mnist_eval_result mnist_model_eval(const mnist_model & model, const float * imag
{
const int64_t t_start_us = ggml_time_us();

GGML_ASSERT(nex % MNIST_NBATCH == 0);
for (int iex0; iex0 < nex; iex0 += MNIST_NBATCH) {
GGML_ASSERT(nex % model.nbatch == 0);
for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch) {
memcpy(model.images->data, images + iex0*MNIST_NINPUT, ggml_nbytes(model.images));
memcpy(model.labels->data, labels + iex0*MNIST_NCLASSES, ggml_nbytes(model.labels));
ggml_graph_compute_with_ctx(model.ctx_compute, gf, 16);
ggml_graph_compute_with_ctx(model.ctx_compute, gf, nthreads);

result.loss.push_back(*ggml_get_data_f32(model.loss));

for (int iexb = 0; iexb < MNIST_NBATCH; ++iexb) {
for (int iexb = 0; iexb < model.nbatch; ++iexb) {
const float * logits_data = ggml_get_data_f32(model.logits) + iexb*MNIST_NCLASSES;
result.pred.push_back(std::max_element(logits_data, logits_data + MNIST_NCLASSES) - logits_data);
}
Expand Down Expand Up @@ -481,10 +498,10 @@ void mnist_model_train(const float * images, const float * labels, const int nex
ggml_opt_init(model.ctx_compute, &opt_ctx, opt_pars, 0);

for (int epoch = 0; epoch < 20; ++epoch) {
fprintf(stderr, "%s: epoch %d start...", __func__, epoch, nex);
fprintf(stderr, "%s: epoch %d start...", __func__, epoch);
const int64_t t_start_us = ggml_time_us();
mnist_eval_result result;
for (int iex0 = 0; iex0 < nex; iex0 += MNIST_NBATCH) {
for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch) {
memcpy(model.images->data, images + iex0*MNIST_NINPUT, ggml_nbytes(model.images));
memcpy(model.labels->data, labels + iex0*MNIST_NCLASSES, ggml_nbytes(model.labels));

Expand All @@ -493,7 +510,7 @@ void mnist_model_train(const float * images, const float * labels, const int nex

result.loss.push_back(*ggml_get_data_f32(model.loss));

for (int iexb = 0; iexb < MNIST_NBATCH; ++iexb) {
for (int iexb = 0; iexb < model.nbatch; ++iexb) {
const float * ptr_p = (const float *) model.logits->data + iexb*MNIST_NCLASSES;
result.pred.push_back(std::max_element(ptr_p, ptr_p + MNIST_NCLASSES) - ptr_p);
}
Expand Down Expand Up @@ -575,3 +592,37 @@ std::pair<double, double> mnist_accuracy(const mnist_eval_result & result, const

return std::make_pair(fraction_correct, uncertainty);
}

#ifdef __cplusplus
extern "C" {
#endif

int wasm_eval(uint8_t * digitPtr) {
std::vector<float> digit(digitPtr, digitPtr + MNIST_NINPUT);
std::vector<float> labels(MNIST_NCLASSES);

mnist_model model = mnist_model_init_from_file("mnist-f32.gguf");
mnist_model_build(model, 1);
mnist_eval_result result = mnist_model_eval(model, digit.data(), labels.data(), 1, 1);

return result.pred[0];
}

int wasm_random_digit(char * digitPtr) {
auto fin = std::ifstream("t10k-images-idx3-ubyte", std::ios::binary);
if (!fin) {
fprintf(stderr, "failed to open digits file\n");
return 0;
}
srand(time(NULL));

// Seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
fin.seekg(16 + MNIST_NINPUT * (rand() % MNIST_NTEST));
fin.read(digitPtr, MNIST_NINPUT);

return 1;
}

#ifdef __cplusplus
}
#endif
3 changes: 2 additions & 1 deletion examples/mnist/mnist-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ static_assert(MNIST_NTEST % MNIST_NBATCH == 0, "MNIST_NTRAIN % MNIST_BATCH != 0

struct mnist_model {
std::string arch;
int nbatch;

struct ggml_tensor * images = nullptr;
struct ggml_tensor * labels = nullptr;
Expand Down Expand Up @@ -94,7 +95,7 @@ mnist_eval_result mnist_graph_eval(const std::string & fname, const float * imag

mnist_model mnist_model_init_from_file(const std::string & fname);
mnist_model mnist_model_init_random(const std::string & arch);
void mnist_model_build(mnist_model & model);
void mnist_model_build(mnist_model & model, const int nbatch);
mnist_eval_result mnist_model_eval(const mnist_model & model, const float * images, const float * labels, const int nex);
void mnist_model_train(const float * images, const float * labels, const int nex, mnist_model & model);
void mnist_model_save(const std::string & fname, mnist_model & model);
Expand Down
Loading

0 comments on commit 5428681

Please sign in to comment.