Skip to content

Commit

Permalink
Merge pull request #49 from JDAI-CV/support_arbitrarily_channels
Browse files Browse the repository at this point in the history
Support arbitrary channels
  • Loading branch information
daquexian authored Aug 21, 2019
2 parents c08695a + ce6e843 commit 49bc3e9
Show file tree
Hide file tree
Showing 37 changed files with 467 additions and 8,715 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
[submodule "third_party/protobuf"]
path = third_party/protobuf
url = https://github.com/protocolbuffers/protobuf
[submodule "third_party/flatbuffers"]
path = third_party/flatbuffers
url = https://github.com/google/flatbuffers
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ include(cmake/system.cmake)
include(cmake/glog.cmake)
configure_glog()

include(cmake/flatbuffers.cmake)
configure_flatbuffers()

add_compile_options("-DEIGEN_MPL2_ONLY")
if (${BNN_NET_BENCHMARK})
add_compile_options("-DBNN_BENCHMARK")
Expand Down
8 changes: 4 additions & 4 deletions benchmark/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@
#include <dabnn/net.h>

static void BM_pack_mat_64_small(benchmark::State &state) {
const bnn::Mat a(1, 32, 32, 128, bnn::DataType::Float, 0);
bnn::Mat b(1, 32, 32, 128, bnn::DataType::Bit, 0);
const bnn::Mat a(1, 32, 32, 128, bnn::DataType::Float, false);
bnn::Mat b(1, 32, 32, 128, bnn::DataType::Bit, false);
for (auto _ : state) {
pack_mat_64(a, b);
}
}

#ifdef __aarch64__
static void BM_pack_mat_128_small(benchmark::State &state) {
const bnn::Mat a(1, 32, 32, 128, bnn::DataType::Float, 0);
bnn::Mat b(1, 32, 32, 128, bnn::DataType::Bit, 0);
const bnn::Mat a(1, 32, 32, 128, bnn::DataType::Float, false);
bnn::Mat b(1, 32, 32, 128, bnn::DataType::Bit, false);
for (auto _ : state) {
pack_mat_128(a, b);
}
Expand Down
10 changes: 10 additions & 0 deletions cmake/flatbuffers.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function(configure_flatbuffers)
option(FLATBUFFERS_BUILD_TESTS "Enable the build of tests and samples." OFF)
option(FLATBUFFERS_BUILD_FLATHASH "Enable the build of flathash" OFF)
option(FLATBUFFERS_BUILD_FLATC "Enable the build of the flatbuffers compiler"
OFF)
option(FLATBUFFERS_BUILD_FLATLIB "Enable the build of the flatbuffers library"
ON)
add_subdirectory(third_party/flatbuffers)
endfunction()

4 changes: 3 additions & 1 deletion common/baseline.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ inline void baseline_bconv(const Mat &input, const Mat &weight,
const int stride_w, const int dilation_h,
const int dilation_w, const int output_channels,
Mat &output) {
BNN_ASSERT(weight.total() % weight.n == 0, "");
const auto HWC = weight.total() / weight.n;
int input_y = 0;
FORZ(th, output.h) {
int input_x = 0;
Expand All @@ -91,7 +93,7 @@ inline void baseline_bconv(const Mat &input, const Mat &weight,
FORZ(ww, kernel_w) {
int x = input_x - pad_w + ww * dilation_w;
FORZ(wc, input.c) {
int idx = tc * kernel_h * kernel_w * input.c +
int idx = tc * HWC +
wh * kernel_w * input.c + ww * input.c +
wc;
const auto w_value =
Expand Down
17 changes: 15 additions & 2 deletions common/common_bitpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,24 @@

#include <common/helper.h>

inline void pack_64_bitset(const float *fptr, uint64_t *buf) {
inline void pack_64_bitset(const float *fptr, uint64_t *buf,
const size_t eff_bits = 64) {
/**
* The eff_bits is to support non-128-multiple channels.
* In this case, we need pad the tensor to make the
* channel aligned with 128.
*/
// BNN_ASSERT(eff_bits == 64, eff_bits);
const size_t UNIT_LEN = 64;
BNN_ASSERT(eff_bits <= UNIT_LEN, "The eff_bits ", eff_bits,
" must be smaller than UNIT_LEN ", UNIT_LEN);
std::bitset<UNIT_LEN> bits;
for (size_t i = 0; i < UNIT_LEN; i++) {
bits[i] = (*(fptr + i) > 0);
if (i < eff_bits) {
bits[i] = (*(fptr + i) > 0);
} else {
bits[i] = 0;
}
}
static_assert(sizeof(decltype(bits.to_ullong())) * CHAR_BIT == 64,
"bits.to_ullong() must return a 64-bit element");
Expand Down
1 change: 1 addition & 0 deletions common/dab.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ table Tensor {
float32_data: [float32];
shape: [uint32];
name: string;
align_hwc_to_128: bool;
}

table Input {
Expand Down
Loading

0 comments on commit 49bc3e9

Please sign in to comment.