Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Add support for Bipolar and Binary FINN datatype for Quant op. (#41)
Browse files Browse the repository at this point in the history
* Add support for Bipolar and Binary FINN datatype for Quant op.

* [Quant] custom qnt execution for bipolar until Brevitas bug resolved

Co-authored-by: Yaman Umuroglu <[email protected]>
  • Loading branch information
heborras and maltanar authored Aug 26, 2021
1 parent 3aaaff4 commit 1429935
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions src/finn/custom_op/general/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import onnx.helper as helper

import onnx.helper as helper
from finn.core.datatype import DataType
from finn.custom_op.base import CustomOp

Expand Down Expand Up @@ -97,14 +97,19 @@ def quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow):
# Scaling
y_int = inp_tensor / scale
y_int = y_int + zeropt
# Clamping
min_int_val = min_int(signed, narrow, bitwidth)
max_int_val = max_int(signed, narrow, bitwidth)
y_int = np.where(y_int > max_int_val, max_int_val.astype(y_int.dtype), y_int)
y_int = np.where(y_int < min_int_val, min_int_val.astype(y_int.dtype), y_int)
# Rounding
y_int = np.round(y_int)

if bitwidth == 1 and signed:
# BUG: 1-bit Quant ops currently not exported correctly
# manually convert to bipolar values
y_ones = np.ones(y_int.shape, dtype=y_int.dtype)
y_int = np.where(y_int >= 0.0, y_ones, -y_ones)
else:
# Clamping
min_int_val = min_int(signed, narrow, bitwidth)
max_int_val = max_int(signed, narrow, bitwidth)
y_int = np.where(y_int > max_int_val, max_int_val.astype(y_int.dtype), y_int)
y_int = np.where(y_int < min_int_val, min_int_val.astype(y_int.dtype), y_int)
# Rounding
y_int = np.round(y_int)
# Re-scaling
out_tensor = y_int - zeropt
out_tensor = out_tensor * scale
Expand Down Expand Up @@ -165,10 +170,16 @@ def get_quant_config(self, model):
zero_zeropt = np.all(zeropt == 0.0)
assert zero_zeropt, "Only zero_point=0 Quant nodes supported for now"
if unit_scale and zero_zeropt:
if signed:
finn_dt = DataType["INT" + str(bitwidth)]
if bitwidth == 1:
if signed:
finn_dt = DataType["BIPOLAR"]
else:
finn_dt = DataType["BINARY"]
else:
finn_dt = DataType["UINT" + str(bitwidth)]
if signed:
finn_dt = DataType["INT" + str(bitwidth)]
else:
finn_dt = DataType["UINT" + str(bitwidth)]
else:
if signed:
finn_dt = DataType["SCALEDINT" + str(bitwidth)]
Expand Down

0 comments on commit 1429935

Please sign in to comment.