Skip to content

Commit

Permalink
vsncnn/vs_ncnn.cpp: fix fp16 inference
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Sep 15, 2022
1 parent b0adf9f commit babf997
Showing 1 changed file with 31 additions and 13 deletions.
44 changes: 31 additions & 13 deletions vsncnn/vs_ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,12 @@ struct Resource {
std::unique_ptr<ncnn::VkCompute> cmd;
ncnn::VkAllocator * blob_vkallocator;
ncnn::VkAllocator * staging_vkallocator;
ncnn::Mat h_src_fp32;
ncnn::Mat h_src;
ncnn::VkMat d_src;
ncnn::VkMat d_dst;
ncnn::Mat h_dst;
ncnn::Mat h_dst_fp32;
};

static std::atomic<int> num_plugin_instances {};
Expand All @@ -114,6 +116,8 @@ struct vsNcnnData {
int in_tile_c, in_tile_w, in_tile_h;
int out_tile_c, out_tile_w, out_tile_h;

bool fp16;

std::vector<Resource> resources;
std::vector<int> tickets;
std::mutex ticket_lock;
Expand Down Expand Up @@ -266,7 +270,7 @@ static const VSFrameRef *VS_CC vsNcnnGetFrame(
int x_crop_end = (x == src_width - src_tile_w) ? 0 : d->overlap_w;

{
auto input_buffer = reinterpret_cast<uint8_t *>(resource.h_src.data);
auto input_buffer = reinterpret_cast<uint8_t *>(d->fp16 ? resource.h_src_fp32.data : resource.h_src.data);

// assumes the pitches of ncnn::Mat to be
// (cstep * elemsize, w * h * elemsize, h * elemsize)
Expand All @@ -286,6 +290,10 @@ static const VSFrameRef *VS_CC vsNcnnGetFrame(
}
}

if (d->fp16) {
ncnn::cast_float32_to_float16(resource.h_src_fp32, resource.h_src);
}

resource.cmd->record_clone(resource.h_src, resource.d_src, opt);

{
Expand All @@ -306,8 +314,12 @@ static const VSFrameRef *VS_CC vsNcnnGetFrame(
return set_error("cmd reset failed");
}

if (d->fp16) {
ncnn::cast_float16_to_float32(resource.h_dst, resource.h_dst_fp32);
}

{
auto output_buffer = reinterpret_cast<uint8_t *>(resource.h_dst.data);
auto output_buffer = reinterpret_cast<uint8_t *>(d->fp16 ? resource.h_dst_fp32.data : resource.h_dst.data);

for (int plane = 0; plane < dst_planes; ++plane) {
auto dst_ptr = (dst_ptrs[plane] +
Expand Down Expand Up @@ -475,9 +487,9 @@ static void VS_CC vsNcnnCreate(
d->tickets.push_back(i);
}

bool fp16 = !!vsapi->propGetInt(in, "fp16", 0, &error);
d->fp16 = !!vsapi->propGetInt(in, "fp16", 0, &error);
if (error) {
fp16 = false;
d->fp16 = false;
}

bool path_is_serialization = !!vsapi->propGetInt(in, "path_is_serialization", 0, &error);
Expand Down Expand Up @@ -552,12 +564,9 @@ static void VS_CC vsNcnnCreate(

d->net.opt.num_threads = 1;
d->net.opt.use_vulkan_compute = true;
d->net.opt.use_fp16_packed = fp16;
d->net.opt.use_fp16_storage = fp16;
d->net.opt.use_fp16_arithmetic = fp16;
d->net.opt.use_int8_packed = false;
d->net.opt.use_fp16_packed = d->fp16;
d->net.opt.use_fp16_storage = d->fp16;
d->net.opt.use_int8_storage = false;
d->net.opt.use_int8_arithmetic = false;
d->net.set_vulkan_device(d->device);
if (d->net.load_param_mem(ncnn_param) != 0) {
vs_aligned_free(ncnn_param);
Expand All @@ -572,15 +581,24 @@ static void VS_CC vsNcnnCreate(
d->input_index = d->net.input_indexes().front();
d->output_index = d->net.output_indexes().front();

size_t bps = 4;
if (d->fp16) {
bps = 2;
}

d->resources.resize(num_streams);
for (auto & resource : d->resources) {
resource.cmd = std::make_unique<ncnn::VkCompute>(d->device);
resource.blob_vkallocator = d->device->acquire_blob_allocator();
resource.staging_vkallocator = d->device->acquire_staging_allocator();
resource.h_src.create(d->in_tile_w, d->in_tile_h, d->in_tile_c);
resource.d_src.create(d->in_tile_w, d->in_tile_h, d->in_tile_c, sizeof(float), resource.blob_vkallocator);
resource.d_dst.create(d->out_tile_w, d->out_tile_h, d->out_tile_c, sizeof(float), resource.blob_vkallocator);
resource.h_dst.create(d->out_tile_w, d->out_tile_h, d->out_tile_c);
resource.h_src.create(d->in_tile_w, d->in_tile_h, d->in_tile_c, bps);
resource.d_src.create(d->in_tile_w, d->in_tile_h, d->in_tile_c, bps, resource.blob_vkallocator);
resource.d_dst.create(d->out_tile_w, d->out_tile_h, d->out_tile_c, bps, resource.blob_vkallocator);
resource.h_dst.create(d->out_tile_w, d->out_tile_h, d->out_tile_c, bps);
if (d->fp16) {
resource.h_src_fp32.create(d->in_tile_w, d->in_tile_h, d->in_tile_c, sizeof(float));
resource.h_dst_fp32.create(d->out_tile_w, d->out_tile_h, d->out_tile_c, sizeof(float));
}
}

vsapi->createFilter(
Expand Down

0 comments on commit babf997

Please sign in to comment.