diff --git a/README.md b/README.md
index 7e69430..fc932dc 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-## Core Components for Quantized Neural Network Inference
+## Core Components for Quantized Neural Network Inference
[![Gitter](https://badges.gitter.im/xilinx-finn/community.svg)](https://gitter.im/xilinx-finn/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
[![ReadTheDocs](https://readthedocs.org/projects/finn-base/badge/?version=latest&style=plastic)](http://finn-base.readthedocs.io/)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 6fb062c..73d9e50 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -73,7 +73,7 @@ RUN python -mpip install --upgrade pip && \
rm requirements.txt
# Install custom fork of pyverilator
-RUN pip install git+https://github.com/maltanar/pyverilator.git#egg=pyverilator
+RUN pip install git+https://github.com/maltanar/pyverilator.git@0c3eb9343500fc1352a02c020a736c8c2db47e8e
# Install pytest-xdist (not in requirements, only for faster testing in Docker)
RUN pip install pytest-xdist==2.0.0
diff --git a/src/finn/core/data_layout.py b/src/finn/core/data_layout.py
index 630a25f..4a5d87a 100644
--- a/src/finn/core/data_layout.py
+++ b/src/finn/core/data_layout.py
@@ -31,5 +31,19 @@
NHWC = ["N", "H", "W", "C"]
NCHW = ["N", "C", "H", "W"]
+NCW = ["N", "C", "W"]
+NWC = ["N", "W", "C"]
NC = ["N", "C"]
UNKNOWN = []
+
+
+def is_channels_last(layout):
+ return layout[-1] == "C"
+
+
+def get_channels_last_layout_for_ndims(ndims):
+ return {4: NHWC, 3: NWC, 2: NC}[ndims]
+
+
+def get_channels_first_layout_for_ndims(ndims):
+ return {4: NCHW, 3: NCW, 2: NC}[ndims]
diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py
index 12d0b38..eec52a1 100644
--- a/src/finn/core/modelwrapper.py
+++ b/src/finn/core/modelwrapper.py
@@ -58,7 +58,9 @@ def __init__(self, onnx_model_proto, make_deepcopy=False):
is made internally.
"""
if isinstance(onnx_model_proto, str):
- assert os.path.isfile(onnx_model_proto)
+ assert os.path.isfile(
+ onnx_model_proto
+ ), f"File not found: {onnx_model_proto}"
self._model_proto = onnx.load(onnx_model_proto)
elif isinstance(onnx_model_proto, bytes):
self._model_proto = onnx.load_from_string(onnx_model_proto)
@@ -217,7 +219,11 @@ def get_tensor_valueinfo(self, tensor_name):
vi_names += [(x.name, x) for x in graph.output]
vi_names += [(x.name, x) for x in graph.value_info]
try:
- vi_ind = [x[0] for x in vi_names].index(tensor_name)
+ vi_t_names = [x[0] for x in vi_names]
+ assert vi_t_names.count(tensor_name) <= 1, (
+ "Multiple ValueInfoProto found for " + tensor_name
+ )
+ vi_ind = vi_t_names.index(tensor_name)
vi = vi_names[vi_ind][1]
return vi
except ValueError:
@@ -230,7 +236,11 @@ def get_tensor_shape(self, tensor_name):
vi_names += [(x.name, x) for x in graph.output]
vi_names += [(x.name, x) for x in graph.value_info]
try:
- vi_ind = [x[0] for x in vi_names].index(tensor_name)
+ vi_t_names = [x[0] for x in vi_names]
+ assert vi_t_names.count(tensor_name) <= 1, (
+ "Multiple ValueInfoProto found for " + tensor_name
+ )
+ vi_ind = vi_t_names.index(tensor_name)
vi = vi_names[vi_ind][1]
dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
return dims
@@ -240,6 +250,8 @@ def get_tensor_shape(self, tensor_name):
def set_tensor_shape(self, tensor_name, tensor_shape, dtype=TensorProto.FLOAT):
"""Assigns shape in ValueInfoProto for tensor with given name."""
new_vi = oh.make_tensor_value_info(tensor_name, dtype, tensor_shape)
+ # call get_tensor_shape to catch multiple ValueInfoProto cases
+ self.get_tensor_shape(tensor_name)
# find what container tis tensor's ValueInfo lives in
# if not found anywhere, we assume it's a new value_info
target_container = self.graph.value_info
@@ -534,13 +546,7 @@ def get_tensor_layout(self, tensor_name):
def set_tensor_layout(self, tensor_name, data_layout):
"""Sets the data layout annotation of tensor with given name. See
get_tensor_layout for examples."""
- tensor_shape = self.get_tensor_shape(tensor_name)
assert type(data_layout) == list, "data_layout must be a list"
- if tensor_shape is not None:
- assert len(tensor_shape) == len(
- data_layout
- ), """Mismatch between number
- of dimensions of tensor shape and data layout annotation."""
graph = self._model_proto.graph
qnt_annotations = graph.quantization_annotation
ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name")
diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py
index 5de1afd..9ba8a47 100644
--- a/src/finn/core/onnx_exec.py
+++ b/src/finn/core/onnx_exec.py
@@ -51,7 +51,7 @@ def execute_node(node, context, graph, return_full_exec_context=False):
Input/output provided via context."""
- if node.op_type == "GenericPartition":
+ if node.op_type in ["GenericPartition", "StreamingDataflowPartition"]:
partition_node = getCustomOp(node)
model = ModelWrapper(partition_node.get_nodeattr("model"))
inp_ctx = dict(filter(lambda x: x[0] in node.input, context.items()))
@@ -71,32 +71,6 @@ def execute_node(node, context, graph, return_full_exec_context=False):
for tname in ret.keys():
if tname not in [x.name for x in model.graph.output]:
context[node.name + "_" + tname] = ret[tname]
- elif node.op_type == "StreamingDataflowPartition":
- sdp_node = getCustomOp(node)
- model = ModelWrapper(sdp_node.get_nodeattr("model"))
- inp_ctx = dict(filter(lambda x: x[0] in node.input, context.items()))
- # input may have been renamed in partition
- assert len(inp_ctx) == 1
- old_iname = node.input[0]
- new_iname = model.graph.input[0].name
- if old_iname != new_iname:
- inp_ctx[new_iname] = inp_ctx[old_iname]
- del inp_ctx[old_iname]
- ret = execute_onnx(model, inp_ctx, return_full_exec_context)
- # if the model was in ip-stitched rtlsim mode, may get annotation
- # for numbet of elapsed cycles, save again
- if model.get_metadata_prop("exec_mode") == "rtlsim":
- model.save(sdp_node.get_nodeattr("model"))
- # output may have been renamed in partition
- assert len(model.graph.output) == 1
- node_oname = node.output[0]
- model_oname = model.graph.output[0].name
- context[node_oname] = ret[model_oname]
- # prefix and insert exec context entries
- if return_full_exec_context:
- for tname in ret.keys():
- if tname != model_oname:
- context[node.name + "_" + tname] = ret[tname]
else:
if is_finn_op(node.domain):
ex_cu_node.execute_custom_node(node, context, graph)
@@ -108,7 +82,9 @@ def execute_node(node, context, graph, return_full_exec_context=False):
# graph.value_info as well as graph.output or graph.input
# nodes with multiple outputs that are a mix of value_info and
# input/outputs may get them reordered below
+ # note: a node's input may (also) be a top-level input or output
node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
+ node_inputs += list(filter(lambda x: x.name in node.input, graph.output))
node_inputs += list(
filter(lambda x: x.name in node.input, graph.value_info)
)
diff --git a/src/finn/custom_op/general/quantavgpool2d.py b/src/finn/custom_op/general/quantavgpool2d.py
index 148e266..99a9d43 100644
--- a/src/finn/custom_op/general/quantavgpool2d.py
+++ b/src/finn/custom_op/general/quantavgpool2d.py
@@ -94,7 +94,7 @@ def make_shape_compatible_op(self, model):
def infer_node_datatype(self, model):
node = self.onnx_node
bw = self.get_nodeattr("obits")
- if bw in [2, 4, 8, 16, 32]:
+ if bw in range(2, 33):
if self.get_nodeattr("signed") == 0:
dtype = DataType["UINT%d" % bw]
else:
diff --git a/src/finn/transformation/change_3d_tensors_to_4d.py b/src/finn/transformation/change_3d_tensors_to_4d.py
index 251f609..23912a6 100644
--- a/src/finn/transformation/change_3d_tensors_to_4d.py
+++ b/src/finn/transformation/change_3d_tensors_to_4d.py
@@ -56,6 +56,12 @@ def _find_invalid_nodes(model):
"Transpose",
"LogSoftmax",
"ArgMax",
+ "Div",
+ "TopK",
+ "MatMul",
+ "Flatten",
+ "Reshape",
+ "MaxPool",
]
invalid_nodes = []
for n in model.graph.node:
@@ -96,7 +102,7 @@ def apply(self, model):
model = model.transform(RemoveUnusedTensors())
# This list contains all nodes with initializers that need to be converted
- nodes_with_initializers = ["Mul", "Conv", "Add"]
+ nodes_with_initializers = ["Mul", "Conv", "Add", "Div", "Reshape"]
# Obtain a list of initializer names (used to filter out only value infos)
initializers_names = [x.name for x in model.graph.initializer]
@@ -118,8 +124,7 @@ def apply(self, model):
if x.name not in initializers_names
},
}
- # Extract only initializers from Conv, Mul and Add nodes (which are the
- # only ones relevant for conversion)
+ # Extract only initializers from nodes that are relevant for conversion
all_tensors = {
**all_tensors,
**{
@@ -143,10 +148,11 @@ def apply(self, model):
tensors_reduced_dimension = []
for n in model.graph.node:
node_op_type = n.op_type
+ input_shape = model.get_tensor_shape(n.input[0])
# Find tensors that are the output of nodes that reduce the dimension
if node_op_type == "ArgMax":
keep_dims = get_by_name(n.attribute, "keepdims", "name").i
- if keep_dims == 0:
+ if len(input_shape) == 3 and keep_dims == 0:
node_out = n.output
for n_o in node_out:
tensors_reduced_dimension.append(n_o)
@@ -158,10 +164,10 @@ def apply(self, model):
len(perm) == 3
): # Meaning that the transpose operation was on a 3D tensor
perm.append(3) # append 4th dimension
- elif node_op_type == "ArgMax" or node_op_type == "LogSoftMax":
+ elif node_op_type in ["ArgMax", "LogSoftMax", "TopK", "Flatten"]:
axis = get_by_name(n.attribute, "axis", "name")
- if axis.i == -1:
- axis.i = 2 # argmax is now on the second-to-last axis
+ if len(input_shape) == 3 and axis.i < 0:
+ axis.i = 3 + axis.i # count dimensions from the front
elif node_op_type == "Conv":
dilations = get_by_name(n.attribute, "dilations", "name").ints
kernel_shape = get_by_name(n.attribute, "kernel_shape", "name").ints
@@ -180,6 +186,19 @@ def apply(self, model):
pads.append(0)
if len(strides) == 1: # strides = [stride_h, stride_w]
strides.append(1)
+ elif node_op_type == "MaxPool":
+ kernel_shape = get_by_name(n.attribute, "kernel_shape", "name").ints
+ pads = get_by_name(n.attribute, "pads", "name").ints
+ strides = get_by_name(n.attribute, "strides", "name").ints
+ if len(kernel_shape) == 1: # we must add another dimension to it
+ kernel_shape.append(1)
+ if (
+ len(pads) == 2
+ ): # pads = [x1_begin, x1_end] --> [x1_begin, x2_begin, x1_end, x2_end]
+ pads.insert(1, 0)
+ pads.append(0)
+ if len(strides) == 1: # strides = [stride_h, stride_w]
+ strides.append(1)
# Change format of each input/value_info/output tensor
for k, v in all_tensors.items():
diff --git a/src/finn/transformation/create_generic_partitions.py b/src/finn/transformation/create_generic_partitions.py
index 67da854..00430ed 100755
--- a/src/finn/transformation/create_generic_partitions.py
+++ b/src/finn/transformation/create_generic_partitions.py
@@ -131,33 +131,23 @@ def apply(self, model):
to_check = next_to_check
# set p graph in/out to be p_in/p_out
- for x in p_model.graph.input:
- p_model.graph.input.remove(x)
+ while len(p_model.graph.input) > 0:
+ p_model.graph.input.pop()
for i in p_in_vi:
p_model.graph.input.append(i)
- for x in p_model.graph.output:
- p_model.graph.output.remove(x)
+ while len(p_model.graph.output) > 0:
+ p_model.graph.output.pop()
for o in p_out_vi:
p_model.graph.output.append(o)
# remove redundant input and output value_info entries
for i in p_in_vi:
- # the tensor can be both an input and value_info, so we also have to
- # ensure that the tensor is not a relevant value_info before removing
- if (
- i in p_model.graph.value_info
- and p_model.find_producer(i.name) is None
- ):
+ if i in p_model.graph.value_info:
p_model.graph.value_info.remove(i)
for o in p_out_vi:
- # the tensor can both an output and value_info, so we also have to
- # ensure that the tensor is not a relevant value_info before removing
- if (
- o in p_model.graph.value_info
- and p_model.find_consumers(o.name) is None
- ):
+ if o in p_model.graph.value_info:
p_model.graph.value_info.remove(o)
# save partition model
diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py
index 475a8d4..e2fb54e 100644
--- a/src/finn/transformation/general.py
+++ b/src/finn/transformation/general.py
@@ -134,9 +134,13 @@ def apply(self, model):
if model.get_initializer(i) is not None:
model.rename_tensor(i, "%s_param%d" % (n.name, init_in_num))
init_in_num += 1
- # give special names to the main model input and output
- model.rename_tensor(model.graph.input[0].name, "global_in")
- model.rename_tensor(model.graph.output[0].name, "global_out")
+ # give special names to the model inputs and outputs
+ for i, inp in enumerate(model.graph.input):
+ iname = "global_in" if i == 0 else "global_in_%d" % i
+ model.rename_tensor(inp.name, iname)
+ for i, outp in enumerate(model.graph.output):
+ oname = "global_out" if i == 0 else "global_out_%d" % i
+ model.rename_tensor(outp.name, oname)
# return model_was_changed = False as single iteration is always enough
return (model, False)
diff --git a/src/finn/transformation/infer_data_layouts.py b/src/finn/transformation/infer_data_layouts.py
index 7066a66..4bae4d4 100644
--- a/src/finn/transformation/infer_data_layouts.py
+++ b/src/finn/transformation/infer_data_layouts.py
@@ -63,35 +63,40 @@ def _infer_node_data_layout(model, node):
"""Infer output data layout annotation(s) for a particular node.
Returns True if any changes were made."""
old_layouts = list(map(lambda x: model.get_tensor_layout(x), node.output))
- if is_finn_op(node.domain):
- # try to guess based on number of output dims
- for o in node.output:
- ndims = len(model.get_tensor_shape(o))
- new_layout = _dims_to_layout(model, node, ndims)
- model.set_tensor_layout(o, new_layout)
- else:
- if node.op_type == "Transpose":
- # grab input annotation and switch it around using perm
- perm = get_by_name(node.attribute, "perm").ints
- inp_layout = model.get_tensor_layout(node.input[0])
- out_layout = [inp_layout[i] for i in perm]
- model.set_tensor_layout(node.output[0], out_layout)
- elif node.op_type == "Unsqueeze":
- inp_layout = model.get_tensor_layout(node.input[0])
- # add dummy dimension at the output
- out_layout = inp_layout + ["x"]
- model.set_tensor_layout(node.output[0], out_layout)
- elif node.op_type == "Squeeze":
- inp_layout = model.get_tensor_layout(node.input[0])
- assert inp_layout[-1] == "x"
- # remove dummy dimension
- out_layout = inp_layout[:-1]
- model.set_tensor_layout(node.output[0], out_layout)
- else:
+ try:
+ if is_finn_op(node.domain):
# try to guess based on number of output dims
for o in node.output:
ndims = len(model.get_tensor_shape(o))
- model.set_tensor_layout(o, _dims_to_layout(model, node, ndims))
+ new_layout = _dims_to_layout(model, node, ndims)
+ model.set_tensor_layout(o, new_layout)
+ else:
+ if node.op_type == "Transpose":
+ # grab input annotation and switch it around using perm
+ perm = get_by_name(node.attribute, "perm").ints
+ inp_layout = model.get_tensor_layout(node.input[0])
+ out_layout = [inp_layout[i] for i in perm]
+ model.set_tensor_layout(node.output[0], out_layout)
+ elif node.op_type == "Unsqueeze":
+ inp_layout = model.get_tensor_layout(node.input[0])
+ # add dummy dimension at the output
+ out_layout = inp_layout + ["x"]
+ model.set_tensor_layout(node.output[0], out_layout)
+ elif node.op_type == "Squeeze":
+ inp_layout = model.get_tensor_layout(node.input[0])
+ assert inp_layout[-1] == "x"
+ # remove dummy dimension
+ out_layout = inp_layout[:-1]
+ model.set_tensor_layout(node.output[0], out_layout)
+ else:
+ # try to guess based on number of output dims
+ for o in node.output:
+ ndims = len(model.get_tensor_shape(o))
+ model.set_tensor_layout(o, _dims_to_layout(model, node, ndims))
+ except Exception:
+ for o in node.output:
+ model.set_tensor_layout(o, DataLayout.UNKNOWN)
+
# compare old and new output dtypes to see if anything changed
new_layouts = list(map(lambda x: model.get_tensor_layout(x), node.output))
graph_modified = new_layouts != old_layouts
diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index 66d91ca..e76ea69 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -29,7 +29,7 @@
import finn.custom_op.registry as registry
from finn.core.datatype import DataType
from finn.transformation.base import Transformation
-from finn.util.basic import is_finn_op
+from finn.util.basic import get_by_name, is_finn_op
def _infer_node_datatype(model, node):
@@ -41,7 +41,19 @@ def _infer_node_datatype(model, node):
"Flatten",
"Slice",
"Gather",
+ "GatherElements",
+ "GatherND",
"Identity",
+ "Expand",
+ "Flatten",
+ "MaxPool",
+ "GlobalMaxPool",
+ "Scatter",
+ "ScatterElements",
+ "ScatterND",
+ "Squeeze",
+ "Unsqueeze",
+ "Tile",
]
idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
@@ -72,6 +84,16 @@ def _infer_node_datatype(model, node):
else:
odtype = DataType.UINT32
model.set_tensor_datatype(node.output[0], odtype)
+ elif node.op_type in ["Resize", "Upsample"]:
+ mode = get_by_name(node.attribute, "mode").s
+ if mode is None:
+ mode = "nearest"
+ else:
+ mode = mode.decode("UTF-8")
+ if mode == "nearest":
+ # set output dtype = input dtype
+ idtype = model.get_tensor_datatype(node.input[0])
+ model.set_tensor_datatype(node.output[0], idtype)
elif node.op_type in dt_identity_optypes:
# set output dtype = input dtype
idtype = model.get_tensor_datatype(node.input[0])
diff --git a/src/finn/transformation/make_input_chanlast.py b/src/finn/transformation/make_input_chanlast.py
new file mode 100644
index 0000000..9819abd
--- /dev/null
+++ b/src/finn/transformation/make_input_chanlast.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2021 Xilinx, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of Xilinx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from onnx import helper as oh
+
+import finn.core.data_layout as data_layout
+from finn.transformation.base import Transformation
+
+
+class MakeInputChannelsLast(Transformation):
+ """For networks with an input using the NCx data layout, add a transpose node
+ at the beginning and mark the input as using NxC (channels-last)."""
+
+ def __init__(self):
+ super().__init__()
+
+ def apply(self, model):
+ graph_in_name = model.graph.input[0].name
+ graph_new_in_name = graph_in_name + "_transposed"
+ orig_ishape = model.get_tensor_shape(graph_in_name)
+ ndim = len(orig_ishape)
+ if ndim == 2:
+ # assume NC layout, no action needed
+ return (model, False)
+ elif ndim > 2:
+ orig_layout = model.get_tensor_layout(graph_in_name)
+ if orig_layout == data_layout.get_channels_last_layout_for_ndims(ndim):
+ # already marked as channels-last, no action needed
+ return (model, False)
+ else:
+ # determine channels-last shape and required permutation to
+ # go from channels-last to previous format
+ new_perm = list(range(ndim))
+ new_perm.remove(ndim - 1)
+ new_perm.insert(1, ndim - 1)
+ new_ishape = list(orig_ishape)
+ new_ishape.remove(orig_ishape[1])
+ new_ishape.append(orig_ishape[1])
+ # create and insert transpose node
+ t_trans_node = oh.make_node(
+ "Transpose", [graph_in_name], [graph_new_in_name], perm=new_perm
+ )
+ model.graph.node.insert(0, t_trans_node)
+ # rewire all consumers of original input to transpose's output
+ consumers = model.find_consumers(graph_in_name)
+ for cons in consumers:
+ if cons == t_trans_node:
+ continue
+ for i, ci in enumerate(cons.input):
+ if ci == graph_in_name:
+ cons.input[i] = graph_new_in_name
+ # set tensor shapes and layouts
+ model.set_tensor_shape(graph_in_name, new_ishape)
+ model.set_tensor_shape(graph_new_in_name, orig_ishape)
+ model.set_tensor_layout(
+ graph_in_name, data_layout.get_channels_last_layout_for_ndims(ndim)
+ )
+ model.set_tensor_layout(
+ graph_new_in_name,
+ data_layout.get_channels_first_layout_for_ndims(ndim),
+ )
+ # single iteration is enough so return model_was_changed=False
+ return (model, False)
diff --git a/src/finn/util/platforms.py b/src/finn/util/platforms.py
new file mode 100644
index 0000000..6a94812
--- /dev/null
+++ b/src/finn/util/platforms.py
@@ -0,0 +1,480 @@
+# Copyright (c) 2021, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+from abc import abstractmethod
+
+# contains the amount of available FPGA resources for several
+# Xilinx platforms, as well as certain resource limit guidelines
+# for creating designs that can achieve timing closure
+
+# explicit value for res types/costs we don't care about
+DONT_CARE = -1
+# recommended resource limits from Xilinx for timing closure
+# respectively for LUT, FF, BRAM_18K, URAM, DSP res types
+DEFAULT_RES_LIMITS = np.array([0.7, 0.5, 0.80, 0.80, 0.80])
+DEFAULT_AVG_CONSTRAINTS = [((2, 3, 4), 0.7)] #
+
+# resources required to instantiate certain infrastructure components
+# such as memory controllers and network interfaces
+DDR_RESOURCE_REQUIREMENTS = {
+ "LUT": 33256,
+ "FF": 44889,
+ "BRAM_18K": 199,
+ "URAM": 0,
+ "DSP": 3,
+}
+HBM_RESOURCE_REQUIREMENTS = {
+ "LUT": 10718,
+ "FF": 21793,
+ "BRAM_18K": 8,
+ "URAM": 0,
+ "DSP": 0,
+}
+
+# we assume use of VNx Alveo UDP stack
+# see: https://gitenterprise.xilinx.com/mruiznog/vitis_network_layer
+ETH_RESOURCE_REQUIREMENTS = {
+ "LUT": 35219,
+ "FF": 86269,
+ "BRAM_18K": 183,
+ "URAM": 0,
+ "DSP": 0,
+}
+
+
+class Platform:
+ def __init__(
+ self,
+ nslr=1,
+ ndevices=1,
+ sll_count=[],
+ hbm_slr=-1,
+ ddr_slr=[0],
+ eth_slr=0,
+ eth_gbps=0,
+ limits=DEFAULT_RES_LIMITS,
+ avg_constraints=DEFAULT_AVG_CONSTRAINTS,
+ ):
+ self.nslr = nslr
+ self.sll_count = sll_count
+ self.eth_slr = eth_slr
+ self.eth_gbps = eth_gbps
+ self.ndevices = ndevices
+ self.hbm_slr = hbm_slr
+ self.ddr_slr = ddr_slr
+ # limits must be a np.array either of
+ # the same shape as compute_resources
+ # or broadcastable to it
+ self.res_limits = limits
+ # list of tuples of the form ( tuple of resource positions to avg, limit )
+ self.avg_constraints = avg_constraints
+
+ @property
+ @abstractmethod
+ def compute_resources(self):
+ pass
+
+ @property
+ def guide_resources(self):
+ guide = []
+ # TODO: assert limits is of correct size
+ guide_res = (
+ np.tile(np.array(self.compute_resources), (self.ndevices, 1))
+ ).astype(int)
+ for i in range(self.nslr * self.ndevices):
+ # when in multi-FPGA mode, subtract cost of UDP connection from eth_slr
+ local_slr = i % self.nslr
+ if self.ndevices > 1 and local_slr == self.eth_slr:
+ guide_res[i][0] -= ETH_RESOURCE_REQUIREMENTS["LUT"]
+ guide_res[i][1] -= ETH_RESOURCE_REQUIREMENTS["FF"]
+ guide_res[i][2] -= ETH_RESOURCE_REQUIREMENTS["BRAM_18K"]
+ guide_res[i][3] -= ETH_RESOURCE_REQUIREMENTS["URAM"]
+ guide_res[i][4] -= ETH_RESOURCE_REQUIREMENTS["DSP"]
+ # subtract the cost of memory controllers
+ # if we have a choice between DDR and HBM, use HBM
+ if local_slr == self.hbm_slr:
+ guide_res[i][0] -= HBM_RESOURCE_REQUIREMENTS["LUT"]
+ guide_res[i][1] -= HBM_RESOURCE_REQUIREMENTS["FF"]
+ guide_res[i][2] -= HBM_RESOURCE_REQUIREMENTS["BRAM_18K"]
+ guide_res[i][3] -= HBM_RESOURCE_REQUIREMENTS["URAM"]
+ guide_res[i][4] -= HBM_RESOURCE_REQUIREMENTS["DSP"]
+ elif local_slr in self.ddr_slr:
+ guide_res[i][0] -= DDR_RESOURCE_REQUIREMENTS["LUT"]
+ guide_res[i][1] -= DDR_RESOURCE_REQUIREMENTS["FF"]
+ guide_res[i][2] -= DDR_RESOURCE_REQUIREMENTS["BRAM_18K"]
+ guide_res[i][3] -= DDR_RESOURCE_REQUIREMENTS["URAM"]
+ guide_res[i][4] -= DDR_RESOURCE_REQUIREMENTS["DSP"]
+ guide.append(list(guide_res[i]))
+ return guide
+
+ @property
+ def resource_count_dict(self):
+ res = dict()
+ for i in range(self.nslr * self.ndevices):
+ slr_res = dict()
+ slr_res["LUT"] = self.compute_resources[i % self.nslr][0]
+ slr_res["FF"] = self.compute_resources[i % self.nslr][1]
+ slr_res["BRAM_18K"] = self.compute_resources[i % self.nslr][2]
+ slr_res["URAM"] = self.compute_resources[i % self.nslr][3]
+ slr_res["DSP"] = self.compute_resources[i % self.nslr][4]
+ res["slr" + str(i)] = slr_res
+ return res
+
+ @property
+ def compute_connection_cost(self):
+ x = np.full((self.nslr * self.ndevices, self.nslr * self.ndevices), DONT_CARE)
+ # build connection cost matrix for one device's SLRs
+ xlocal = np.full((self.nslr, self.nslr), DONT_CARE)
+ for i in range(self.nslr):
+ for j in range(self.nslr):
+ if i == j:
+ xlocal[i][j] = 0
+ elif abs(i - j) == 1:
+ xlocal[i][j] = 1
+ # tile connection cost matrices for entire system
+ for i in range(self.ndevices):
+ x[
+ i * self.nslr : (i + 1) * self.nslr, i * self.nslr : (i + 1) * self.nslr
+ ] = xlocal
+ # set cost for ethernet connections, assuming daisy-chaining
+ for i in range(self.ndevices - 1):
+ x[i * self.nslr + self.eth_slr][(i + 1) * self.nslr + self.eth_slr] = 10
+ x[(i + 1) * self.nslr + self.eth_slr][i * self.nslr + self.eth_slr] = 10
+ return x
+
+ @property
+ def compute_connection_resource(self):
+ sll = np.full((self.nslr * self.ndevices, self.nslr * self.ndevices), 0)
+ # build connection resource matrix for one device's SLRs
+ slllocal = np.full((self.nslr, self.nslr), -1)
+ for i in range(self.nslr):
+ for j in range(self.nslr):
+ if i == j:
+ # no SLL constraint when going from one SLR to itself
+ slllocal[i][j] = -1
+ else:
+ slllocal[i][j] = self.sll_count[i][j]
+ # tile connection cost matrices for entire system
+ for i in range(self.ndevices):
+ sll[
+ i * self.nslr : (i + 1) * self.nslr, i * self.nslr : (i + 1) * self.nslr
+ ] = slllocal
+ # set cost for ethernet connections, assuming daisy-chaining
+ eth = np.full((self.nslr * self.ndevices, self.nslr * self.ndevices), 0)
+ # no Eth throughput constraints from one SLR to itself
+ for i in range(self.ndevices * self.nslr):
+ eth[i][i] = -1
+ # apply symmetric ETH throughput constraints between the SLRs that have GTXes
+ for i in range(self.ndevices - 1):
+ eth[i * self.nslr + self.eth_slr][
+ (i + 1) * self.nslr + self.eth_slr
+ ] = self.eth_gbps * (10 ** 9)
+ eth[(i + 1) * self.nslr + self.eth_slr][
+ i * self.nslr + self.eth_slr
+ ] = self.eth_gbps * (10 ** 9)
+ # pack sll and eth info in one list-of-list-of-tuple structure
+ constraints = []
+ for i in range(self.ndevices * self.nslr):
+ constraints_line = []
+ for j in range(self.ndevices * self.nslr):
+ # make sure not to constrain both resources at the same time
+ # constrain for Eth throughput between SLRs on different devices
+ # constrain for SLLs between SLRs on same device
+ is_offchip = i // self.nslr != j // self.nslr
+ constraints_line.append(
+ (-1 if is_offchip else sll[i][j], eth[i][j] if is_offchip else -1)
+ )
+ constraints.append(constraints_line)
+ return constraints
+
+ def map_device_to_slr(self, idx):
+ """Given a global SLR index, return device id and local slr index"""
+ assert idx <= self.nslr * self.ndevices
+ return (idx % self.nslr, idx // self.nslr)
+
+
+class Zynq7020_Platform(Platform):
+ def __init__(
+ self,
+ ndevices=1,
+ limits=DEFAULT_RES_LIMITS,
+ avg_constraints=DEFAULT_AVG_CONSTRAINTS,
+ ):
+ super(Zynq7020_Platform, self).__init__(
+ nslr=1,
+ ndevices=ndevices,
+ sll_count=[[0]],
+ ddr_slr=[],
+ eth_slr=0,
+ eth_gbps=1,
+ limits=limits,
+ avg_constraints=avg_constraints,
+ )
+
+ @property
+ def compute_resources(self):
+ return [[53200, 2 * 53200, 280, 0, 220] for i in range(1)]
+
+
+class ZU3EG_Platform(Platform):
+ def __init__(
+ self,
+ ndevices=1,
+ limits=DEFAULT_RES_LIMITS,
+ avg_constraints=DEFAULT_AVG_CONSTRAINTS,
+ ):
+ super(ZU3EG_Platform, self).__init__(
+ nslr=1,
+ ndevices=ndevices,
+ sll_count=[[0]],
+ ddr_slr=[],
+ eth_slr=0,
+ eth_gbps=1,
+ limits=limits,
+ avg_constraints=avg_constraints,
+ )
+
+ @property
+ def compute_resources(self):
+ return [[71000, 2 * 71000, 412, 0, 360] for i in range(1)]
+
+
+class ZU7EV_Platform(Platform):
+ def __init__(
+ self,
+ ndevices=1,
+ limits=DEFAULT_RES_LIMITS,
+ avg_constraints=DEFAULT_AVG_CONSTRAINTS,
+ ):
+ super(ZU7EV_Platform, self).__init__(
+ nslr=1,
+ ndevices=ndevices,
+ sll_count=[[0]],
+ ddr_slr=[],
+ eth_slr=0,
+ eth_gbps=1,
+ limits=limits,
+ avg_constraints=avg_constraints,
+ )
+
+ @property
+ def compute_resources(self):
+ return [[230000, 2 * 230000, 610, 92, 1728] for i in range(1)]
+
+
+class ZU9EG_Platform(Platform):
+ def __init__(
+ self,
+ ndevices=1,
+ limits=DEFAULT_RES_LIMITS,
+ avg_constraints=DEFAULT_AVG_CONSTRAINTS,
+ ):
+ super(ZU9EG_Platform, self).__init__(
+ nslr=1,
+ ndevices=ndevices,
+ sll_count=[[0]],
+ ddr_slr=[],
+ eth_slr=0,
+ eth_gbps=1,
+ limits=limits,
+ avg_constraints=avg_constraints,
+ )
+
+ @property
+ def compute_resources(self):
+ return [[274000, 2 * 274000, 1824, 0, 2520] for i in range(1)]
+
+
+class ZU28DR_Platform(Platform):
+ def __init__(
+ self,
+ ndevices=1,
+ limits=DEFAULT_RES_LIMITS,
+ avg_constraints=DEFAULT_AVG_CONSTRAINTS,
+ ):
+ super(ZU28DR_Platform, self).__init__(
+ nslr=1,
+ ndevices=ndevices,
+ sll_count=[[0]],
+ ddr_slr=[],
+ eth_slr=0,
+ eth_gbps=1,
+ limits=limits,
+ avg_constraints=avg_constraints,
+ )
+
+ @property
+ def compute_resources(self):
+ return [[425000, 2 * 425000, 2160, 80, 4272] for i in range(1)]
+
+
+class Alveo_NxU50_Platform(Platform):
+ def __init__(
+ self,
+ ndevices=1,
+ limits=DEFAULT_RES_LIMITS,
+ avg_constraints=DEFAULT_AVG_CONSTRAINTS,
+ ):
+ # according to Vivado: 23040 SLR0 <-> SLR1
+ sll_counts = [[0, 5000], [5000, 0]]
+ super(Alveo_NxU50_Platform, self).__init__(
+ nslr=2,
+ ndevices=ndevices,
+ sll_count=sll_counts,
+ ddr_slr=[],
+ hbm_slr=0,
+ eth_slr=1,
+ eth_gbps=100,
+ limits=limits,
+ avg_constraints=avg_constraints,
+ )
+
+ @property
+ def compute_resources(self):
+ # According to UG1120:
+ # U50 has identical resource counts on both SLRs
+ # return [[365000,2*365000,2*564, 304, 2580] for i in range(2)]
+ # we observe from Vivado that the resource counts are actually:
+ return [
+ [374400, 2 * 374400, 2 * 564, 304, 2592],
+ [368160, 2 * 368160, 2 * 564, 304, 2760],
+ ]
+
+
+class Alveo_NxU200_Platform(Platform):
+ def __init__(
+ self,
+ ndevices=1,
+ limits=DEFAULT_RES_LIMITS,
+ avg_constraints=DEFAULT_AVG_CONSTRAINTS,
+ ):
+ sll_counts = [[0, 5000, 0], [5000, 0, 5000], [0, 5000, 0]]
+ super(Alveo_NxU200_Platform, self).__init__(
+ nslr=3,
+ ndevices=ndevices,
+ sll_count=sll_counts,
+ ddr_slr=[0, 2],
+ eth_slr=2,
+ eth_gbps=100,
+ limits=limits,
+ avg_constraints=avg_constraints,
+ )
+
+ @property
+ def compute_resources(self):
+ # According to UG1120:
+ # return [[355000, 723000, 2*638, 320, 2265],
+ # [160000, 331000, 2*326, 160, 1317],
+ # [355000, 723000, 2*638, 320, 2265]]
+ # we observe from Vivado that the resource counts are actually:
+ return [
+ [385920, 2 * 385920, 2 * 714, 320, 2268],
+ [199680, 2 * 199680, 2 * 420, 160, 1320],
+ [385920, 2 * 385920, 2 * 714, 320, 2268],
+ ]
+
+
+class Alveo_NxU250_Platform(Platform):
+ def __init__(
+ self,
+ ndevices=1,
+ limits=DEFAULT_RES_LIMITS,
+ avg_constraints=DEFAULT_AVG_CONSTRAINTS,
+ ):
+ sll_counts = [
+ [0, 5000, 0, 0],
+ [5000, 0, 5000, 0],
+ [0, 5000, 0, 5000],
+ [0, 0, 5000, 0],
+ ]
+ super(Alveo_NxU250_Platform, self).__init__(
+ nslr=4,
+ ndevices=ndevices,
+ sll_count=sll_counts,
+ ddr_slr=[0, 1, 2, 3],
+ eth_slr=3,
+ eth_gbps=100,
+ limits=limits,
+ avg_constraints=avg_constraints,
+ )
+
+ @property
+ def compute_resources(self):
+ # According to UG1120:
+ # U250 has identical resource counts on all 4 SLRs:
+ # return [[345000,2*345000,2*500, 320, 2877] for i in range(4)]
+ # we observe from Vivado that the resource counts are actually:
+ return [[375000, 2 * 375000, 2 * 576, 320, 2880] for i in range(4)]
+
+
+class Alveo_NxU280_Platform(Platform):
+ def __init__(
+ self,
+ ndevices=1,
+ limits=DEFAULT_RES_LIMITS,
+ avg_constraints=DEFAULT_AVG_CONSTRAINTS,
+ ):
+ sll_counts = [[0, 5000, 0], [5000, 0, 5000], [0, 5000, 0]]
+ super(Alveo_NxU280_Platform, self).__init__(
+ nslr=3,
+ ndevices=ndevices,
+ sll_count=sll_counts,
+ ddr_slr=[0, 1],
+ hbm_slr=0,
+ eth_slr=2,
+ eth_gbps=100,
+ limits=limits,
+ avg_constraints=avg_constraints,
+ )
+
+ @property
+ def compute_resources(self):
+ # according to UG1120
+ # return [[369000, 746000, 2*507, 320, 2733],
+ # [333000, 675000, 2*468, 320, 2877],
+ # [367000, 729000, 2*512, 320, 2880]]
+ # observed from Vivado:
+ return [
+ [400800, 2 * 400800, 2 * 600, 320, 2736],
+ [382080, 2 * 382080, 2 * 576, 320, 2880],
+ [380640, 2 * 380640, 2 * 576, 320, 2880],
+ ]
+
+
+platforms = dict()
+platforms["U50"] = Alveo_NxU50_Platform
+platforms["U200"] = Alveo_NxU200_Platform
+platforms["U250"] = Alveo_NxU250_Platform
+platforms["U280"] = Alveo_NxU280_Platform
+platforms["Pynq-Z1"] = Zynq7020_Platform
+platforms["Pynq-Z2"] = Zynq7020_Platform
+platforms["Ultra96"] = ZU3EG_Platform
+platforms["ZCU104"] = ZU7EV_Platform
+platforms["ZCU102"] = ZU9EG_Platform
+platforms["ZCU111"] = ZU28DR_Platform
diff --git a/src/finn/util/pyverilator.py b/src/finn/util/pyverilator.py
index b598a4a..78e6706 100644
--- a/src/finn/util/pyverilator.py
+++ b/src/finn/util/pyverilator.py
@@ -72,7 +72,7 @@ def rtlsim_multi_io(sim, io_dict, num_out_values, trace_file="", sname="_V_V_"):
sim.start_vcd_trace(trace_file)
for outp in io_dict["outputs"]:
- sim.io[outp + sname + "TREADY"] = 1
+ _write_signal(sim, outp + sname + "TREADY", 1)
# observe if output is completely calculated
# total_cycle_count will contain the number of cycles the calculation ran
@@ -89,11 +89,13 @@ def rtlsim_multi_io(sim, io_dict, num_out_values, trace_file="", sname="_V_V_"):
while not (output_done):
for inp in io_dict["inputs"]:
inputs = io_dict["inputs"][inp]
- sim.io[inp + sname + "TVALID"] = 1 if len(inputs) > 0 else 0
- sim.io[inp + sname + "TDATA"] = inputs[0] if len(inputs) > 0 else 0
+ _write_signal(sim, inp + sname + "TVALID", 1 if len(inputs) > 0 else 0)
+ _write_signal(
+ sim, inp + sname + "TDATA", inputs[0] if len(inputs) > 0 else 0
+ )
if (
- sim.io[inp + sname + "TREADY"] == 1
- and sim.io[inp + sname + "TVALID"] == 1
+ _read_signal(sim, inp + sname + "TREADY") == 1
+ and _read_signal(sim, inp + sname + "TVALID") == 1
):
inputs = inputs[1:]
io_dict["inputs"][inp] = inputs
@@ -101,15 +103,13 @@ def rtlsim_multi_io(sim, io_dict, num_out_values, trace_file="", sname="_V_V_"):
for outp in io_dict["outputs"]:
outputs = io_dict["outputs"][outp]
if (
- sim.io[outp + sname + "TVALID"] == 1
- and sim.io[outp + sname + "TREADY"] == 1
+ _read_signal(sim, outp + sname + "TREADY") == 1
+ and _read_signal(sim, outp + sname + "TVALID") == 1
):
- outputs = outputs + [sim.io[outp + sname + "TDATA"]]
- output_count += 1
+ outputs = outputs + [_read_signal(sim, outp + sname + "TDATA")]
io_dict["outputs"][outp] = outputs
- sim.io.ap_clk = 1
- sim.io.ap_clk = 0
+ toggle_clk(sim)
total_cycle_count = total_cycle_count + 1
@@ -141,12 +141,20 @@ def rtlsim_multi_io(sim, io_dict, num_out_values, trace_file="", sname="_V_V_"):
return total_cycle_count
-def pyverilate_stitched_ip(model, read_internal_signals=True):
+def pyverilate_stitched_ip(
+ model, read_internal_signals=True, disable_common_warnings=True
+):
"""Given a model with stitched IP, return a PyVerilator sim object.
- If read_internal_signals is True, it will be possible to examine the
- internal (not only port) signals of the Verilog module, but this may
- slow down compilation and emulation.
Trace depth is also controllable, see get_rtlsim_trace_depth()
+
+ :param read_internal_signals If set, it will be possible to examine the
+ internal (not only port) signals of the Verilog module, but this may
+ slow down compilation and emulation.
+
+ :param disable_common_warnings If set, disable the set of warnings that
+ Vivado-HLS-generated Verilog typically triggers in Verilator
+ (which can be very verbose otherwise)
+
"""
if PyVerilator is None:
raise ImportError("Installation of PyVerilator is required.")
@@ -192,6 +200,19 @@ def file_to_basename(x):
wf.write("//Added from " + vfile + "\n\n")
wf.write(rf.read())
+ verilator_args = []
+ # disable common verilator warnings that should be harmless but commonly occur
+ # in large quantities for Vivado HLS-generated verilog code
+ if disable_common_warnings:
+ verilator_args += ["-Wno-STMTDLY"]
+ verilator_args += ["-Wno-PINMISSING"]
+ verilator_args += ["-Wno-IMPLICIT"]
+ verilator_args += ["-Wno-WIDTH"]
+ verilator_args += ["-Wno-COMBDLY"]
+ # force inlining of all submodules to ensure we can read internal signals properly
+ if read_internal_signals:
+ verilator_args += ["--inline-mult", "0"]
+
sim = PyVerilator.build(
top_module_file_name,
verilog_path=[vivado_stitch_proj_dir],
@@ -200,6 +221,7 @@ def file_to_basename(x):
top_module_name=top_module_name,
auto_eval=False,
read_internal_signals=read_internal_signals,
+ extra_args=verilator_args,
)
return sim
diff --git a/tests/transformation/test_4d_conversion.py b/tests/transformation/test_4d_conversion.py
index 18fe9cc..d6eb11c 100644
--- a/tests/transformation/test_4d_conversion.py
+++ b/tests/transformation/test_4d_conversion.py
@@ -1,3 +1,5 @@
+import pytest
+
import numpy as np
import onnx
@@ -26,7 +28,7 @@ def generate_random_input(model):
def set_all_initializers(model):
"""Sets all initializers of the graph to a random value."""
for n in model.graph.node:
- if len(n.input) > 1:
+ if len(n.input) > 1 and n.name != "TopK1":
init_name = n.input[1]
init_shape = model.get_tensor_shape(init_name)
init_val = gen_finn_dt_tensor(DataType.FLOAT32, init_shape)
@@ -189,11 +191,153 @@ def create_arbitrary_model(invalid=False):
return model
-def test_4d_conversion():
+def create_arbitrary_model_vgg():
+ """
+ Creates arbitrary model for testing the 3D to 4D transform.
+ This model is based on a subpart of VGG10.
+ """
+ Conv1_node = onnx.helper.make_node(
+ "Conv",
+ inputs=["in1_conv1", "in2_conv1"],
+ outputs=["out1_conv1"],
+ name="Conv1",
+ dilations=[1],
+ group=1,
+ kernel_shape=[3],
+ pads=[1, 1],
+ strides=[1],
+ )
+
+ Div1_node = onnx.helper.make_node(
+ "Div", inputs=["out1_conv1", "in2_div1"], outputs=["out1_div1"], name="Div1"
+ )
+
+ MaxPool1_node = onnx.helper.make_node(
+ "MaxPool",
+ inputs=["out1_div1"],
+ outputs=["out1_maxpool1"],
+ name="MaxPool1",
+ kernel_shape=[2],
+ pads=[0, 0],
+ strides=[2],
+ )
+
+ Flatten1_node = onnx.helper.make_node(
+ "Flatten",
+ inputs=["out1_maxpool1"],
+ outputs=["out1_flatten1"],
+ name="Flatten1",
+ axis=1,
+ )
+
+ MatMul1_node = onnx.helper.make_node(
+ "MatMul",
+ inputs=["out1_flatten1", "in2_matmul1"],
+ outputs=["out1_matmul1"],
+ name="MatMul1",
+ )
+
+ TopK1_node = onnx.helper.make_node(
+ "TopK",
+ inputs=["out1_matmul1", "in2topk1"],
+ outputs=["out1_topk1", "out2_topk1"],
+ name="TopK1",
+ axis=-1,
+ largest=1,
+ sorted=1,
+ )
+
+ # Inputs and outputs
+ in1_conv1 = onnx.helper.make_tensor_value_info(
+ "in1_conv1", onnx.TensorProto.FLOAT, [1, 64, 16]
+ )
+ out2_topk1 = onnx.helper.make_tensor_value_info(
+ "out2_topk1", onnx.TensorProto.INT64, [1, 3]
+ )
+
+ # Value infos
+ out1_conv1 = onnx.helper.make_tensor_value_info(
+ "out1_conv1", onnx.TensorProto.FLOAT, [1, 64, 16]
+ )
+ out1_div1 = onnx.helper.make_tensor_value_info(
+ "out1_div1", onnx.TensorProto.FLOAT, [1, 64, 16]
+ )
+ out1_maxpool1 = onnx.helper.make_tensor_value_info(
+ "out1_maxpool1", onnx.TensorProto.FLOAT, [1, 64, 8]
+ )
+ out1_flatten1 = onnx.helper.make_tensor_value_info(
+ "out1_flatten1", onnx.TensorProto.FLOAT, [1, 512]
+ )
+ out1_matmul1 = onnx.helper.make_tensor_value_info(
+ "out1_matmul1", onnx.TensorProto.FLOAT, [1, 24]
+ )
+ out1_topk1 = onnx.helper.make_tensor_value_info(
+ "out1_topk1", onnx.TensorProto.FLOAT, [1, 3]
+ )
+
+ # Initializers
+ in2_conv1 = onnx.helper.make_tensor_value_info(
+ "in2_conv1", onnx.TensorProto.FLOAT, [64, 64, 3]
+ )
+ in2_div1 = onnx.helper.make_tensor_value_info(
+ "in2_div1", onnx.TensorProto.FLOAT, [1]
+ )
+ in2_matmul1 = onnx.helper.make_tensor_value_info(
+ "in2_matmul1", onnx.TensorProto.FLOAT, [512, 24]
+ )
+ in2topk1 = onnx.helper.make_tensor_value_info(
+ "in2topk1", onnx.TensorProto.FLOAT, [1]
+ )
+
+ list_of_nodes = [
+ Conv1_node,
+ Div1_node,
+ MaxPool1_node,
+ Flatten1_node,
+ MatMul1_node,
+ TopK1_node,
+ ]
+ list_of_value_infos = [
+ out1_conv1,
+ out1_div1,
+ out1_maxpool1,
+ out1_flatten1,
+ out1_matmul1,
+ out1_topk1,
+ in2_conv1,
+ in2_div1,
+ in2_matmul1,
+ in2topk1,
+ ]
+
+ graph = onnx.helper.make_graph(
+ nodes=list_of_nodes,
+ name="4d_conversion_test_graph",
+ inputs=[in1_conv1],
+ outputs=[out2_topk1],
+ value_info=list_of_value_infos,
+ )
+ onnx_model = onnx.helper.make_model(graph, producer_name="4d_conversion_test-model")
+ model = ModelWrapper(onnx_model)
+
+ # Fixed TopK initializer (K=3)
+ model.set_initializer("in2topk1", np.array([3]))
+
+ return model
+
+
+@pytest.mark.parametrize("test_model", ["Quartz", "VGG"])
+def test_4d_conversion(test_model):
"""
Test for the 3D to 4D transformation with a valid graph.
"""
- model = create_arbitrary_model(invalid=False)
+
+ if test_model == "Quartz":
+ model = create_arbitrary_model(invalid=False)
+ elif test_model == "VGG":
+ model = create_arbitrary_model_vgg()
+ else:
+ raise Exception("Unknown test_model in test_4d_conversion")
# Inputs
input_dict = generate_random_input(model)
diff --git a/tests/transformation/test_batchnorm_to_affine.py b/tests/transformation/test_batchnorm_to_affine.py
index 4adc874..821338f 100644
--- a/tests/transformation/test_batchnorm_to_affine.py
+++ b/tests/transformation/test_batchnorm_to_affine.py
@@ -64,7 +64,7 @@ def test_batchnorm_to_affine_shufflenet():
op_types = list(map(lambda x: x.op_type, new_model.graph.node))
assert "BatchNormalization" not in op_types
produced = oxe.execute_onnx(new_model, input_dict)[oname]
- assert np.isclose(expected, produced).all()
+ assert np.isclose(expected, produced, atol=1e-05).all()
os.remove(export_onnx_path)
diff --git a/tests/transformation/test_infer_datatypes.py b/tests/transformation/test_infer_datatypes.py
index 2b18a88..96b1b28 100644
--- a/tests/transformation/test_infer_datatypes.py
+++ b/tests/transformation/test_infer_datatypes.py
@@ -46,13 +46,10 @@ def test_infer_datatypes():
# this model has no DataType info, so add some DataType annotation
# to make things a bit more exciting
model.set_tensor_datatype("global_in", DataType.UINT8)
- # manual non-float annotations on regular ONNX nodes won't disappear
- # (InferDataTypes assumes they've been put there with good reason)
- model.set_tensor_datatype("MaxPool_1_out0", DataType.INT4)
- # MatMul with int weights + inputs will have int output datatype
- model.set_tensor_datatype("MatMul_0_param0", DataType.UINT8)
+ # Conv with int weights + inputs will have int output datatype
+ model.set_tensor_datatype("Conv_0_param0", DataType.INT4)
model = model.transform(InferDataTypes())
assert model.get_tensor_datatype("global_in") == DataType.UINT8
- assert model.get_tensor_datatype("Reshape_0_out0") == DataType.INT4
- assert model.get_tensor_datatype("MatMul_0_out0") == DataType.INT32
+ assert model.get_tensor_datatype("Conv_0_out0") == DataType.INT32
+ assert model.get_tensor_datatype("Relu_0_out0") == DataType.FLOAT32
assert model.get_tensor_datatype("global_out") == DataType.FLOAT32
diff --git a/tests/transformation/test_make_input_chanlast.py b/tests/transformation/test_make_input_chanlast.py
new file mode 100644
index 0000000..4e4f894
--- /dev/null
+++ b/tests/transformation/test_make_input_chanlast.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2021 Xilinx, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of Xilinx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+from pkgutil import get_data
+
+import finn.core.data_layout as data_layout
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.make_input_chanlast import MakeInputChannelsLast
+
+
+def test_make_input_chanlast():
+ # load the onnx model
+ raw_m = get_data("finn.data", "onnx/mnist-conv/model.onnx")
+ model = ModelWrapper(raw_m)
+ iname = model.graph.input[0].name
+ assert tuple(model.get_tensor_shape(iname)) == (1, 1, 28, 28)
+ model = model.transform(MakeInputChannelsLast())
+ assert model.graph.node[0].op_type == "Transpose"
+ assert tuple(model.get_tensor_shape(iname)) == (1, 28, 28, 1)
+ assert model.get_tensor_layout(iname) == data_layout.NHWC
diff --git a/tests/transformation/test_renaming.py b/tests/transformation/test_renaming.py
index 491ccb0..b082a36 100644
--- a/tests/transformation/test_renaming.py
+++ b/tests/transformation/test_renaming.py
@@ -26,9 +26,13 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import pytest
+
import numpy as np
import onnx
import onnx.numpy_helper as np_helper
+import os
+import urllib.request as ureq
from pkgutil import get_data
import finn.core.onnx_exec as oxe
@@ -72,3 +76,26 @@ def test_renaming():
assert np.isclose(
np_helper.to_array(output_tensor), output_dict["global_out"], atol=1e-3
).all()
+
+
+def test_rename_multi_io_tinyyolov3():
+ download_url = (
+ "https://github.com/onnx/models/raw/master/vision/object_detection_segmentation"
+ )
+ download_url += "/tiny-yolov3/model/tiny-yolov3-11.onnx"
+ export_onnx_path = download_url.split("/")[-1]
+ ureq.urlretrieve(download_url, export_onnx_path)
+ if not os.path.isfile(export_onnx_path):
+ pytest.skip("Couldn't download ONNX model, skipping")
+ model = ModelWrapper(export_onnx_path)
+ model = model.transform(GiveUniqueNodeNames())
+ model = model.transform(GiveReadableTensorNames())
+ assert len(model.graph.input) == 2
+ assert model.graph.input[0].name == "global_in"
+ assert model.graph.input[1].name == "global_in_1"
+ assert len(model.graph.output) == 3
+ assert model.graph.output[0].name == "global_out"
+ assert model.graph.output[1].name == "global_out_1"
+ assert model.graph.output[2].name == "global_out_2"
+ model.save("dbg.onnx")
+ os.remove(export_onnx_path)