-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from JDAI-CV/docs
English Docs
- Loading branch information
Showing
59 changed files
with
1,342 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
## Bit-packing | ||
|
||
Bit-packing is performed in `Binarize` layers. It pack N 32-bit float/integer to an N-bit operand according their signs. For example, performing bit-packing on 128 float numbers produces a 128-bit operand. xnor/xor is only enabled on these packed operands. | ||
|
||
The details of bit-packing are in | ||
|
||
* https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bitpack.h#L20 (optimized, for tensors of 128 and more channels) | ||
* https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bitpack.h#L204 (normal, for tensors of less than 128 channels) | ||
|
||
The optmized version is 4X faster than the normal version. Bit-packing algorithm directly leverage the sign bits of int32 and IEEE 754 float numbers, and then eliminate the comparison with zeros. SIMD instructions are also used to speed up this process. Note that after SIMD instructions is performed, the N bit in the result will be re-arranged so that they are not in the same order with the N 32-bit inputs. Fortunately, the output of xnor/xor is not affected as long as the input and weight is re-arranged in the same way. Given this observation, we re-arranged the weights of binary convs whose inputs is bit-packed in the optmized way. The details are in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/net.cpp#L82. | ||
|
||
dabnn present the following two optmized implementation of binary convs. | ||
|
||
## BGEMM | ||
|
||
SGEMM (Single float GEneral Matrix Multiplication) is a widely adopted approach to implement float convolutions in various high-performance scientific programs. In the context of BNNs, an alternative operation to SGEMM is BGEMM, which performs binary matrix multiplication for binary convolution after [im2col](https://github.com/JDAI-CV/dabnn/blob/master/dabnn/im2col.h). dabnn present optmized BGEMM. The advantage of GEMM is that it covers all cases of convolutions (various kernel size, stride, padding, ..) and it is easy to implement. | ||
|
||
The detailed implementation of BGEMM is in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bgemm.h. | ||
|
||
## Binary Direct Convolution | ||
|
||
However, we argue that BGEMM is sub-optimal for BGEMM especially on ARM devices. | ||
|
||
In addition to the common multiplication and add operations, BGEMM includes extra operations that count how many 1s are in a vector. Specifically, we denote <img src="svgs/88cf5350b4c645c31edaa0cbba3ee5f9.svg" align=middle width=48.70555799999999pt height=27.6567522pt/> as the space of matrices with dimension <img src="svgs/252b59b1233ed40f0396e2cd369f514d.svg" align=middle width=52.83089789999999pt height=22.465723500000017pt/> and each element of it is a bit-packed vector. Given two matrices (i.e., <img src="svgs/c82ef99e46a995ca2c9e5865a66d022f.svg" align=middle width=83.29436114999999pt height=27.6567522pt/> and <img src="svgs/8edf0a665654dc211972f609f97cb684.svg" align=middle width=81.91594949999998pt height=27.6567522pt/>), <img src="svgs/68d27da8ea3f60dda13e915b722c2c25.svg" align=middle width=82.22933399999998pt height=27.6567522pt/> (<img src="svgs/4fd661cfefdf4318d1aa35fb483796b2.svg" align=middle width=11.87217899999999pt height=22.648391699999998pt/> represents the set of non-negative integers), <img src="svgs/f29d99803e443e4e6e87180539b3197f.svg" align=middle width=160.69418474999998pt height=24.65753399999998pt/> is measured as: | ||
<p align="center"><img src="svgs/f8b4daba6c4183a3c1000ebb2d64de5f.svg" align=middle width=274.75177454999994pt height=27.1234854pt/></p> | ||
|
||
where <img src="svgs/7adcdcafe095c28283fc5a319a9b6cdb.svg" align=middle width=28.14985964999999pt height=31.799054100000024pt/> and <img src="svgs/a6b6654f6dbe55b7fa2c8f5104fb8370.svg" align=middle width=29.743322399999986pt height=31.799054100000024pt/> denotes each element in <img src="svgs/ff7cbf533a4e41019c689366004849fb.svg" align=middle width=14.29216634999999pt height=22.55708729999998pt/> and <img src="svgs/d0b09e58d8b197fff6fc95ea3bca20fe.svg" align=middle width=15.037050599999992pt height=22.55708729999998pt/>. In SGEMM, to amortize the cost of loading memory, <img src="svgs/f6128a2d469857252e8e52385e7a00c5.svg" align=middle width=14.57641844999999pt height=22.55708729999998pt/> is often calculated as | ||
<p align="center"><img src="svgs/cc1dbbcd450fb3182ca125d94560c60d.svg" align=middle width=96.21701429999999pt height=17.9744895pt/></p> | ||
<p align="center"><img src="svgs/f5feb9f32839cb69ccdf8b0838d8c7cb.svg" align=middle width=77.2440273pt height=17.9744895pt/></p> | ||
|
||
where <img src="svgs/1cb45f0e1e422f5a042ce0dc8710ed27.svg" align=middle width=24.97105709999999pt height=27.91243950000002pt/> is the <img src="svgs/9034606aa4dd18758a6889347abf0302.svg" align=middle width=21.21969464999999pt height=22.831056599999986pt/> column of <img src="svgs/09e963a9a257d451169d317f04f4cf59.svg" align=middle width=14.29216634999999pt height=22.55708729999998pt/> and <img src="svgs/f0fa7d7a09a30703b30ba8aae9c1c1b5.svg" align=middle width=19.71994364999999pt height=27.91243950000002pt/> is the <img src="svgs/9034606aa4dd18758a6889347abf0302.svg" align=middle width=21.21969464999999pt height=22.831056599999986pt/> row of <img src="svgs/d0b09e58d8b197fff6fc95ea3bca20fe.svg" align=middle width=15.037050599999992pt height=22.55708729999998pt/>. | ||
|
||
In particular, on ARMv8 (the 64-bit ARM architecture) devices, the operation of bitcount contains two instructions: "cnt" and "addv". "cnt" takes an <img src="svgs/f9c4988898e7f532b9f826a75014ed3c.svg" align=middle width=14.99998994999999pt height=22.465723500000017pt/>-byte vector <img src="svgs/c745b9b57c145ec5577b82542b2df546.svg" align=middle width=10.57650494999999pt height=14.15524440000002pt/> as input and outputs an <img src="svgs/f9c4988898e7f532b9f826a75014ed3c.svg" align=middle width=14.99998994999999pt height=22.465723500000017pt/>-byte vector <img src="svgs/8217ed3c32a785f0b5aad4055f432ad8.svg" align=middle width=10.16555099999999pt height=22.831056599999986pt/>, which <img src="svgs/dc61d515b6f36dadf6ab7371698a9ef1.svg" align=middle width=196.27936515pt height=24.65753399999998pt/> where <img src="svgs/2e32e0141d372413f25c35045d246695.svg" align=middle width=15.16654589999999pt height=14.15524440000002pt/> and <img src="svgs/f3d9f6f447d13bcef7127ff6c98710a3.svg" align=middle width=13.948864049999989pt height=22.831056599999986pt/> are the <img src="svgs/22aefc0b275701a94e3684ede71e1cbf.svg" align=middle width=18.32504519999999pt height=21.68300969999999pt/> byte of <img src="svgs/c745b9b57c145ec5577b82542b2df546.svg" align=middle width=10.57650494999999pt height=14.15524440000002pt/> and <img src="svgs/8217ed3c32a785f0b5aad4055f432ad8.svg" align=middle width=10.16555099999999pt height=22.831056599999986pt/> respectively. "addv" sums up all bytes in a vector and outputs the aggregated scalar. The equation is then expanded as: | ||
<p align="center"><img src="svgs/2c08f38f094ac03aea56779378242468.svg" align=middle width=245.01614279999998pt height=25.139101349999997pt/></p> | ||
|
||
Thus, the above equation shows that the operation of binary multiply-addition on ARMv8 devices consists of four instructions: xnor, cnt, addv, and addition. Moreover, on ARMv7 (the 32-bit ARM architecture) devices, there is even no "addv" instruction and <img src="svgs/2e67a96431b169a7b134a2ab4c5f3457.svg" align=middle width=60.95890844999999pt height=24.65753399999998pt/> instructions are needed to sum up all bytes in an <img src="svgs/f9c4988898e7f532b9f826a75014ed3c.svg" align=middle width=14.99998994999999pt height=22.465723500000017pt/>-byte vector, so the operation of binary multiply-addition consists of <img src="svgs/4723cf14b1da3a0da99410e67984882d.svg" align=middle width=89.26930814999999pt height=24.65753399999998pt/> instructions on these devices. To improve the efficiency of this operation, we re-arrange the calculation order and calculate <img src="svgs/904d8a3dfde39f4fb05df9337f05b65f.svg" align=middle width=160.69418474999998pt height=24.65753399999998pt/> as the multiplication of a row vector <img src="svgs/15e03f3c82848a46865db186cb4c1092.svg" align=middle width=71.45799539999999pt height=27.6567522pt/> and <img src="svgs/5615b81594cc5f5f54f6c86a17443fea.svg" align=middle width=73.22358779999999pt height=27.6567522pt/>: | ||
<p align="center"><img src="svgs/25bbbd23c3609fee3f26aa5f809dbe2e.svg" align=middle width=86.85560895pt height=19.4813124pt/></p> | ||
|
||
where <img src="svgs/81299da238f63ff881f8365a2a3b638a.svg" align=middle width=15.260267549999991pt height=27.91243950000002pt/> is the <img src="svgs/22aefc0b275701a94e3684ede71e1cbf.svg" align=middle width=18.32504519999999pt height=21.68300969999999pt/> row of <img src="svgs/09e963a9a257d451169d317f04f4cf59.svg" align=middle width=14.29216634999999pt height=22.55708729999998pt/> and <img src="svgs/ab03e97f653c3b2963d6a503b2a9719b.svg" align=middle width=16.23744374999999pt height=27.91243950000002pt/> is the <img src="svgs/3b5fe08410dc2e357ad56d5e09c013c5.svg" align=middle width=19.43124974999999pt height=21.68300969999999pt/> column of <img src="svgs/d0b09e58d8b197fff6fc95ea3bca20fe.svg" align=middle width=15.037050599999992pt height=22.55708729999998pt/>. | ||
|
||
In this way, the cost of "addv" instructions can be mostly squeezed by summing up the results of "cnt" in advance: | ||
<p align="center"><img src="svgs/9998129ab540f7bc0985032e06e974ed.svg" align=middle width=238.7193303pt height=27.1234854pt/></p> | ||
<p align="center"><img src="svgs/eb7ee640b8ff98c0068ed4d9ec3baf60.svg" align=middle width=128.08869314999998pt height=20.602701899999996pt/></p> | ||
|
||
Please note that the same transformation can not be employed in BGEMM because <img src="svgs/f6128a2d469857252e8e52385e7a00c5.svg" align=middle width=14.57641844999999pt height=22.55708729999998pt/> is stored as 32-bit integers to save the valuable registers. Therefore in the equation of BGEMM, we have to utilize "addv" to reduce the vector into an integer before every instruction of "addition". Taking a close look on the above two equations, we can observe some interesting connections between them and the operation of convolution. Specifically, if we treat <img src="svgs/d0740c8f4fc4e3563ada4e53f43a81a1.svg" align=middle width=83.29436114999999pt height=27.6567522pt/> and <img src="svgs/8edf0a665654dc211972f609f97cb684.svg" align=middle width=81.91594949999998pt height=27.6567522pt/> as the weight and the im2col-ed input (<img src="svgs/fb97d38bcc19230b0acd442e17db879c.svg" align=middle width=17.73973739999999pt height=22.465723500000017pt/>: the number of output channels, <img src="svgs/f9c4988898e7f532b9f826a75014ed3c.svg" align=middle width=14.99998994999999pt height=22.465723500000017pt/>: output height <img src="svgs/bdbf342b57819773421273d508dba586.svg" align=middle width=12.785434199999989pt height=19.1781018pt/> output width, and <img src="svgs/d6328eaebbcd5c358f426dbea4bdbf70.svg" align=middle width=15.13700594999999pt height=22.465723500000017pt/>: the number of bit-packed vectors in a weight filter), the above two equations can be directly interpreted as the definition of convolution. As such, the refined operation of binary convolution is dubbed as "Binary Direct Convolution". | ||
|
||
The implementation of Binary Direct Convolution is in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bconv.h. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
## Bit-packing | ||
|
||
Bit-packing is performed in `Binarize` layers. It pack N 32-bit float/integer to an N-bit operand according their signs. For example, performing bit-packing on 128 float numbers produces a 128-bit operand. xnor/xor is only enabled on these packed operands. | ||
|
||
The details of bit-packing are in | ||
|
||
* https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bitpack.h#L20 (optimized, for tensors of 128 and more channels) | ||
* https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bitpack.h#L204 (normal, for tensors of less than 128 channels) | ||
|
||
The optmized version is 4X faster than the normal version. Bit-packing algorithm directly leverage the sign bits of int32 and IEEE 754 float numbers, and then eliminate the comparison with zeros. SIMD instructions are also used to speed up this process. Note that after SIMD instructions is performed, the N bit in the result will be re-arranged so that they are not in the same order with the N 32-bit inputs. Fortunately, the output of xnor/xor is not affected as long as the input and weight is re-arranged in the same way. Given this observation, we re-arranged the weights of binary convs whose inputs is bit-packed in the optmized way. The details are in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/net.cpp#L82. | ||
|
||
dabnn present the following two optmized implementation of binary convs. | ||
|
||
## BGEMM | ||
|
||
SGEMM (Single float GEneral Matrix Multiplication) is a widely adopted approach to implement float convolutions in various high-performance scientific programs. In the context of BNNs, an alternative operation to SGEMM is BGEMM, which performs binary matrix multiplication for binary convolution after [im2col](https://github.com/JDAI-CV/dabnn/blob/master/dabnn/im2col.h). dabnn present optmized BGEMM. The advantage of GEMM is that it covers all cases of convolutions (various kernel size, stride, padding, ..) and it is easy to implement. | ||
|
||
The detailed implementation of BGEMM is in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bgemm.h. | ||
|
||
## Binary Direct Convolution | ||
|
||
However, we argue that BGEMM is sub-optimal for BGEMM especially on ARM devices. | ||
|
||
In addition to the common multiplication and add operations, BGEMM includes extra operations that count how many 1s are in a vector. Specifically, we denote $U^{M \times N}$ as the space of matrices with dimension $M \times N$ and each element of it is a bit-packed vector. Given two matrices (i.e., $ \boldsymbol{A} \in U^{M \times K}$ and $\boldsymbol{B} \in U^{K \times N}$), $\boldsymbol{C} \in \mathbb{N}^{M \times N}$ ($\mathbb{N}$ represents the set of non-negative integers), $\boldsymbol{C} = BGEMM(\boldsymbol{A}, \boldsymbol{B})$ is measured as: | ||
$$ | ||
C_{i,j} = \sum\nolimits_{k} bitcount(xnor(\Vec{A_{i,k}}, \Vec{B_{k,j}})), | ||
$$ | ||
|
||
where $\Vec{A_{i,k}}$ and $\Vec{B_{k,j}}$ denotes each element in $ \boldsymbol{A}$ and $\boldsymbol{B}$. In SGEMM, to amortize the cost of loading memory, $\boldsymbol{C}$ is often calculated as | ||
$$ | ||
\boldsymbol{C^{k}} = \boldsymbol{m^{k}}\boldsymbol{n^{k}}, | ||
$$ | ||
$$ | ||
\boldsymbol{C} \mathrel{+}= \boldsymbol{C^{k}}, | ||
$$ | ||
|
||
where $\boldsymbol{m^{k}}$ is the $k_{th}$ column of $\boldsymbol{A}$ and $\boldsymbol{n^{k}}$ is the $k_{th}$ row of $\boldsymbol{B}$. | ||
|
||
In particular, on ARMv8 (the 64-bit ARM architecture) devices, the operation of bitcount contains two instructions: "cnt" and "addv". "cnt" takes an $N$-byte vector $\alpha$ as input and outputs an $N$-byte vector $\beta$, which $\beta_{i} = the\_number\_of\_1s(\alpha_{i})$ where $\alpha_{i}$ and $\beta_{i}$ are the $i_{th}$ byte of $\alpha$ and $\beta$ respectively. "addv" sums up all bytes in a vector and outputs the aggregated scalar. The equation is then expanded as: | ||
$$ | ||
C_{i,j} \mathrel{+}= addv(cnt(xnor(\Vec{m^{k}_{i}}, \Vec{n^{k}_{j}}))). | ||
$$ | ||
|
||
Thus, the above equation shows that the operation of binary multiply-addition on ARMv8 devices consists of four instructions: xnor, cnt, addv, and addition. Moreover, on ARMv7 (the 32-bit ARM architecture) devices, there is even no "addv" instruction and $\lceil \log_{2}N \rceil$ instructions are needed to sum up all bytes in an $N$-byte vector, so the operation of binary multiply-addition consists of $\lceil \log_{2}N \rceil+3$ instructions on these devices. To improve the efficiency of this operation, we re-arrange the calculation order and calculate $\boldsymbol{C}=BGEMM(\boldsymbol{A},\boldsymbol{B})$ as the multiplication of a row vector $\boldsymbol{p} \in U^{1 \times N}$ and $\boldsymbol{q} \in U^{M \times 1}$: | ||
$$ | ||
C_{i,j} = \boldsymbol{p^{i}}\boldsymbol{q^{j}}, | ||
$$ | ||
|
||
where $\boldsymbol{p^{i}}$ is the $i_{th}$ row of $\boldsymbol{A}$ and $\boldsymbol{q^{j}}$ is the $j_{th}$ column of $\boldsymbol{B}$. | ||
|
||
In this way, the cost of "addv" instructions can be mostly squeezed by summing up the results of "cnt" in advance: | ||
$$ | ||
\Vec{C_{i,j}} = \sum\nolimits_{k} cnt(xnor(\Vec{A_{i,k}}, \Vec{B_{k,j}})), | ||
$$ | ||
$$ | ||
C_{i,j} = addv(\Vec{C_{i,j}}). | ||
$$ | ||
|
||
Please note that the same transformation can not be employed in BGEMM because $\boldsymbol{C}$ is stored as 32-bit integers to save the valuable registers. Therefore in the equation of BGEMM, we have to utilize "addv" to reduce the vector into an integer before every instruction of "addition". Taking a close look on the above two equations, we can observe some interesting connections between them and the operation of convolution. Specifically, if we treat $\boldsymbol{A} \in U^{M \times K}$ and $\boldsymbol{B} \in U^{K \times N}$ as the weight and the im2col-ed input ($M$: the number of output channels, $N$: output height $\times$ output width, and $K$: the number of bit-packed vectors in a weight filter), the above two equations can be directly interpreted as the definition of convolution. As such, the refined operation of binary convolution is dubbed as "Binary Direct Convolution". | ||
|
||
The implementation of Binary Direct Convolution is in https://github.com/JDAI-CV/dabnn/blob/master/dabnn/bconv.h. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.