Skip to content

Commit

Permalink
[lower_convs_to_matmul]: added support for non-square input images an…
Browse files Browse the repository at this point in the history
…d kernels and non-equal padding.

[test_conv_lowering]: added/modified test cases for non-equal padding, depthwise convolution and 'standard' convolution.
  • Loading branch information
mmrahorovic committed Dec 12, 2020
1 parent c524020 commit 15e34ed
Show file tree
Hide file tree
Showing 2 changed files with 265 additions and 42 deletions.
88 changes: 61 additions & 27 deletions src/finn/transformation/lower_convs_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,21 @@
from finn.util.basic import get_by_name


def _auto_pad_to_explicit_padding(autopad_str, idim, k, stride, n_dims):
pad_total = (stride - 1) * idim - stride + k
pad_half_small = int((pad_total / 2))
pad_half_large = pad_total - pad_half_small
def _auto_pad_to_explicit_padding(
autopad_str, idim_H, idim_W, k_H, k_W, stride, n_dims
):
pad_total_H = (stride - 1) * idim_H - stride + k_H
pad_total_W = (stride - 1) * idim_W - stride + k_W
pad_half_small_H = int((pad_total_H / 2))
pad_half_small_W = int((pad_total_W / 2))
pad_half_large_H = pad_total_H - pad_half_small_H
pad_half_large_W = pad_total_W - pad_half_small_W
if autopad_str == "VALID":
return [0 for i in range(2 * n_dims)]
elif autopad_str == "SAME_UPPER":
return [pad_half_small, pad_half_large] * n_dims
return [pad_half_small_H, pad_half_small_W, pad_half_large_H, pad_half_large_W]
elif autopad_str == "SAME_LOWER":
return [pad_half_large, pad_half_small] * n_dims
return [pad_half_large_H, pad_half_large_W, pad_half_small_H, pad_half_small_W]
else:
raise Exception("Unsupported auto_pad: " + autopad_str)

Expand All @@ -65,15 +70,23 @@ def apply(self, model):
idt = model.get_tensor_datatype(cnv_input)
odt = model.get_tensor_datatype(cnv_output)
# extract conv parameters
k = get_by_name(n.attribute, "kernel_shape").ints[-1]
k = get_by_name(n.attribute, "kernel_shape").ints
if len(k) == 1: # assume square kernel
k_H = k[0]
k_W = k[0]
else:
k_H = k[0]
k_W = k[1]
stride = get_by_name(n.attribute, "strides").ints[-1]
group = get_by_name(n.attribute, "group").i
weight_name = n.input[1]
W_conv = model.get_initializer(weight_name)
ifm_ch = model.get_tensor_shape(n.input[0])[1] # assume NCHW
ofm_ch = model.get_tensor_shape(n.output[0])[1] # assume NCHW
ifm_dim = model.get_tensor_shape(n.input[0])[-1] # assume NCHW
ofm_dim = model.get_tensor_shape(n.output[0])[-1] # assume NCHW
ifm_dim_H = model.get_tensor_shape(n.input[0])[2] # assume NCHW
ifm_dim_W = model.get_tensor_shape(n.input[0])[3]
ofm_dim_H = model.get_tensor_shape(n.output[0])[2] # assume NCHW
ofm_dim_W = model.get_tensor_shape(n.output[0])[3]
# handle both auto_pad and explicit padding
auto_pad = get_by_name(n.attribute, "auto_pad")
if auto_pad is not None:
Expand All @@ -83,36 +96,53 @@ def apply(self, model):
# use specified padding
pad = get_by_name(n.attribute, "pads").ints
else:
assert auto_pad != "NOTSET", print("AUTOPAD NOT SUPPORTED YET")
pad = _auto_pad_to_explicit_padding(
auto_pad,
ifm_dim,
k,
ifm_dim_H,
ifm_dim_W,
k_H,
k_W,
stride,
len(model.get_tensor_shape(n.input[0])) - 2,
)
else:
# use specified padding
pad = get_by_name(n.attribute, "pads").ints
# ensure all pads are equal for now
assert (
len(set(pad)) <= 1
), "Only all-equal padding supported for now: " + str(pad)
pad = pad[-1]

# If len(pad) == 2, assume no padding for other dimension
if len(pad) == 2: # only one dimension should be padded
assert (
ifm_dim_H == 1 or ifm_dim_W == 1
), "Padding is assumed to be 1D, image is 2D"
if ifm_dim_H == 1: # Assumption: dim H is not padded
pad_2D = [0, 0, 0, 0]
pad_2D[1] = pad[0]
pad_2D[3] = pad[1]
elif ifm_dim_W == 1: # Assumption: dim W is not padded
pad_2D = [0, 0, 0, 0]
pad_2D[0] = pad[0]
pad_2D[2] = pad[1]
pad = pad_2D

# if depthwise conv create sparse matrix and variable "dw"
# to store as attribute in Im2Col that indicates that the created
# Im2Col node belongs to a depthwise convolution
dw = False
if group == ifm_ch and ofm_ch == ifm_ch:
W_sparse = np.zeros((ofm_ch, ifm_ch, k, k))
W_sparse = np.zeros(
(ofm_ch, ifm_ch, k_H, k_W)
) # (OFM, IFM, k_H, k_W)
for ch in range(ifm_ch):
W_sparse[ch][ch] = W_conv[ch][0]
W_sparse[ch][ch] = W_conv[ch][
0
] # W_conv = [OFM, IFM, k_H, k_W]
W_conv = W_sparse.astype(np.float32)
# we need to store information of the
# sparsity of the weight matrix. For this
# we use the sparsity annotation of the
# weight tensor
sparsity = {"dw": {"kernel_shape": k}}
sparsity = {"dw": {"kernel_shape": k_H}}
model.set_tensor_sparsity(weight_name, sparsity)
# additionally create variable "dw" to store
# as attribute in Im2Col that indicates that the created
Expand All @@ -123,9 +153,9 @@ def apply(self, model):
# conv weights are [OFM][IFM][k][k]
# first convert to [OFM][k][k][IFM] (to remain compatible with
# finn-hlslib and how it does im2col/sliding window)
W_matmul = W_conv.transpose(0, 2, 3, 1)
W_matmul = W_conv.transpose(0, 2, 3, 1) # W_conv = [OFM, IFM, k_H, k_W]
# reshape into [OFM][k*k*IFM] matrix
W_matmul = W_matmul.reshape(ofm_ch, ifm_ch * k * k)
W_matmul = W_matmul.reshape(ofm_ch, ifm_ch * k_H * k_W)
# transpose to get ONNX-compatible [k*k*IFM][OFM] matrix
W_matmul = W_matmul.T
model.set_initializer(weight_name, W_matmul)
Expand All @@ -134,21 +164,25 @@ def apply(self, model):
inp_trans_out = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
(1, ifm_dim, ifm_dim, ifm_ch), # NHWC
(1, ifm_dim_H, ifm_dim_W, ifm_ch), # NHWC
)
graph.value_info.append(inp_trans_out)
inp_trans_out = inp_trans_out.name
model.set_tensor_datatype(inp_trans_out, idt)

need_im2col = True
if k == 1 and pad == 0 and stride == 1:
if all(p == 0 for p in pad):
padding = 0

# k_H=k_W==1: pointwise convolution, thus no im2col needed
if k_H == 1 and k_W == 1 and padding == 0 and stride == 1:
need_im2col = False

if need_im2col:
im2col_out = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
(1, ofm_dim, ofm_dim, ifm_ch * k * k),
(1, ofm_dim_H, ofm_dim_W, ifm_ch * k_H * k_W),
)
graph.value_info.append(im2col_out)
im2col_out = im2col_out.name
Expand All @@ -157,7 +191,7 @@ def apply(self, model):
matmul_out = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
(1, ofm_dim, ofm_dim, ofm_ch),
(1, ofm_dim_H, ofm_dim_W, ofm_ch),
)
graph.value_info.append(matmul_out)
matmul_out = matmul_out.name
Expand All @@ -178,9 +212,9 @@ def apply(self, model):
[im2col_out],
domain="finn.custom_op.general",
stride=stride,
kernel_size=[k],
kernel_size=[k_H, k_W],
pad_amount=pad,
input_shape="(1,{},{},{})".format(ifm_dim, ifm_dim, ifm_ch),
input_shape="(1,{},{},{})".format(ifm_dim_H, ifm_dim_W, ifm_ch),
depthwise=dw,
)

Expand Down
Loading

0 comments on commit 15e34ed

Please sign in to comment.