Skip to content

Commit

Permalink
fused_binarize_im2col works
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Aug 20, 2019
1 parent 691c67c commit ddf6a6f
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 16 deletions.
24 changes: 18 additions & 6 deletions dabnn/bitpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,20 @@
#include "mat.h"

namespace bnn {
inline void pack_64(const float *float_ptr, void *binary_ptr, size_t size) {
uint64_t *u64_bptr = static_cast<uint64_t *>(binary_ptr);
FORZS(_, size, 64) {
pack_64_bitfield(float_ptr, u64_bptr);
float_ptr += 64;
u64_bptr++;
}
}

#ifdef __aarch64__
inline void pack_128_opt(const float *float_ptr, void *binary_ptr,
size_t size) {
/**
* size: the number of __elements__ needed to be packed.
* size: the number of __float elements__ needed to be packed.
*
* This is the optimized bit-packing.
*
Expand Down Expand Up @@ -249,7 +257,8 @@ inline void pack_mat_64(const bnn::Mat &float_mat, bnn::Mat &binary_mat) {
BNN_ASSERT(
float_mat.w * float_mat.c > 0 && float_mat.w * float_mat.c % 64 == 0,
float_mat.w * float_mat.c);
BNN_ASSERT(float_mat.c / 64 == binary_mat.c && float_mat.c % 64 == 0, "float_mat.c ", float_mat.c, ", binary_mat.c ", binary_mat.c);
BNN_ASSERT(float_mat.c / 64 == binary_mat.c && float_mat.c % 64 == 0,
"float_mat.c ", float_mat.c, ", binary_mat.c ", binary_mat.c);

FORZ(n, float_mat.n) {
FORZ(h, float_mat.h) {
Expand All @@ -265,12 +274,15 @@ inline void pack_mat_64(const bnn::Mat &float_mat, bnn::Mat &binary_mat) {
}

inline void pack_mat(const bnn::Mat &float_mat, bnn::Mat &binary_mat) {
BNN_ASSERT(float_mat.data_type == DataType::Float , "float_mat has wrong data type");
BNN_ASSERT(binary_mat.data_type == DataType::Bit, "binary_mat has wrong data type");
BNN_ASSERT(float_mat.data_type == DataType::Float,
"float_mat has wrong data type");
BNN_ASSERT(binary_mat.data_type == DataType::Bit,
"binary_mat has wrong data type");
BNN_ASSERT(float_mat.c % 64 == 0, float_mat.c);
#ifdef __aarch64__
if (float_mat.c % 128 == 0) {
pack_mat_128_opt(float_mat, binary_mat);
// pack_mat_128_opt(float_mat, binary_mat);
pack_mat_64(float_mat, binary_mat);
} else {
pack_mat_64(float_mat, binary_mat);
}
Expand All @@ -279,5 +291,5 @@ inline void pack_mat(const bnn::Mat &float_mat, bnn::Mat &binary_mat) {
#endif // __aarch64__
}

}
} // namespace bnn
#endif /* BITPACK_H */
61 changes: 54 additions & 7 deletions dabnn/fused_binarize_im2col.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,63 @@ inline void fused_binarize_im2col(const Mat &im, const int kernel_h,
const int pad_w, const int stride_h,
const int stride_w, const int dilation_h,
const int dilation_w, Mat &col) {
BNN_ASSERT(im.data_type == DataType::Float, "Input of fused_binarize_im2col should be float");
BNN_ASSERT(col.data_type == DataType::Bit, "Output of fused_binarize_im2col should be bit");

BNN_ASSERT(kernel_h * kernel_w * im.c < 60000,
"kernel_h * kernel_w * im.c must be smaller than 60000");

const int output_h =
(im.h + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int output_w =
(im.w + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

BNN_ASSERT(kernel_h * kernel_w * im.c < 60000,
"kernel_h * kernel_w * im.c must be smaller than 60000");
// Mat temp(1, 1, kernel_h * kernel_w * output_h * output_w * im.c, DataType::Float);
// char *data_col = static_cast<char *>(temp);
// int input_y = 0;
// FORZ(output_y, output_h) {
// int input_x = 0;
// FORZ(output_x, output_w) {
// FORZ(kh, kernel_h) {
// int y = input_y - pad_h + kh * dilation_h;
// const char *data_im = static_cast<char *>(im.data) +
// y * im.w * im.c * im.elemsize;
// FORZ(kw, kernel_w) {
// int x = input_x - pad_w + kw * dilation_w;
// if (y < 0 || y >= im.h || x < 0 || x >= im.w) {
// memset(data_col, 0, im.c * im.elemsize);
// } else {
// memcpy(data_col, data_im + x * im.c * im.elemsize,
// im.c * im.elemsize);
// }
// data_col += im.c * im.elemsize;
// }
// }
// input_x += stride_w;
// }
// input_y += stride_h;
// }
// pack_64(static_cast<float *>(temp.data), col.data, temp.total());
// if (true) {
// Mat temp(1, 1, kernel_h * kernel_w * output_h * output_w * im.c, DataType::Float);
// im2col(im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, temp);
// pack_mat(temp, col);
// } else {
// Mat temp(1, 9999999, DataType::Bit);
// pack_mat_128_opt(im, temp);
// im2col(temp, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col);
// }

// TODO: More elegant way
static float buf[60000];
static char buf[2400000];


char *data_col = static_cast<char *>(col);
int input_y = 0;
FORZ(output_y, output_h) {
int input_x = 0;
FORZ(output_x, output_w) {
float *buf_ptr = buf;
char *buf_ptr = buf;
FORZ(kh, kernel_h) {
int y = input_y - pad_h + kh * dilation_h;
const char *data_im = static_cast<char *>(im.data) +
Expand All @@ -40,16 +81,22 @@ inline void fused_binarize_im2col(const Mat &im, const int kernel_h,
}
}

BNN_ASSERT(im.elemsize == 4, "");
// len: the number of elements in one column
const size_t len = (buf_ptr - buf) / im.elemsize;
const size_t len_aligned_128 = (len + 127) / 128 * 128;
BNN_ASSERT(len == len_aligned_128, "");
// pad the buffer so that its length aligns to 128
memset(buf_ptr, 0, (len_aligned_128 - len) * im.elemsize);

pack_128_opt(buf_ptr, data_col, len_aligned_128);
auto *fbuf = reinterpret_cast<float *>(buf);
pack_64(fbuf, data_col, len_aligned_128);

// `len_aligned_128` is the number of appended __bits__ in
// mat `col`, so divide sizeof(decltype(data_col)) here
data_col += len_aligned_128 / sizeof(decltype(data_col));
// mat `col`, so divide here
const auto tmp = len_aligned_128 / 8;

data_col += tmp;

input_x += stride_w;
}
Expand Down
17 changes: 15 additions & 2 deletions dabnn/layers/BinConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
}
padded_mat = mat(pad_name);

const auto col_mat_name = "col_mat";
const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len = output_mat->h * output_mat->w * weight_mat->h *
weight_mat->w * input_mat->elem_c;
mat_map[col_mat_name] = std::make_shared<Mat>(len, bnn::DataType::Bit);
mat_map[col_mat_name] = std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);

Expand Down Expand Up @@ -74,6 +74,7 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,

bool BinConv::direct_conv_compatible() const {
#ifdef __aarch64__
return false;
if (weight_mat->h == 3 && weight_mat->w == 3 && input_mat->elem_c == 64 &&
stride_h == stride_w) {
return true;
Expand Down Expand Up @@ -116,9 +117,21 @@ void BinConv::forward_impl() const {
bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h);
} else if (gemm_compatible()) {
output_mat->fill<float>(0.f);
// pack_mat_64(*input_mat, *binarized_mat);
// bnn::im2col(*binarized_mat, weight_mat->h, weight_mat->w,
// pad_h, pad_w, stride_h, stride_w, 1, 1,
// *col_mat);

// const auto len = output_mat->h * output_mat->w * weight_mat->h *
// weight_mat->w * input_mat->elem_c;
// Mat temp(1, 1, len, bnn::DataType::Float);
// im2col(*input_mat, weight_mat->h, weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1, temp);
// pack_mat(temp, *col_mat);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
pad_h, pad_w, stride_h, stride_w, 1, 1,
*col_mat);

const int m = weight_mat->n;
const int n = output_mat->h * output_mat->w;
const int k = weight_mat->h * weight_mat->w * weight_mat->c;
Expand Down
6 changes: 5 additions & 1 deletion dabnn/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ void Net::prepare() {
add_mat(name, std::make_shared<Mat>(
shape[0], shape[1], shape[2], shape[3],
bnn::DataType::Bit, len, false));
pack_mat_128(*tmp, *mat_map_[name]);
pack_mat(*tmp, *mat_map_[name]);
// add_mat(name, std::make_shared<Mat>(
// shape[0], shape[1], shape[2], shape[3],
// const_cast<uint64_t *>(data),
// bnn::DataType::Bit, false));
} else {
#endif // __aarch64__
add_mat(name, std::make_shared<Mat>(
Expand Down

0 comments on commit ddf6a6f

Please sign in to comment.