Skip to content

Commit

Permalink
Merge pull request #33 from JDAI-CV/docs
Browse files Browse the repository at this point in the history
English Docs
  • Loading branch information
daquexian authored Jun 5, 2019
2 parents 62160c4 + 26b4fcf commit 66dc6a9
Show file tree
Hide file tree
Showing 59 changed files with 1,342 additions and 4 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ We publish two pretrained binary neural network models based on [Bi-Real Net](ht

## Implementation Details

We plan to participate the [ACM Multimedia 2019 Open Source Software Competition](https://www.acmmm.org/2019/call-for-open-source-software-competition/). Our implementation details will be presented in a 4-page short paper soon.
* The Implementation of Binary Convolutions: [docs/bconv.md](docs/bconv.md)

* Model Conversion: [docs/onnx2bnn.md](docs/onnx2bnn.md)

## Example project

Expand Down
4 changes: 3 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ dabnn_bireal18_imagenet_stem 43279353 ns 41533009 ns 14 <---

## 技术细节

我们计划参加 [ACM Multimedia 2019 Open Source Software Competition](https://www.acmmm.org/2019/call-for-open-source-software-competition/). dabnn 的技术细节很快会在一篇四页的短论文中描述。
* Binary Convolutions 的实现: [docs/bconv.md](docs/bconv_CN.md)

* 模型转换: [docs/onnx2bnn.md](docs/onnx2bnn_CN.md)

## 示例工程

Expand Down
47 changes: 47 additions & 0 deletions docs/bconv.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
## Bit-packing

Bit-packing is performed in `Binarize` layers. It pack N 32-bit float/integer to an N-bit operand according their signs. For example, performing bit-packing on 128 float numbers produces a 128-bit operand. xnor/xor is only enabled on these packed operands.

The details of bit-packing are in

* https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bitpack.h#L20 (optimized, for tensors of 128 and more channels)
* https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bitpack.h#L204 (normal, for tensors of less than 128 channels)

The optmized version is 4X faster than the normal version. Bit-packing algorithm directly leverage the sign bits of int32 and IEEE 754 float numbers, and then eliminate the comparison with zeros. SIMD instructions are also used to speed up this process. Note that after SIMD instructions is performed, the N bit in the result will be re-arranged so that they are not in the same order with the N 32-bit inputs. Fortunately, the output of xnor/xor is not affected as long as the input and weight is re-arranged in the same way. Given this observation, we re-arranged the weights of binary convs whose inputs is bit-packed in the optmized way. The details are in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/net.cpp#L82.

dabnn present the following two optmized implementation of binary convs.

## BGEMM

SGEMM (Single float GEneral Matrix Multiplication) is a widely adopted approach to implement float convolutions in various high-performance scientific programs. In the context of BNNs, an alternative operation to SGEMM is BGEMM, which performs binary matrix multiplication for binary convolution after [im2col](https://github.com/JDAI-CV/dabnn/blob/master/dabnn/im2col.h). dabnn present optmized BGEMM. The advantage of GEMM is that it covers all cases of convolutions (various kernel size, stride, padding, ..) and it is easy to implement.

The detailed implementation of BGEMM is in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bgemm.h.

## Binary Direct Convolution

However, we argue that BGEMM is sub-optimal for BGEMM especially on ARM devices.

In addition to the common multiplication and add operations, BGEMM includes extra operations that count how many 1s are in a vector. Specifically, we denote <img src="svgs/88cf5350b4c645c31edaa0cbba3ee5f9.svg" align=middle width=48.70555799999999pt height=27.6567522pt/> as the space of matrices with dimension <img src="svgs/252b59b1233ed40f0396e2cd369f514d.svg" align=middle width=52.83089789999999pt height=22.465723500000017pt/> and each element of it is a bit-packed vector. Given two matrices (i.e., <img src="svgs/c82ef99e46a995ca2c9e5865a66d022f.svg" align=middle width=83.29436114999999pt height=27.6567522pt/> and <img src="svgs/8edf0a665654dc211972f609f97cb684.svg" align=middle width=81.91594949999998pt height=27.6567522pt/>), <img src="svgs/68d27da8ea3f60dda13e915b722c2c25.svg" align=middle width=82.22933399999998pt height=27.6567522pt/> (<img src="svgs/4fd661cfefdf4318d1aa35fb483796b2.svg" align=middle width=11.87217899999999pt height=22.648391699999998pt/> represents the set of non-negative integers), <img src="svgs/f29d99803e443e4e6e87180539b3197f.svg" align=middle width=160.69418474999998pt height=24.65753399999998pt/> is measured as:
<p align="center"><img src="svgs/f8b4daba6c4183a3c1000ebb2d64de5f.svg" align=middle width=274.75177454999994pt height=27.1234854pt/></p>

where <img src="svgs/7adcdcafe095c28283fc5a319a9b6cdb.svg" align=middle width=28.14985964999999pt height=31.799054100000024pt/> and <img src="svgs/a6b6654f6dbe55b7fa2c8f5104fb8370.svg" align=middle width=29.743322399999986pt height=31.799054100000024pt/> denotes each element in <img src="svgs/ff7cbf533a4e41019c689366004849fb.svg" align=middle width=14.29216634999999pt height=22.55708729999998pt/> and <img src="svgs/d0b09e58d8b197fff6fc95ea3bca20fe.svg" align=middle width=15.037050599999992pt height=22.55708729999998pt/>. In SGEMM, to amortize the cost of loading memory, <img src="svgs/f6128a2d469857252e8e52385e7a00c5.svg" align=middle width=14.57641844999999pt height=22.55708729999998pt/> is often calculated as
<p align="center"><img src="svgs/cc1dbbcd450fb3182ca125d94560c60d.svg" align=middle width=96.21701429999999pt height=17.9744895pt/></p>
<p align="center"><img src="svgs/f5feb9f32839cb69ccdf8b0838d8c7cb.svg" align=middle width=77.2440273pt height=17.9744895pt/></p>

where <img src="svgs/1cb45f0e1e422f5a042ce0dc8710ed27.svg" align=middle width=24.97105709999999pt height=27.91243950000002pt/> is the <img src="svgs/9034606aa4dd18758a6889347abf0302.svg" align=middle width=21.21969464999999pt height=22.831056599999986pt/> column of <img src="svgs/09e963a9a257d451169d317f04f4cf59.svg" align=middle width=14.29216634999999pt height=22.55708729999998pt/> and <img src="svgs/f0fa7d7a09a30703b30ba8aae9c1c1b5.svg" align=middle width=19.71994364999999pt height=27.91243950000002pt/> is the <img src="svgs/9034606aa4dd18758a6889347abf0302.svg" align=middle width=21.21969464999999pt height=22.831056599999986pt/> row of <img src="svgs/d0b09e58d8b197fff6fc95ea3bca20fe.svg" align=middle width=15.037050599999992pt height=22.55708729999998pt/>.

In particular, on ARMv8 (the 64-bit ARM architecture) devices, the operation of bitcount contains two instructions: "cnt" and "addv". "cnt" takes an <img src="svgs/f9c4988898e7f532b9f826a75014ed3c.svg" align=middle width=14.99998994999999pt height=22.465723500000017pt/>-byte vector <img src="svgs/c745b9b57c145ec5577b82542b2df546.svg" align=middle width=10.57650494999999pt height=14.15524440000002pt/> as input and outputs an <img src="svgs/f9c4988898e7f532b9f826a75014ed3c.svg" align=middle width=14.99998994999999pt height=22.465723500000017pt/>-byte vector <img src="svgs/8217ed3c32a785f0b5aad4055f432ad8.svg" align=middle width=10.16555099999999pt height=22.831056599999986pt/>, which <img src="svgs/dc61d515b6f36dadf6ab7371698a9ef1.svg" align=middle width=196.27936515pt height=24.65753399999998pt/> where <img src="svgs/2e32e0141d372413f25c35045d246695.svg" align=middle width=15.16654589999999pt height=14.15524440000002pt/> and <img src="svgs/f3d9f6f447d13bcef7127ff6c98710a3.svg" align=middle width=13.948864049999989pt height=22.831056599999986pt/> are the <img src="svgs/22aefc0b275701a94e3684ede71e1cbf.svg" align=middle width=18.32504519999999pt height=21.68300969999999pt/> byte of <img src="svgs/c745b9b57c145ec5577b82542b2df546.svg" align=middle width=10.57650494999999pt height=14.15524440000002pt/> and <img src="svgs/8217ed3c32a785f0b5aad4055f432ad8.svg" align=middle width=10.16555099999999pt height=22.831056599999986pt/> respectively. "addv" sums up all bytes in a vector and outputs the aggregated scalar. The equation is then expanded as:
<p align="center"><img src="svgs/2c08f38f094ac03aea56779378242468.svg" align=middle width=245.01614279999998pt height=25.139101349999997pt/></p>

Thus, the above equation shows that the operation of binary multiply-addition on ARMv8 devices consists of four instructions: xnor, cnt, addv, and addition. Moreover, on ARMv7 (the 32-bit ARM architecture) devices, there is even no "addv" instruction and <img src="svgs/2e67a96431b169a7b134a2ab4c5f3457.svg" align=middle width=60.95890844999999pt height=24.65753399999998pt/> instructions are needed to sum up all bytes in an <img src="svgs/f9c4988898e7f532b9f826a75014ed3c.svg" align=middle width=14.99998994999999pt height=22.465723500000017pt/>-byte vector, so the operation of binary multiply-addition consists of <img src="svgs/4723cf14b1da3a0da99410e67984882d.svg" align=middle width=89.26930814999999pt height=24.65753399999998pt/> instructions on these devices. To improve the efficiency of this operation, we re-arrange the calculation order and calculate <img src="svgs/904d8a3dfde39f4fb05df9337f05b65f.svg" align=middle width=160.69418474999998pt height=24.65753399999998pt/> as the multiplication of a row vector <img src="svgs/15e03f3c82848a46865db186cb4c1092.svg" align=middle width=71.45799539999999pt height=27.6567522pt/> and <img src="svgs/5615b81594cc5f5f54f6c86a17443fea.svg" align=middle width=73.22358779999999pt height=27.6567522pt/>:
<p align="center"><img src="svgs/25bbbd23c3609fee3f26aa5f809dbe2e.svg" align=middle width=86.85560895pt height=19.4813124pt/></p>

where <img src="svgs/81299da238f63ff881f8365a2a3b638a.svg" align=middle width=15.260267549999991pt height=27.91243950000002pt/> is the <img src="svgs/22aefc0b275701a94e3684ede71e1cbf.svg" align=middle width=18.32504519999999pt height=21.68300969999999pt/> row of <img src="svgs/09e963a9a257d451169d317f04f4cf59.svg" align=middle width=14.29216634999999pt height=22.55708729999998pt/> and <img src="svgs/ab03e97f653c3b2963d6a503b2a9719b.svg" align=middle width=16.23744374999999pt height=27.91243950000002pt/> is the <img src="svgs/3b5fe08410dc2e357ad56d5e09c013c5.svg" align=middle width=19.43124974999999pt height=21.68300969999999pt/> column of <img src="svgs/d0b09e58d8b197fff6fc95ea3bca20fe.svg" align=middle width=15.037050599999992pt height=22.55708729999998pt/>.

In this way, the cost of "addv" instructions can be mostly squeezed by summing up the results of "cnt" in advance:
<p align="center"><img src="svgs/9998129ab540f7bc0985032e06e974ed.svg" align=middle width=238.7193303pt height=27.1234854pt/></p>
<p align="center"><img src="svgs/eb7ee640b8ff98c0068ed4d9ec3baf60.svg" align=middle width=128.08869314999998pt height=20.602701899999996pt/></p>

Please note that the same transformation can not be employed in BGEMM because <img src="svgs/f6128a2d469857252e8e52385e7a00c5.svg" align=middle width=14.57641844999999pt height=22.55708729999998pt/> is stored as 32-bit integers to save the valuable registers. Therefore in the equation of BGEMM, we have to utilize "addv" to reduce the vector into an integer before every instruction of "addition". Taking a close look on the above two equations, we can observe some interesting connections between them and the operation of convolution. Specifically, if we treat <img src="svgs/d0740c8f4fc4e3563ada4e53f43a81a1.svg" align=middle width=83.29436114999999pt height=27.6567522pt/> and <img src="svgs/8edf0a665654dc211972f609f97cb684.svg" align=middle width=81.91594949999998pt height=27.6567522pt/> as the weight and the im2col-ed input (<img src="svgs/fb97d38bcc19230b0acd442e17db879c.svg" align=middle width=17.73973739999999pt height=22.465723500000017pt/>: the number of output channels, <img src="svgs/f9c4988898e7f532b9f826a75014ed3c.svg" align=middle width=14.99998994999999pt height=22.465723500000017pt/>: output height <img src="svgs/bdbf342b57819773421273d508dba586.svg" align=middle width=12.785434199999989pt height=19.1781018pt/> output width, and <img src="svgs/d6328eaebbcd5c358f426dbea4bdbf70.svg" align=middle width=15.13700594999999pt height=22.465723500000017pt/>: the number of bit-packed vectors in a weight filter), the above two equations can be directly interpreted as the definition of convolution. As such, the refined operation of binary convolution is dubbed as "Binary Direct Convolution".

The implementation of Binary Direct Convolution is in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bconv.h.
61 changes: 61 additions & 0 deletions docs/bconv.md.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
## Bit-packing

Bit-packing is performed in `Binarize` layers. It pack N 32-bit float/integer to an N-bit operand according their signs. For example, performing bit-packing on 128 float numbers produces a 128-bit operand. xnor/xor is only enabled on these packed operands.

The details of bit-packing are in

* https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bitpack.h#L20 (optimized, for tensors of 128 and more channels)
* https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bitpack.h#L204 (normal, for tensors of less than 128 channels)

The optmized version is 4X faster than the normal version. Bit-packing algorithm directly leverage the sign bits of int32 and IEEE 754 float numbers, and then eliminate the comparison with zeros. SIMD instructions are also used to speed up this process. Note that after SIMD instructions is performed, the N bit in the result will be re-arranged so that they are not in the same order with the N 32-bit inputs. Fortunately, the output of xnor/xor is not affected as long as the input and weight is re-arranged in the same way. Given this observation, we re-arranged the weights of binary convs whose inputs is bit-packed in the optmized way. The details are in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/net.cpp#L82.

dabnn present the following two optmized implementation of binary convs.

## BGEMM

SGEMM (Single float GEneral Matrix Multiplication) is a widely adopted approach to implement float convolutions in various high-performance scientific programs. In the context of BNNs, an alternative operation to SGEMM is BGEMM, which performs binary matrix multiplication for binary convolution after [im2col](https://github.com/JDAI-CV/dabnn/blob/master/dabnn/im2col.h). dabnn present optmized BGEMM. The advantage of GEMM is that it covers all cases of convolutions (various kernel size, stride, padding, ..) and it is easy to implement.

The detailed implementation of BGEMM is in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bgemm.h.

## Binary Direct Convolution

However, we argue that BGEMM is sub-optimal for BGEMM especially on ARM devices.

In addition to the common multiplication and add operations, BGEMM includes extra operations that count how many 1s are in a vector. Specifically, we denote $U^{M \times N}$ as the space of matrices with dimension $M \times N$ and each element of it is a bit-packed vector. Given two matrices (i.e., $ \boldsymbol{A} \in U^{M \times K}$ and $\boldsymbol{B} \in U^{K \times N}$), $\boldsymbol{C} \in \mathbb{N}^{M \times N}$ ($\mathbb{N}$ represents the set of non-negative integers), $\boldsymbol{C} = BGEMM(\boldsymbol{A}, \boldsymbol{B})$ is measured as:
$$
C_{i,j} = \sum\nolimits_{k} bitcount(xnor(\Vec{A_{i,k}}, \Vec{B_{k,j}})),
$$

where $\Vec{A_{i,k}}$ and $\Vec{B_{k,j}}$ denotes each element in $ \boldsymbol{A}$ and $\boldsymbol{B}$. In SGEMM, to amortize the cost of loading memory, $\boldsymbol{C}$ is often calculated as
$$
\boldsymbol{C^{k}} = \boldsymbol{m^{k}}\boldsymbol{n^{k}},
$$
$$
\boldsymbol{C} \mathrel{+}= \boldsymbol{C^{k}},
$$

where $\boldsymbol{m^{k}}$ is the $k_{th}$ column of $\boldsymbol{A}$ and $\boldsymbol{n^{k}}$ is the $k_{th}$ row of $\boldsymbol{B}$.

In particular, on ARMv8 (the 64-bit ARM architecture) devices, the operation of bitcount contains two instructions: "cnt" and "addv". "cnt" takes an $N$-byte vector $\alpha$ as input and outputs an $N$-byte vector $\beta$, which $\beta_{i} = the\_number\_of\_1s(\alpha_{i})$ where $\alpha_{i}$ and $\beta_{i}$ are the $i_{th}$ byte of $\alpha$ and $\beta$ respectively. "addv" sums up all bytes in a vector and outputs the aggregated scalar. The equation is then expanded as:
$$
C_{i,j} \mathrel{+}= addv(cnt(xnor(\Vec{m^{k}_{i}}, \Vec{n^{k}_{j}}))).
$$

Thus, the above equation shows that the operation of binary multiply-addition on ARMv8 devices consists of four instructions: xnor, cnt, addv, and addition. Moreover, on ARMv7 (the 32-bit ARM architecture) devices, there is even no "addv" instruction and $\lceil \log_{2}N \rceil$ instructions are needed to sum up all bytes in an $N$-byte vector, so the operation of binary multiply-addition consists of $\lceil \log_{2}N \rceil+3$ instructions on these devices. To improve the efficiency of this operation, we re-arrange the calculation order and calculate $\boldsymbol{C}=BGEMM(\boldsymbol{A},\boldsymbol{B})$ as the multiplication of a row vector $\boldsymbol{p} \in U^{1 \times N}$ and $\boldsymbol{q} \in U^{M \times 1}$:
$$
C_{i,j} = \boldsymbol{p^{i}}\boldsymbol{q^{j}},
$$

where $\boldsymbol{p^{i}}$ is the $i_{th}$ row of $\boldsymbol{A}$ and $\boldsymbol{q^{j}}$ is the $j_{th}$ column of $\boldsymbol{B}$.

In this way, the cost of "addv" instructions can be mostly squeezed by summing up the results of "cnt" in advance:
$$
\Vec{C_{i,j}} = \sum\nolimits_{k} cnt(xnor(\Vec{A_{i,k}}, \Vec{B_{k,j}})),
$$
$$
C_{i,j} = addv(\Vec{C_{i,j}}).
$$

Please note that the same transformation can not be employed in BGEMM because $\boldsymbol{C}$ is stored as 32-bit integers to save the valuable registers. Therefore in the equation of BGEMM, we have to utilize "addv" to reduce the vector into an integer before every instruction of "addition". Taking a close look on the above two equations, we can observe some interesting connections between them and the operation of convolution. Specifically, if we treat $\boldsymbol{A} \in U^{M \times K}$ and $\boldsymbol{B} \in U^{K \times N}$ as the weight and the im2col-ed input ($M$: the number of output channels, $N$: output height $\times$ output width, and $K$: the number of bit-packed vectors in a weight filter), the above two equations can be directly interpreted as the definition of convolution. As such, the refined operation of binary convolution is dubbed as "Binary Direct Convolution".

The implementation of Binary Direct Convolution is in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bconv.h.
4 changes: 2 additions & 2 deletions docs/bconv_CN.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
## Bit-packing
在执行二值卷积之前,网络需要手动插入一层`Binarize`是指将 N 个 32 位的 float/integer,根据和 0 的大小关系,二值化为 N 个 bit (即 0 或 1),并打包成一个 N-bit 的整体,例如对 128 个浮点数进行 bit-packing 之后,就会产生一个 128-bit 的操作数。这一步叫做 bit-packing,做了这一步,后续才可以进行位运算 xnor/xor。
Bit-packing 在 `Binarize` 层进行,是指将 N 个 32 位的 float/integer,根据和 0 的大小关系,二值化为 N 个 bit (即 0 或 1),并打包成一个 N-bit 的整体,例如对 128 个浮点数进行 bit-packing 之后,就会产生一个 128-bit 的操作数。做了这一步,后续才可以进行位运算 xnor/xor。

Bit-packing 的具体实现在

Expand All @@ -20,6 +20,6 @@ BGEMM 的具体实现在 https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bgem

然而 BGEMM 在 ARM 设备上并不高效,因为二值乘-加操作中,加法需要两步 - bitcount 和普通的加法。Bitcount 用来得到一个 N-bit 操作数中有多少 bit 是 1。在 ARMv8 设备上,bitcount 需要两条指令,ARMv7 设备上需要更多条指令。这大大限制了 BGEMM 的速度。因此 dabnn 提出了直接卷积的方法,称为 Binary Direct Convolution (BDC),它是指直接按照卷积的定义来计算卷积。在 BDC 中,通过一个简单的变换,大部分 bitcount 指令会被消除。它的优点是性能比 BGEMM 更高,但不能像 BGEMM 一样用一套代码覆盖所有的情况。

关于 BDC 如何消除大部分 bitcount 指令,请留意我们即将 publish 的 paper
关于 BDC 如何消除大部分 bitcount 指令在 [bconv.md](bconv.md) 中有详细的说明

BDC 的具体实现在 https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bconv.h。
Loading

0 comments on commit 66dc6a9

Please sign in to comment.