From babf997d3389f2c5418d97e6bd4c72ea099a70cc Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Thu, 15 Sep 2022 17:54:59 +0800 Subject: [PATCH] vsncnn/vs_ncnn.cpp: fix fp16 inference --- vsncnn/vs_ncnn.cpp | 44 +++++++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/vsncnn/vs_ncnn.cpp b/vsncnn/vs_ncnn.cpp index 16b541e..8b06821 100644 --- a/vsncnn/vs_ncnn.cpp +++ b/vsncnn/vs_ncnn.cpp @@ -97,10 +97,12 @@ struct Resource { std::unique_ptr cmd; ncnn::VkAllocator * blob_vkallocator; ncnn::VkAllocator * staging_vkallocator; + ncnn::Mat h_src_fp32; ncnn::Mat h_src; ncnn::VkMat d_src; ncnn::VkMat d_dst; ncnn::Mat h_dst; + ncnn::Mat h_dst_fp32; }; static std::atomic num_plugin_instances {}; @@ -114,6 +116,8 @@ struct vsNcnnData { int in_tile_c, in_tile_w, in_tile_h; int out_tile_c, out_tile_w, out_tile_h; + bool fp16; + std::vector resources; std::vector tickets; std::mutex ticket_lock; @@ -266,7 +270,7 @@ static const VSFrameRef *VS_CC vsNcnnGetFrame( int x_crop_end = (x == src_width - src_tile_w) ? 0 : d->overlap_w; { - auto input_buffer = reinterpret_cast(resource.h_src.data); + auto input_buffer = reinterpret_cast(d->fp16 ? resource.h_src_fp32.data : resource.h_src.data); // assumes the pitches of ncnn::Mat to be // (cstep * elemsize, w * h * elemsize, h * elemsize) @@ -286,6 +290,10 @@ static const VSFrameRef *VS_CC vsNcnnGetFrame( } } + if (d->fp16) { + ncnn::cast_float32_to_float16(resource.h_src_fp32, resource.h_src); + } + resource.cmd->record_clone(resource.h_src, resource.d_src, opt); { @@ -306,8 +314,12 @@ static const VSFrameRef *VS_CC vsNcnnGetFrame( return set_error("cmd reset failed"); } + if (d->fp16) { + ncnn::cast_float16_to_float32(resource.h_dst, resource.h_dst_fp32); + } + { - auto output_buffer = reinterpret_cast(resource.h_dst.data); + auto output_buffer = reinterpret_cast(d->fp16 ? resource.h_dst_fp32.data : resource.h_dst.data); for (int plane = 0; plane < dst_planes; ++plane) { auto dst_ptr = (dst_ptrs[plane] + @@ -475,9 +487,9 @@ static void VS_CC vsNcnnCreate( d->tickets.push_back(i); } - bool fp16 = !!vsapi->propGetInt(in, "fp16", 0, &error); + d->fp16 = !!vsapi->propGetInt(in, "fp16", 0, &error); if (error) { - fp16 = false; + d->fp16 = false; } bool path_is_serialization = !!vsapi->propGetInt(in, "path_is_serialization", 0, &error); @@ -552,12 +564,9 @@ static void VS_CC vsNcnnCreate( d->net.opt.num_threads = 1; d->net.opt.use_vulkan_compute = true; - d->net.opt.use_fp16_packed = fp16; - d->net.opt.use_fp16_storage = fp16; - d->net.opt.use_fp16_arithmetic = fp16; - d->net.opt.use_int8_packed = false; + d->net.opt.use_fp16_packed = d->fp16; + d->net.opt.use_fp16_storage = d->fp16; d->net.opt.use_int8_storage = false; - d->net.opt.use_int8_arithmetic = false; d->net.set_vulkan_device(d->device); if (d->net.load_param_mem(ncnn_param) != 0) { vs_aligned_free(ncnn_param); @@ -572,15 +581,24 @@ static void VS_CC vsNcnnCreate( d->input_index = d->net.input_indexes().front(); d->output_index = d->net.output_indexes().front(); + size_t bps = 4; + if (d->fp16) { + bps = 2; + } + d->resources.resize(num_streams); for (auto & resource : d->resources) { resource.cmd = std::make_unique(d->device); resource.blob_vkallocator = d->device->acquire_blob_allocator(); resource.staging_vkallocator = d->device->acquire_staging_allocator(); - resource.h_src.create(d->in_tile_w, d->in_tile_h, d->in_tile_c); - resource.d_src.create(d->in_tile_w, d->in_tile_h, d->in_tile_c, sizeof(float), resource.blob_vkallocator); - resource.d_dst.create(d->out_tile_w, d->out_tile_h, d->out_tile_c, sizeof(float), resource.blob_vkallocator); - resource.h_dst.create(d->out_tile_w, d->out_tile_h, d->out_tile_c); + resource.h_src.create(d->in_tile_w, d->in_tile_h, d->in_tile_c, bps); + resource.d_src.create(d->in_tile_w, d->in_tile_h, d->in_tile_c, bps, resource.blob_vkallocator); + resource.d_dst.create(d->out_tile_w, d->out_tile_h, d->out_tile_c, bps, resource.blob_vkallocator); + resource.h_dst.create(d->out_tile_w, d->out_tile_h, d->out_tile_c, bps); + if (d->fp16) { + resource.h_src_fp32.create(d->in_tile_w, d->in_tile_h, d->in_tile_c, sizeof(float)); + resource.h_dst_fp32.create(d->out_tile_w, d->out_tile_h, d->out_tile_c, sizeof(float)); + } } vsapi->createFilter(