Skip to content

Commit

Permalink
[ONNX] Onnx node tests (#7720)
Browse files Browse the repository at this point in the history
* WIP

* some fixes

* more fixes

* fix some conv_transpose tests

* fix out of bounds slice

* fix flatten import

* fix logsoftmax and softmax tests

* fix Error in Upsample

* fix onehot

* normalize errors

* fix gather with negative indices

* parameterize test

* skip unsupported tests

* clean up

* fix rebase

* fix lint

* add an error message when we find an un-identified tensor
  • Loading branch information
Matthew Brookhart committed Mar 24, 2021
1 parent 6f0a656 commit 8131364
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 34 deletions.
133 changes: 101 additions & 32 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ def get_numpy(tensor_proto):
def get_type(elem_type):
"""Converts onnx integer datatype to numpy datatype"""
try:
from onnx import TensorProto
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
except ImportError as e:
raise ImportError("Unable to import onnx which is required {}".format(e))
return TensorProto.DataType.Name(elem_type).lower()

return str(TENSOR_TYPE_TO_NP_TYPE[elem_type])


def get_info(info_proto):
Expand Down Expand Up @@ -157,14 +158,16 @@ def revert_caffe2_pad(pads):
return pads


def get_pad_pair(input1d, kernel1d, stride1d):
def get_pad_pair(input1d, kernel1d, stride1d, mode):
"""infer pad size"""
if input1d % stride1d == 0:
pad = max(kernel1d - stride1d, 0)
else:
pad = max(kernel1d - (input1d % stride1d), 0)
pad_before = pad // 2
pad_after = pad - pad_before
if "LOWER" in mode:
return [pad_after, pad_before]
return [pad_before, pad_after]


Expand Down Expand Up @@ -280,9 +283,9 @@ def _impl_v1(cls, inputs, attr, params):
pad_tuple = []
for axis in range(len(input_shape) - 2):
axis_shape = input_shape[2 + axis]
stride = attr["strides"][axis]
stride = attr.get("strides", [1] * ndim)[axis]
kernel = attr["kernel_shape"][axis]
pad = get_pad_pair(axis_shape, kernel, stride)
pad = get_pad_pair(axis_shape, kernel, stride, attr["auto_pad"])
pad_tuple.append(pad)
pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
attr["pads"] = pad_tuple
Expand Down Expand Up @@ -444,9 +447,15 @@ class ConvTranspose(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
channels = infer_channels(inputs[1], True)
out_type = infer_type(inputs[1])
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
channels = out_shapes[0][1]
attr["channels"] = channels
groups = attr.get("group", 1)

if "kernel_shape" not in attr:
attr["kernel_shape"] = out_shapes[0][2:]

attr["groups"] = groups
# infer pads for auto_pad
data = inputs[0]
Expand Down Expand Up @@ -528,13 +537,11 @@ def _impl_v1(cls, inputs, attr, params):
if not transB:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
inputs[0] = _op.nn.batch_flatten(inputs[0])

if alpha != 1.0:
inputs[0] *= _expr.const(alpha)
out = _op.nn.dense(inputs[0], inputs[1], units=channels)

if len(inputs) == 3:
return _op.nn.bias_add(out, _expr.const(beta) * inputs[2])
out = out + _expr.const(beta) * inputs[2]
return out


Expand Down Expand Up @@ -618,7 +625,7 @@ def _impl_v1(cls, inputs, attr, params):
# Note: attr['fmod'] determines whether the operator should behave like np.fmod or np.mod.
# attr['fmod'] == 0 will behave as np.mod and attr['fmod'] == 1 will force fmod treatment.
# The relay equivalent of np.fmod is relay.mod and np.mod is relay.floor_mod
if attr["fmod"] == 0:
if attr.get("fmod", 0) == 0:
op_name = "floor_mod"
else:
op_name = "mod"
Expand Down Expand Up @@ -849,12 +856,18 @@ class Flatten(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axis", 1)
ishape = _op.shape_of(inputs[0])
ndim = infer_shape(ishape)[0]
if axis < 0:
axis = axis + ndim

if axis == 1:
out = _op.nn.batch_flatten(inputs[0])
else:
newshape = [0] * (axis + 1)
newshape[axis] = -1
out = _op.reshape(inputs[0], list(newshape))
pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True)
post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True)
newshape = _op.concatenate([pre_shape, post_shape], axis=0)
out = _op.reshape(inputs[0], newshape)
return out


Expand Down Expand Up @@ -1036,7 +1049,7 @@ def _impl_v9(cls, inputs, attr, params):

# in 3d case, we use the purely static op
if dims == 5:
if isinstance(scales, _expr.Call):
if isinstance(scales, _expr.Expr):
scale_h = _op.take(scales, _op.const(3))
scale_w = _op.take(scales, _op.const(4))
scale_d = _op.take(scales, _op.const(1))
Expand All @@ -1052,7 +1065,7 @@ def _impl_v9(cls, inputs, attr, params):
)
# in 2d case, use dynamic op
else:
if isinstance(scales, _expr.Call):
if isinstance(scales, _expr.Expr):
scale_h = _op.take(scales, _op.const(3))
scale_w = _op.take(scales, _op.const(4))
else:
Expand Down Expand Up @@ -1247,7 +1260,13 @@ class Gather(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axis", 0)
return AttrCvt("take", extras={"axis": axis})(inputs, {})
data = inputs[0]
indices = inputs[1]
ind_dtype = infer_type(indices).checked_type.dtype
# Normalize the indices to a positive range
s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis))
indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices)
return _op.take(data, indices, axis)


class GatherElements(OnnxOpConverter):
Expand All @@ -1258,6 +1277,10 @@ def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
indices = inputs[1]
axis = attr.get("axis", 0)
ind_dtype = infer_type(indices).checked_type.dtype
# Normalize the indices to a positive range
s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis))
indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices)
return _op.gather(data, axis, indices)


Expand Down Expand Up @@ -1318,8 +1341,8 @@ class Maximum(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
if len(inputs) == 1:
return inputs[0]
_max = inputs[0]
for i in range(1, len(inputs)):
_max = AttrCvt("maximum")([_max, inputs[i]], {})
Expand All @@ -1331,8 +1354,8 @@ class Minimum(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
if len(inputs) == 1:
return inputs[0]
_min = inputs[0]
for i in range(1, len(inputs)):
_min = AttrCvt("minimum")([_min, inputs[i]], {})
Expand All @@ -1344,8 +1367,8 @@ class Mean(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
if len(inputs) == 1:
return inputs[0]
# avoid overflow
concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
return _op.mean(concat, axis=0, keepdims=False)
Expand Down Expand Up @@ -1485,6 +1508,8 @@ class ArgMax(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
if "select_last_index" in attr:
raise NotImplementedError("select_last_index not supported in ArgMax")
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", True)
attr = {"axis": axis, "keepdims": keepdims}
Expand All @@ -1496,6 +1521,8 @@ class ArgMin(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
if "select_last_index" in attr:
raise NotImplementedError("select_last_index not supported in ArgMin")
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", True)
attr = {"axis": axis, "keepdims": keepdims}
Expand All @@ -1510,7 +1537,35 @@ def _impl_v1(cls, inputs, attr, params):
# set default value when axis is not set in the model
if "axis" not in attr:
attr["axis"] = 1
return AttrCvt("softmax", transforms={"axis": ("axis", 1)})(inputs, attr, params)
axis = attr["axis"]
ndim = len(infer_shape(inputs[0]))
if axis < 0:
axis += ndim
axes = list(range(axis, ndim))
x = inputs[0]
m = _op.max(x, axes, keepdims=True)
e = _op.exp(x - m)
return e / _op.sum(e, axes, keepdims=True)


class LogSoftmax(OnnxOpConverter):
"""Operator converter for Softmax."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
# set default value when axis is not set in the model
if "axis" not in attr:
attr["axis"] = 1
axis = attr["axis"]
ndim = len(infer_shape(inputs[0]))
if axis < 0:
axis += ndim
axes = list(range(axis, ndim))
x = inputs[0]
m = _op.max(x, axes, keepdims=True)
e = _op.exp(x - m)
s = _op.sum(e, axes, keepdims=True)
return x - m - _op.log(s)


class OneHot(OnnxOpConverter):
Expand All @@ -1520,14 +1575,24 @@ class OneHot(OnnxOpConverter):
def _impl_v9(cls, inputs, attr, params):
# Extract relay one_hot inputs.
indices, depth, values = inputs
ndim = len(infer_shape(indices))
# Split onnx on off values into two separate expressions.
off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1))
# Extract the datatype of the output from on_value.
dtype = infer_type(on_value).checked_type.dtype
ind_dtype = infer_type(indices).checked_type.dtype
# Normalize the indices to a positive range
indices = _op.where(
indices < _op.const(0, ind_dtype), indices + _op.cast(depth, ind_dtype), indices
)
# set default value when axis is not set in the model
if "axis" not in attr:
attr["axis"] = -1
return _op.one_hot(indices, on_value, off_value, depth, int(attr["axis"]), dtype=dtype)
axis = attr["axis"]
if axis < 0:
axis += ndim + 1

return _op.one_hot(indices, on_value, off_value, depth, axis, dtype=dtype)


class ConstantOfShape(OnnxOpConverter):
Expand All @@ -1552,7 +1617,7 @@ class Constant(OnnxOpConverter):
@classmethod
def _impl_v9(cls, inputs, attr, params):
if "value" not in attr:
raise "No Value in Constant"
raise tvm.errors.OpAttributeRequired("no value in Constant")
np_value = get_numpy(attr.pop("value"))
dtype = np_value.dtype.name
value = _expr.const(np_value, dtype)
Expand Down Expand Up @@ -2042,7 +2107,7 @@ def _impl_v1(cls, inputs, attr, params):
largest = attr.get("largest", 1)

if largest == 0:
raise ValueError("TVM only supports finding TopK largest elements")
raise NotImplementedError("TVM only supports finding TopK largest elements")

return _op.topk(inputs[0], inputs[1], axis=axis, dtype="int64")

Expand Down Expand Up @@ -2087,7 +2152,7 @@ def _impl_v1(cls, inputs, attr, params):
batch_indices = inputs[2]
mode = attr.get("mode", b"avg")
if mode not in (b"avg", b"max"):
raise ValueError("RoiAlign in Relay only uses avg and max modes")
raise NotImplementedError("RoiAlign in Relay only uses avg and max modes")
output_height = attr.get("output_height", 1)
output_width = attr.get("output_width", 1)

Expand Down Expand Up @@ -2128,7 +2193,8 @@ def _impl_v11(cls, inputs, attr, params):
result = inputs[0]
for i, op in enumerate([_op.tensor.maximum, _op.tensor.minimum]):
if i < len(inputs) - 1:
result = op(result, inputs[i + 1])
if inputs[i + 1] is not None:
result = op(result, inputs[i + 1])
return result


Expand Down Expand Up @@ -2393,9 +2459,10 @@ def _impl_v10(cls, inputs, attr, params):
dtype = infer_type(boxes).checked_type.dtype

if "center_point_box" in attr:
assert (
attr["center_point_box"] == 0
), "Only support center_point_box = 0 in onnx importer right now"
if attr["center_point_box"] != 0:
raise NotImplementedError(
"Only support center_point_box = 0 in ONNX NonMaxSuprresion"
)

if iou_threshold is None:
iou_threshold = _expr.const(0.0, dtype="float32")
Expand Down Expand Up @@ -2718,7 +2785,7 @@ def _get_convert_map(opset):
"Softplus": Softplus.get_converter(opset),
# softmax default axis is different in onnx
"Softmax": Softmax.get_converter(opset),
"LogSoftmax": AttrCvt("log_softmax", {"axis": ("axis", 1)}),
"LogSoftmax": LogSoftmax.get_converter(opset),
"OneHot": OneHot.get_converter(opset),
# 'Hardmax'
"Softsign": Softsign.get_converter(opset),
Expand Down Expand Up @@ -2958,6 +3025,8 @@ def from_onnx(self, graph, opset, get_output_expr=False):
for i in node.input:
if i != "":
inputs[i] = self._nodes[self._renames.get(i, i)]
else:
inputs[i] = None
i_name = self._parse_value_proto(node)
node_output = self._fix_outputs(op_name, node.output)
attr["tvm_custom"] = {}
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,10 +905,13 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
normalized_begin = _make.where(
begin = _make.where(
begin < cast_like(const(0), begin), begin + cast_like(shape_of(data), begin), begin
)
return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode)
begin = _make.where(
begin >= cast_like(shape_of(data), begin), cast_like(shape_of(data), begin), begin
)
return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)
return _make.strided_slice(data, begin, end, strides, slice_mode)


Expand Down
Loading

0 comments on commit 8131364

Please sign in to comment.