From d12dd8b1deab03526cc44abd3f229a6a40db43fe Mon Sep 17 00:00:00 2001 From: daquexian Date: Thu, 22 Aug 2019 18:06:51 +0800 Subject: [PATCH] Revert "use bgemm_naive instead of bconv_naive as fallback due to the 128-align weight" This reverts commit be606fa6f87db0a4e9166b79f5d9cb86804525ba. --- dabnn/layers/BinConv.cpp | 57 +++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/dabnn/layers/BinConv.cpp b/dabnn/layers/BinConv.cpp index 7481337..27ad5ce 100644 --- a/dabnn/layers/BinConv.cpp +++ b/dabnn/layers/BinConv.cpp @@ -25,16 +25,14 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight, stride_h(stride_h), stride_w(stride_w) { auto &mat_map = net.lock()->mat_map_; - if (direct_conv_compatible()) { - const auto binaized_name = "binaized_for_" + output + "_cal"; - if (mat_map.find(binaized_name) == mat_map.end()) { - auto &input_mat = *mat_map[input]; - mat_map[binaized_name] = - std::make_shared(input_mat.h, input_mat.w, input_mat.elem_c, - DataType::Bit, binaized_name); - } - binarized_mat = mat(binaized_name); + const auto binaized_name = "binaized_for_" + output + "_cal"; + if (mat_map.find(binaized_name) == mat_map.end()) { + auto &input_mat = *mat_map[input]; + mat_map[binaized_name] = + std::make_shared(input_mat.h, input_mat.w, input_mat.elem_c, + DataType::Bit, binaized_name); } + binarized_mat = mat(binaized_name); const auto pad_name = "pad_for_" + output + "_cal"; if (mat_map.find(pad_name) == mat_map.end()) { @@ -45,18 +43,18 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight, } padded_mat = mat(pad_name); - if (net.lock()->optimize && !direct_conv_compatible()) { - - const auto col_mat_name = "col_for_" + output + "_cal"; - if (mat_map.find(col_mat_name) == mat_map.end()) { - const auto len = - output_mat->h * output_mat->w * - align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128); - mat_map[col_mat_name] = - std::make_shared(1, 1, len, bnn::DataType::Bit); - } - col_mat = mat(col_mat_name); + const auto col_mat_name = "col_for_" + output + "_cal"; + if (mat_map.find(col_mat_name) == mat_map.end()) { + const auto len = + output_mat->h * output_mat->w * + align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128); + mat_map[col_mat_name] = + std::make_shared(1, 1, len, bnn::DataType::Bit); + } + col_mat = mat(col_mat_name); + if (net.lock()->optimize && !direct_conv_compatible() && + gemm_compatible()) { const auto trans_weight_mat_name = "trans_" + weight; // transpose the weight for bgemm const int m = weight_mat->n; @@ -128,7 +126,7 @@ void BinConv::forward_impl() const { pack_mat(*input_mat, *binarized_mat); pad(*binarized_mat, pad_h, pad_w, *padded_mat); bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h); - } else { + } else if (gemm_compatible()) { output_mat->fill(0.f); bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w, @@ -138,15 +136,14 @@ void BinConv::forward_impl() const { const int m = weight_mat->n; const int n = output_mat->h * output_mat->w; const int k = weight_mat->total() / weight_mat->n; - if (gemm_compatible()) { - bgemm(m, n, k, static_cast(transposed_weight_mat->data), - m, static_cast(col_mat->data), k, - static_cast(output_mat->data), m); - } else { - bgemm_naive(m, n, k, static_cast(transposed_weight_mat->data), - m, static_cast(col_mat->data), k, - static_cast(output_mat->data), m); - } + bgemm(m, n, k, static_cast(transposed_weight_mat->data), + m, static_cast(col_mat->data), k, + static_cast(output_mat->data), m); + } else { + pack_mat(*input_mat, *binarized_mat); + baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h, + weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, + 1, output_mat->c, *output_mat); } } else { pack_mat(*input_mat, *binarized_mat);