Skip to content

Commit

Permalink
Fix wrong col_mat size
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Aug 20, 2019
1 parent 0cbd770 commit 75e6992
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions dabnn/layers/BinConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

namespace bnn {

int align_to(int a, int b) { return (a + (b - 1) / b) * b; }

BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
css output, int pad_h, int pad_w, int stride_h, int stride_w)
: Layer(net, name, "Bin Conv"),
Expand Down Expand Up @@ -43,9 +45,11 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,

const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len = output_mat->h * output_mat->w * weight_mat->h *
weight_mat->w * input_mat->elem_c;
mat_map[col_mat_name] = std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
const auto len =
output_mat->h * output_mat->w *
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128);
mat_map[col_mat_name] =
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);

Expand Down Expand Up @@ -119,22 +123,22 @@ void BinConv::forward_impl() const {
output_mat->fill<float>(0.f);
// pack_mat_64(*input_mat, *binarized_mat);
// bnn::im2col(*binarized_mat, weight_mat->h, weight_mat->w,
// pad_h, pad_w, stride_h, stride_w, 1, 1,
// *col_mat);
// pad_h, pad_w, stride_h, stride_w, 1,
// 1, *col_mat);

// const auto len = output_mat->h * output_mat->w * weight_mat->h *
// weight_mat->w * input_mat->elem_c;
// Mat temp(1, 1, len, bnn::DataType::Float);
// im2col(*input_mat, weight_mat->h, weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1, temp);
// pack_mat(temp, *col_mat);
// im2col(*input_mat, weight_mat->h, weight_mat->w, pad_h, pad_w,
// stride_h, stride_w, 1, 1, temp); pack_mat(temp, *col_mat);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
pad_h, pad_w, stride_h, stride_w, 1, 1,
*col_mat);

const int m = weight_mat->n;
const int n = output_mat->h * output_mat->w;
const int k = weight_mat->h * weight_mat->w * weight_mat->c;
const int k = weight_mat->total() / weight_mat->n;
bgemm(m, n, k, static_cast<uint64_t *>(transposed_weight_mat->data),
m, static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
Expand Down

0 comments on commit 75e6992

Please sign in to comment.