Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Onnx node tests #7720

Merged
merged 17 commits into from
Mar 24, 2021
133 changes: 101 additions & 32 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ def get_numpy(tensor_proto):
def get_type(elem_type):
"""Converts onnx integer datatype to numpy datatype"""
try:
from onnx import TensorProto
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
except ImportError as e:
raise ImportError("Unable to import onnx which is required {}".format(e))
return TensorProto.DataType.Name(elem_type).lower()

return str(TENSOR_TYPE_TO_NP_TYPE[elem_type])


def get_info(info_proto):
Expand Down Expand Up @@ -157,14 +158,16 @@ def revert_caffe2_pad(pads):
return pads


def get_pad_pair(input1d, kernel1d, stride1d):
def get_pad_pair(input1d, kernel1d, stride1d, mode):
"""infer pad size"""
if input1d % stride1d == 0:
pad = max(kernel1d - stride1d, 0)
else:
pad = max(kernel1d - (input1d % stride1d), 0)
pad_before = pad // 2
pad_after = pad - pad_before
if "LOWER" in mode:
return [pad_after, pad_before]
return [pad_before, pad_after]


Expand Down Expand Up @@ -280,9 +283,9 @@ def _impl_v1(cls, inputs, attr, params):
pad_tuple = []
for axis in range(len(input_shape) - 2):
axis_shape = input_shape[2 + axis]
stride = attr["strides"][axis]
stride = attr.get("strides", [1] * ndim)[axis]
kernel = attr["kernel_shape"][axis]
pad = get_pad_pair(axis_shape, kernel, stride)
pad = get_pad_pair(axis_shape, kernel, stride, attr["auto_pad"])
pad_tuple.append(pad)
pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
attr["pads"] = pad_tuple
Expand Down Expand Up @@ -444,9 +447,15 @@ class ConvTranspose(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
channels = infer_channels(inputs[1], True)
out_type = infer_type(inputs[1])
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
channels = out_shapes[0][1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to work for layouts other than NCHW? It looks like the ONNX op doesn't specify layout in the ConvTranspose operator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONNX always assumes NCHW

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, just wanted to make sure we didn't have to worry about it!

attr["channels"] = channels
groups = attr.get("group", 1)

if "kernel_shape" not in attr:
attr["kernel_shape"] = out_shapes[0][2:]

attr["groups"] = groups
# infer pads for auto_pad
data = inputs[0]
Expand Down Expand Up @@ -528,13 +537,11 @@ def _impl_v1(cls, inputs, attr, params):
if not transB:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
inputs[0] = _op.nn.batch_flatten(inputs[0])

if alpha != 1.0:
inputs[0] *= _expr.const(alpha)
out = _op.nn.dense(inputs[0], inputs[1], units=channels)

if len(inputs) == 3:
return _op.nn.bias_add(out, _expr.const(beta) * inputs[2])
out = out + _expr.const(beta) * inputs[2]
return out


Expand Down Expand Up @@ -618,7 +625,7 @@ def _impl_v1(cls, inputs, attr, params):
# Note: attr['fmod'] determines whether the operator should behave like np.fmod or np.mod.
# attr['fmod'] == 0 will behave as np.mod and attr['fmod'] == 1 will force fmod treatment.
# The relay equivalent of np.fmod is relay.mod and np.mod is relay.floor_mod
if attr["fmod"] == 0:
if attr.get("fmod", 0) == 0:
op_name = "floor_mod"
else:
op_name = "mod"
Expand Down Expand Up @@ -849,12 +856,18 @@ class Flatten(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axis", 1)
ishape = _op.shape_of(inputs[0])
ndim = infer_shape(ishape)[0]
if axis < 0:
axis = axis + ndim

if axis == 1:
out = _op.nn.batch_flatten(inputs[0])
else:
newshape = [0] * (axis + 1)
newshape[axis] = -1
out = _op.reshape(inputs[0], list(newshape))
pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True)
post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True)
newshape = _op.concatenate([pre_shape, post_shape], axis=0)
out = _op.reshape(inputs[0], newshape)
return out


Expand Down Expand Up @@ -1036,7 +1049,7 @@ def _impl_v9(cls, inputs, attr, params):

# in 3d case, we use the purely static op
if dims == 5:
if isinstance(scales, _expr.Call):
if isinstance(scales, _expr.Expr):
scale_h = _op.take(scales, _op.const(3))
scale_w = _op.take(scales, _op.const(4))
scale_d = _op.take(scales, _op.const(1))
Expand All @@ -1052,7 +1065,7 @@ def _impl_v9(cls, inputs, attr, params):
)
# in 2d case, use dynamic op
else:
if isinstance(scales, _expr.Call):
if isinstance(scales, _expr.Expr):
scale_h = _op.take(scales, _op.const(3))
scale_w = _op.take(scales, _op.const(4))
else:
Expand Down Expand Up @@ -1247,7 +1260,13 @@ class Gather(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axis", 0)
return AttrCvt("take", extras={"axis": axis})(inputs, {})
data = inputs[0]
indices = inputs[1]
ind_dtype = infer_type(indices).checked_type.dtype
# Normalize the indices to a positive range
s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis))
indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices)
return _op.take(data, indices, axis)


class GatherElements(OnnxOpConverter):
Expand All @@ -1258,6 +1277,10 @@ def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
indices = inputs[1]
axis = attr.get("axis", 0)
ind_dtype = infer_type(indices).checked_type.dtype
# Normalize the indices to a positive range
s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis))
indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices)
return _op.gather(data, axis, indices)


Expand Down Expand Up @@ -1318,8 +1341,8 @@ class Maximum(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
if len(inputs) == 1:
return inputs[0]
_max = inputs[0]
for i in range(1, len(inputs)):
_max = AttrCvt("maximum")([_max, inputs[i]], {})
Expand All @@ -1331,8 +1354,8 @@ class Minimum(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
if len(inputs) == 1:
return inputs[0]
_min = inputs[0]
for i in range(1, len(inputs)):
_min = AttrCvt("minimum")([_min, inputs[i]], {})
Expand All @@ -1344,8 +1367,8 @@ class Mean(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
if len(inputs) == 1:
return inputs[0]
# avoid overflow
concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
return _op.mean(concat, axis=0, keepdims=False)
Expand Down Expand Up @@ -1485,6 +1508,8 @@ class ArgMax(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
if "select_last_index" in attr:
raise NotImplementedError("select_last_index not supported in ArgMax")
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", True)
attr = {"axis": axis, "keepdims": keepdims}
Expand All @@ -1496,6 +1521,8 @@ class ArgMin(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
if "select_last_index" in attr:
raise NotImplementedError("select_last_index not supported in ArgMin")
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", True)
attr = {"axis": axis, "keepdims": keepdims}
Expand All @@ -1510,7 +1537,35 @@ def _impl_v1(cls, inputs, attr, params):
# set default value when axis is not set in the model
if "axis" not in attr:
attr["axis"] = 1
return AttrCvt("softmax", transforms={"axis": ("axis", 1)})(inputs, attr, params)
axis = attr["axis"]
ndim = len(infer_shape(inputs[0]))
if axis < 0:
axis += ndim
axes = list(range(axis, ndim))
x = inputs[0]
m = _op.max(x, axes, keepdims=True)
e = _op.exp(x - m)
return e / _op.sum(e, axes, keepdims=True)


class LogSoftmax(OnnxOpConverter):
"""Operator converter for Softmax."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
# set default value when axis is not set in the model
if "axis" not in attr:
attr["axis"] = 1
axis = attr["axis"]
ndim = len(infer_shape(inputs[0]))
if axis < 0:
axis += ndim
axes = list(range(axis, ndim))
x = inputs[0]
m = _op.max(x, axes, keepdims=True)
e = _op.exp(x - m)
s = _op.sum(e, axes, keepdims=True)
return x - m - _op.log(s)


class OneHot(OnnxOpConverter):
Expand All @@ -1520,14 +1575,24 @@ class OneHot(OnnxOpConverter):
def _impl_v9(cls, inputs, attr, params):
# Extract relay one_hot inputs.
indices, depth, values = inputs
ndim = len(infer_shape(indices))
# Split onnx on off values into two separate expressions.
off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1))
# Extract the datatype of the output from on_value.
dtype = infer_type(on_value).checked_type.dtype
ind_dtype = infer_type(indices).checked_type.dtype
# Normalize the indices to a positive range
indices = _op.where(
indices < _op.const(0, ind_dtype), indices + _op.cast(depth, ind_dtype), indices
)
# set default value when axis is not set in the model
if "axis" not in attr:
attr["axis"] = -1
return _op.one_hot(indices, on_value, off_value, depth, int(attr["axis"]), dtype=dtype)
axis = attr["axis"]
if axis < 0:
axis += ndim + 1

return _op.one_hot(indices, on_value, off_value, depth, axis, dtype=dtype)


class ConstantOfShape(OnnxOpConverter):
Expand All @@ -1552,7 +1617,7 @@ class Constant(OnnxOpConverter):
@classmethod
def _impl_v9(cls, inputs, attr, params):
if "value" not in attr:
raise "No Value in Constant"
raise tvm.errors.OpAttributeRequired("no value in Constant")
np_value = get_numpy(attr.pop("value"))
dtype = np_value.dtype.name
value = _expr.const(np_value, dtype)
Expand Down Expand Up @@ -2042,7 +2107,7 @@ def _impl_v1(cls, inputs, attr, params):
largest = attr.get("largest", 1)

if largest == 0:
raise ValueError("TVM only supports finding TopK largest elements")
raise NotImplementedError("TVM only supports finding TopK largest elements")

return _op.topk(inputs[0], inputs[1], axis=axis, dtype="int64")

Expand Down Expand Up @@ -2087,7 +2152,7 @@ def _impl_v1(cls, inputs, attr, params):
batch_indices = inputs[2]
mode = attr.get("mode", b"avg")
if mode not in (b"avg", b"max"):
raise ValueError("RoiAlign in Relay only uses avg and max modes")
raise NotImplementedError("RoiAlign in Relay only uses avg and max modes")
output_height = attr.get("output_height", 1)
output_width = attr.get("output_width", 1)

Expand Down Expand Up @@ -2128,7 +2193,8 @@ def _impl_v11(cls, inputs, attr, params):
result = inputs[0]
for i, op in enumerate([_op.tensor.maximum, _op.tensor.minimum]):
if i < len(inputs) - 1:
result = op(result, inputs[i + 1])
if inputs[i + 1] is not None:
result = op(result, inputs[i + 1])
return result


Expand Down Expand Up @@ -2393,9 +2459,10 @@ def _impl_v10(cls, inputs, attr, params):
dtype = infer_type(boxes).checked_type.dtype

if "center_point_box" in attr:
assert (
attr["center_point_box"] == 0
), "Only support center_point_box = 0 in onnx importer right now"
if attr["center_point_box"] != 0:
raise NotImplementedError(
"Only support center_point_box = 0 in ONNX NonMaxSuprresion"
)

if iou_threshold is None:
iou_threshold = _expr.const(0.0, dtype="float32")
Expand Down Expand Up @@ -2718,7 +2785,7 @@ def _get_convert_map(opset):
"Softplus": Softplus.get_converter(opset),
# softmax default axis is different in onnx
"Softmax": Softmax.get_converter(opset),
"LogSoftmax": AttrCvt("log_softmax", {"axis": ("axis", 1)}),
"LogSoftmax": LogSoftmax.get_converter(opset),
"OneHot": OneHot.get_converter(opset),
# 'Hardmax'
"Softsign": Softsign.get_converter(opset),
Expand Down Expand Up @@ -2958,6 +3025,8 @@ def from_onnx(self, graph, opset, get_output_expr=False):
for i in node.input:
if i != "":
inputs[i] = self._nodes[self._renames.get(i, i)]
else:
inputs[i] = None
i_name = self._parse_value_proto(node)
node_output = self._fix_outputs(op_name, node.output)
attr["tvm_custom"] = {}
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,10 +905,13 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
normalized_begin = _make.where(
begin = _make.where(
begin < cast_like(const(0), begin), begin + cast_like(shape_of(data), begin), begin
)
return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode)
begin = _make.where(
begin >= cast_like(shape_of(data), begin), cast_like(shape_of(data), begin), begin
)
return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)
return _make.strided_slice(data, begin, end, strides, slice_mode)


Expand Down
Loading