Skip to content

Commit

Permalink
Add yolov3-tiny example
Browse files Browse the repository at this point in the history
  • Loading branch information
rgerganov committed Oct 12, 2023
1 parent 39d80ea commit e220e24
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ add_subdirectory(replit)
add_subdirectory(mpt)
add_subdirectory(starcoder)
add_subdirectory(sam)
add_subdirectory(yolo)
6 changes: 6 additions & 0 deletions examples/yolo/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#
# yolov3-tiny

set(TEST_TARGET yolov3-tiny)
add_executable(${TEST_TARGET} yolov3-tiny.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common)
25 changes: 25 additions & 0 deletions examples/yolo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
This example shows how to implement YOLO object detection with ggml.

# YOLOv3-tiny

Download the model weights:

```bash
$ wget https://pjreddie.com/media/files/yolov3-tiny.weights
$ sha1sum yolov3-tiny.weights
40f3c11883bef62fd850213bc14266632ed4414f yolov3-tiny.weights
```

Convert the weights to ggml format:

```bash
$ ./convert-yolov3-tiny.py yolov3-tiny.weights
yolov3-tiny.weights converted to yolov3-tiny.gguf
```

Object detection:

```bash
$ wget https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg
$ ./yolov3-tiny yolov3-tiny.gguf dog.jpg
```
53 changes: 53 additions & 0 deletions examples/yolo/convert-yolov3-tiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3
import sys
import gguf
import numpy as np

def save_conv2d_layer(f, gguf_writer, prefix, inp_c, filters, size, batch_normalize=True):
biases = np.fromfile(f, dtype=np.float32, count=filters)
gguf_writer.add_tensor(prefix + "_biases", biases, raw_shape=(1, filters, 1, 1))

if batch_normalize:
scales = np.fromfile(f, dtype=np.float32, count=filters)
gguf_writer.add_tensor(prefix + "_scales", scales, raw_shape=(1, filters, 1, 1))
rolling_mean = np.fromfile(f, dtype=np.float32, count=filters)
gguf_writer.add_tensor(prefix + "_rolling_mean", rolling_mean, raw_shape=(1, filters, 1, 1))
rolling_variance = np.fromfile(f, dtype=np.float32, count=filters)
gguf_writer.add_tensor(prefix + "_rolling_variance", rolling_variance, raw_shape=(1, filters, 1, 1))

weights_count = filters * inp_c * size * size
l0_weights = np.fromfile(f, dtype=np.float32, count=weights_count)
## ggml doesn't support f32 convolution yet, use f16 instead
l0_weights = l0_weights.astype(np.float16)
gguf_writer.add_tensor(prefix + "_weights", l0_weights, raw_shape=(filters, inp_c, size, size))


if __name__ == '__main__':
if len(sys.argv) != 2:
print("Usage: %s <yolov3-tiny.weights>" % sys.argv[0])
sys.exit(1)
outfile = 'yolov3-tiny.gguf'
gguf_writer = gguf.GGUFWriter(outfile, 'yolov3-tiny')

f = open(sys.argv[1], 'rb')
f.read(20) # skip header
save_conv2d_layer(f, gguf_writer, "l0", 3, 16, 3)
save_conv2d_layer(f, gguf_writer, "l1", 16, 32, 3)
save_conv2d_layer(f, gguf_writer, "l2", 32, 64, 3)
save_conv2d_layer(f, gguf_writer, "l3", 64, 128, 3)
save_conv2d_layer(f, gguf_writer, "l4", 128, 256, 3)
save_conv2d_layer(f, gguf_writer, "l5", 256, 512, 3)
save_conv2d_layer(f, gguf_writer, "l6", 512, 1024, 3)
save_conv2d_layer(f, gguf_writer, "l7", 1024, 256, 1)
save_conv2d_layer(f, gguf_writer, "l8", 256, 512, 3)
save_conv2d_layer(f, gguf_writer, "l9", 512, 255, 1, batch_normalize=False)
save_conv2d_layer(f, gguf_writer, "l10", 256, 128, 1)
save_conv2d_layer(f, gguf_writer, "l11", 384, 256, 3)
save_conv2d_layer(f, gguf_writer, "l12", 256, 255, 1, batch_normalize=False)
f.close()

gguf_writer.write_header_to_file()
gguf_writer.write_kv_data_to_file()
gguf_writer.write_tensors_to_file()
gguf_writer.close()
print("{} converted to {}".format(sys.argv[1], outfile))
304 changes: 304 additions & 0 deletions examples/yolo/yolov3-tiny.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
#include "ggml/ggml.h"
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"

#include <cmath>
#include <cstdio>
#include <ctime>
#include <string>
#include <vector>
#include <algorithm>

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

const int NET_W = 416;
const int NET_H = 416;

struct yolo_image {
int w, h, c;
std::vector<float> data;

yolo_image(int w, int h, int c) : w(w), h(h), c(c), data(w*h*c) {}

float get_pixel(int x, int y, int c) const {
assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c);
return data[c*w*h + y*w + x];
}

void set_pixel(int x, int y, int c, float val) {
assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c);
data[c*w*h + y*w + x] = val;
}

void add_pixel(int x, int y, int c, float val) {
assert(x >= 0 && x < w && y >= 0 && y < h && c >= 0 && c < this->c);
data[c*w*h + y*w + x] += val;
}

void fill(float val) {
std::fill(data.begin(), data.end(), val);
}
};

bool load_image(const char *fname, yolo_image & img)
{
int w, h, c;
uint8_t * data = stbi_load(fname, &w, &h, &c, 3);
if (!data) {
fprintf(stderr, "Failed to load image: %s\n", fname);
return false;
}
img.w = w;
img.h = h;
img.c = c;
img.data.resize(w*h*c);
for (int k = 0; k < c; ++k){
for (int j = 0; j < h; ++j){
for (int i = 0; i < w; ++i){
int dst_index = i + w*j + w*h*k;
int src_index = k + c*i + c*w*j;
img.data[dst_index] = (float)data[src_index]/255.;
}
}
}
stbi_image_free(data);
return true;
}

yolo_image resize_image(const yolo_image & im, int w, int h)
{
yolo_image resized(w, h, im.c);
yolo_image part(w, im.h, im.c);
float w_scale = (float)(im.w - 1) / (w - 1);
float h_scale = (float)(im.h - 1) / (h - 1);
for (int k = 0; k < im.c; ++k){
for (int r = 0; r < im.h; ++r) {
for (int c = 0; c < w; ++c) {
float val = 0;
if (c == w-1 || im.w == 1){
val = im.get_pixel(im.w-1, r, k);
} else {
float sx = c*w_scale;
int ix = (int) sx;
float dx = sx - ix;
val = (1 - dx) * im.get_pixel(ix, r, k) + dx * im.get_pixel(ix+1, r, k);
}
part.set_pixel(c, r, k, val);
}
}
}
for (int k = 0; k < im.c; ++k){
for (int r = 0; r < h; ++r){
float sy = r*h_scale;
int iy = (int) sy;
float dy = sy - iy;
for (int c = 0; c < w; ++c){
float val = (1-dy) * part.get_pixel(c, iy, k);
resized.set_pixel(c, r, k, val);
}
if (r == h-1 || im.h == 1) continue;
for (int c = 0; c < w; ++c){
float val = dy * part.get_pixel(c, iy+1, k);
resized.add_pixel(c, r, k, val);
}
}
}
return resized;
}

void embed_image(const yolo_image & source, yolo_image & dest, int dx, int dy)
{
for (int k = 0; k < source.c; ++k) {
for (int y = 0; y < source.h; ++y) {
for (int x = 0; x < source.w; ++x) {
float val = source.get_pixel(x, y, k);
dest.set_pixel(dx+x, dy+y, k, val);
}
}
}
}

yolo_image letterbox_image(yolo_image im, int w, int h)
{
int new_w = im.w;
int new_h = im.h;
if (((float)w/im.w) < ((float)h/im.h)) {
new_w = w;
new_h = (im.h * w)/im.w;
} else {
new_h = h;
new_w = (im.w * h)/im.h;
}
yolo_image resized = resize_image(im, new_w, new_h);
yolo_image boxed(w, h, im.c);
boxed.fill(0.5);
embed_image(resized, boxed, (w-new_w)/2, (h-new_h)/2);
return boxed;
}

struct conv2d_layer {
struct ggml_tensor * weights;
struct ggml_tensor * biases;
struct ggml_tensor * scales;
struct ggml_tensor * rolling_mean;
struct ggml_tensor * rolling_variance;
int padding = 1;
bool batch_normalize = true;
bool activate = true; // true for leaky relu, false for linear
};

struct yolo_model {
std::vector<conv2d_layer> conv2d_layers;
struct ggml_context * ctx;
};

bool yolo_model_load(const std::string & fname, yolo_model & model) {
struct gguf_init_params params = {
/*.no_alloc =*/ false,
/*.ctx =*/ &model.ctx,
};
gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
if (!ctx) {
fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
return false;
}
model.conv2d_layers.resize(10);
model.conv2d_layers[7].padding = 0;
model.conv2d_layers[9].padding = 0;
model.conv2d_layers[9].batch_normalize = false;
model.conv2d_layers[9].activate = false;
for (int i = 0; i < (int)model.conv2d_layers.size(); i++) {
char name[256];
snprintf(name, sizeof(name), "l%d_weights", i);
model.conv2d_layers[i].weights = ggml_get_tensor(model.ctx, name);
snprintf(name, sizeof(name), "l%d_biases", i);
model.conv2d_layers[i].biases = ggml_get_tensor(model.ctx, name);
if (model.conv2d_layers[i].batch_normalize) {
snprintf(name, sizeof(name), "l%d_scales", i);
model.conv2d_layers[i].scales = ggml_get_tensor(model.ctx, name);
snprintf(name, sizeof(name), "l%d_rolling_mean", i);
model.conv2d_layers[i].rolling_mean = ggml_get_tensor(model.ctx, name);
snprintf(name, sizeof(name), "l%d_rolling_variance", i);
model.conv2d_layers[i].rolling_variance = ggml_get_tensor(model.ctx, name);
}
}
return true;
}

ggml_tensor* apply_conv2d(ggml_context * ctx, ggml_tensor * input, const conv2d_layer & layer)
{
struct ggml_tensor *result = ggml_conv_2d(ctx, layer.weights, input, 1, 1, layer.padding, layer.padding, 1, 1);
if (layer.batch_normalize) {
result = ggml_sub(ctx, result, ggml_repeat(ctx, layer.rolling_mean, result));
result = ggml_div(ctx, result, ggml_sqrt(ctx, ggml_repeat(ctx, layer.rolling_variance, result)));
result = ggml_mul(ctx, result, ggml_repeat(ctx, layer.scales, result));
}
result = ggml_add(ctx, result, ggml_repeat(ctx, layer.biases, result));
if (layer.activate) {
result = ggml_leaky(ctx, result);
}
return result;
}

void dump_tensor(ggml_tensor *t, int n)
{
const float * data = ggml_get_data_f32(t);
n = std::min(n, (int)t->ne[0] * (int)t->ne[1]);
for (int i = 0 ; i < n ; i++) {
if (i % 13 == 0) printf("\n");
printf("%8.4f ", data[i]);
}
printf("\n");
}

void print_shape(int layer, const ggml_tensor * t)
{
printf("Layer %2d output shape: %3d x %3d x %4d x %3d\n", layer, (int)t->ne[0], (int)t->ne[1], (int)t->ne[2], (int)t->ne[3]);
}

void detect(const yolo_image & img, const yolo_model & model)
{
static size_t buf_size = 20000000 * sizeof(float) * 4;
static void * buf = malloc(buf_size);

struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};

struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};

struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, NET_W, NET_H, 3, 1);
memcpy(input->data, img.data.data(), ggml_nbytes(input));
ggml_set_name(input, "input");

struct ggml_tensor * result = apply_conv2d(ctx0, input, model.conv2d_layers[0]);
print_shape(0, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(1, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[1]);
print_shape(2, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(3, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[2]);
print_shape(4, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(5, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[3]);
print_shape(6, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(7, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[4]);
print_shape(8, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
print_shape(9, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[5]);
print_shape(10, result);
result = ggml_pool_2d(ctx0, result, GGML_OP_POOL_MAX, 2, 2, 1, 1, 0.5, 0.5);
print_shape(11, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[6]);
print_shape(12, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[7]);
print_shape(13, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[8]);
print_shape(14, result);
result = apply_conv2d(ctx0, result, model.conv2d_layers[9]);
print_shape(15, result);

ggml_build_forward_expand(&gf, result);
ggml_graph_compute_with_ctx(ctx0, &gf, 1);

dump_tensor(result, 13*13);
ggml_free(ctx0);
}

int main(int argc, char *argv[])
{
ggml_time_init();
yolo_model model;

if (argc != 3) {
fprintf(stderr, "Usage: %s <yolov3-tiny.gguf> <image>\n", argv[0]);
exit(0);
}
if (!yolo_model_load(argv[1], model)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, argv[1]);
return 1;
}

yolo_image img(0,0,0);
if (!load_image(argv[2], img)) {
fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, argv[2]);
return 1;
}
yolo_image sized = letterbox_image(img, NET_W, NET_H);
detect(sized, model);
ggml_free(model.ctx);
return 0;
}

0 comments on commit e220e24

Please sign in to comment.