From 813136401a11a49d6c15e6013c34dd822a5c4ff6 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 23 Mar 2021 20:40:32 -0600 Subject: [PATCH] [ONNX] Onnx node tests (#7720) * WIP * some fixes * more fixes * fix some conv_transpose tests * fix out of bounds slice * fix flatten import * fix logsoftmax and softmax tests * fix Error in Upsample * fix onehot * normalize errors * fix gather with negative indices * parameterize test * skip unsupported tests * clean up * fix rebase * fix lint * add an error message when we find an un-identified tensor --- python/tvm/relay/frontend/onnx.py | 133 +++++++++++++---- python/tvm/relay/op/transform.py | 7 +- tests/python/frontend/onnx/test_forward.py | 163 +++++++++++++++++++++ 3 files changed, 269 insertions(+), 34 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index fab4ae889dd7..d9fc2ff99a76 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -103,10 +103,11 @@ def get_numpy(tensor_proto): def get_type(elem_type): """Converts onnx integer datatype to numpy datatype""" try: - from onnx import TensorProto + from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE except ImportError as e: raise ImportError("Unable to import onnx which is required {}".format(e)) - return TensorProto.DataType.Name(elem_type).lower() + + return str(TENSOR_TYPE_TO_NP_TYPE[elem_type]) def get_info(info_proto): @@ -157,7 +158,7 @@ def revert_caffe2_pad(pads): return pads -def get_pad_pair(input1d, kernel1d, stride1d): +def get_pad_pair(input1d, kernel1d, stride1d, mode): """infer pad size""" if input1d % stride1d == 0: pad = max(kernel1d - stride1d, 0) @@ -165,6 +166,8 @@ def get_pad_pair(input1d, kernel1d, stride1d): pad = max(kernel1d - (input1d % stride1d), 0) pad_before = pad // 2 pad_after = pad - pad_before + if "LOWER" in mode: + return [pad_after, pad_before] return [pad_before, pad_after] @@ -280,9 +283,9 @@ def _impl_v1(cls, inputs, attr, params): pad_tuple = [] for axis in range(len(input_shape) - 2): axis_shape = input_shape[2 + axis] - stride = attr["strides"][axis] + stride = attr.get("strides", [1] * ndim)[axis] kernel = attr["kernel_shape"][axis] - pad = get_pad_pair(axis_shape, kernel, stride) + pad = get_pad_pair(axis_shape, kernel, stride, attr["auto_pad"]) pad_tuple.append(pad) pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) attr["pads"] = pad_tuple @@ -444,9 +447,15 @@ class ConvTranspose(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): # get number of channels - channels = infer_channels(inputs[1], True) + out_type = infer_type(inputs[1]) + out_shapes = [get_const_tuple(out_type.checked_type.shape)] + channels = out_shapes[0][1] attr["channels"] = channels groups = attr.get("group", 1) + + if "kernel_shape" not in attr: + attr["kernel_shape"] = out_shapes[0][2:] + attr["groups"] = groups # infer pads for auto_pad data = inputs[0] @@ -528,13 +537,11 @@ def _impl_v1(cls, inputs, attr, params): if not transB: inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) inputs[0] = _op.nn.batch_flatten(inputs[0]) - if alpha != 1.0: inputs[0] *= _expr.const(alpha) out = _op.nn.dense(inputs[0], inputs[1], units=channels) - if len(inputs) == 3: - return _op.nn.bias_add(out, _expr.const(beta) * inputs[2]) + out = out + _expr.const(beta) * inputs[2] return out @@ -618,7 +625,7 @@ def _impl_v1(cls, inputs, attr, params): # Note: attr['fmod'] determines whether the operator should behave like np.fmod or np.mod. # attr['fmod'] == 0 will behave as np.mod and attr['fmod'] == 1 will force fmod treatment. # The relay equivalent of np.fmod is relay.mod and np.mod is relay.floor_mod - if attr["fmod"] == 0: + if attr.get("fmod", 0) == 0: op_name = "floor_mod" else: op_name = "mod" @@ -849,12 +856,18 @@ class Flatten(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 1) + ishape = _op.shape_of(inputs[0]) + ndim = infer_shape(ishape)[0] + if axis < 0: + axis = axis + ndim + if axis == 1: out = _op.nn.batch_flatten(inputs[0]) else: - newshape = [0] * (axis + 1) - newshape[axis] = -1 - out = _op.reshape(inputs[0], list(newshape)) + pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True) + post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True) + newshape = _op.concatenate([pre_shape, post_shape], axis=0) + out = _op.reshape(inputs[0], newshape) return out @@ -1036,7 +1049,7 @@ def _impl_v9(cls, inputs, attr, params): # in 3d case, we use the purely static op if dims == 5: - if isinstance(scales, _expr.Call): + if isinstance(scales, _expr.Expr): scale_h = _op.take(scales, _op.const(3)) scale_w = _op.take(scales, _op.const(4)) scale_d = _op.take(scales, _op.const(1)) @@ -1052,7 +1065,7 @@ def _impl_v9(cls, inputs, attr, params): ) # in 2d case, use dynamic op else: - if isinstance(scales, _expr.Call): + if isinstance(scales, _expr.Expr): scale_h = _op.take(scales, _op.const(3)) scale_w = _op.take(scales, _op.const(4)) else: @@ -1247,7 +1260,13 @@ class Gather(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 0) - return AttrCvt("take", extras={"axis": axis})(inputs, {}) + data = inputs[0] + indices = inputs[1] + ind_dtype = infer_type(indices).checked_type.dtype + # Normalize the indices to a positive range + s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis)) + indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices) + return _op.take(data, indices, axis) class GatherElements(OnnxOpConverter): @@ -1258,6 +1277,10 @@ def _impl_v1(cls, inputs, attr, params): data = inputs[0] indices = inputs[1] axis = attr.get("axis", 0) + ind_dtype = infer_type(indices).checked_type.dtype + # Normalize the indices to a positive range + s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis)) + indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices) return _op.gather(data, axis, indices) @@ -1318,8 +1341,8 @@ class Maximum(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: - raise ValueError("Expect minimum 2 inputs") + if len(inputs) == 1: + return inputs[0] _max = inputs[0] for i in range(1, len(inputs)): _max = AttrCvt("maximum")([_max, inputs[i]], {}) @@ -1331,8 +1354,8 @@ class Minimum(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: - raise ValueError("Expect minimum 2 inputs") + if len(inputs) == 1: + return inputs[0] _min = inputs[0] for i in range(1, len(inputs)): _min = AttrCvt("minimum")([_min, inputs[i]], {}) @@ -1344,8 +1367,8 @@ class Mean(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: - raise ValueError("Expect minimum 2 inputs") + if len(inputs) == 1: + return inputs[0] # avoid overflow concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0) return _op.mean(concat, axis=0, keepdims=False) @@ -1485,6 +1508,8 @@ class ArgMax(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + if "select_last_index" in attr: + raise NotImplementedError("select_last_index not supported in ArgMax") axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} @@ -1496,6 +1521,8 @@ class ArgMin(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + if "select_last_index" in attr: + raise NotImplementedError("select_last_index not supported in ArgMin") axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} @@ -1510,7 +1537,35 @@ def _impl_v1(cls, inputs, attr, params): # set default value when axis is not set in the model if "axis" not in attr: attr["axis"] = 1 - return AttrCvt("softmax", transforms={"axis": ("axis", 1)})(inputs, attr, params) + axis = attr["axis"] + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = list(range(axis, ndim)) + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + return e / _op.sum(e, axes, keepdims=True) + + +class LogSoftmax(OnnxOpConverter): + """Operator converter for Softmax.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + # set default value when axis is not set in the model + if "axis" not in attr: + attr["axis"] = 1 + axis = attr["axis"] + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = list(range(axis, ndim)) + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + s = _op.sum(e, axes, keepdims=True) + return x - m - _op.log(s) class OneHot(OnnxOpConverter): @@ -1520,14 +1575,24 @@ class OneHot(OnnxOpConverter): def _impl_v9(cls, inputs, attr, params): # Extract relay one_hot inputs. indices, depth, values = inputs + ndim = len(infer_shape(indices)) # Split onnx on off values into two separate expressions. off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1)) # Extract the datatype of the output from on_value. dtype = infer_type(on_value).checked_type.dtype + ind_dtype = infer_type(indices).checked_type.dtype + # Normalize the indices to a positive range + indices = _op.where( + indices < _op.const(0, ind_dtype), indices + _op.cast(depth, ind_dtype), indices + ) # set default value when axis is not set in the model if "axis" not in attr: attr["axis"] = -1 - return _op.one_hot(indices, on_value, off_value, depth, int(attr["axis"]), dtype=dtype) + axis = attr["axis"] + if axis < 0: + axis += ndim + 1 + + return _op.one_hot(indices, on_value, off_value, depth, axis, dtype=dtype) class ConstantOfShape(OnnxOpConverter): @@ -1552,7 +1617,7 @@ class Constant(OnnxOpConverter): @classmethod def _impl_v9(cls, inputs, attr, params): if "value" not in attr: - raise "No Value in Constant" + raise tvm.errors.OpAttributeRequired("no value in Constant") np_value = get_numpy(attr.pop("value")) dtype = np_value.dtype.name value = _expr.const(np_value, dtype) @@ -2042,7 +2107,7 @@ def _impl_v1(cls, inputs, attr, params): largest = attr.get("largest", 1) if largest == 0: - raise ValueError("TVM only supports finding TopK largest elements") + raise NotImplementedError("TVM only supports finding TopK largest elements") return _op.topk(inputs[0], inputs[1], axis=axis, dtype="int64") @@ -2087,7 +2152,7 @@ def _impl_v1(cls, inputs, attr, params): batch_indices = inputs[2] mode = attr.get("mode", b"avg") if mode not in (b"avg", b"max"): - raise ValueError("RoiAlign in Relay only uses avg and max modes") + raise NotImplementedError("RoiAlign in Relay only uses avg and max modes") output_height = attr.get("output_height", 1) output_width = attr.get("output_width", 1) @@ -2128,7 +2193,8 @@ def _impl_v11(cls, inputs, attr, params): result = inputs[0] for i, op in enumerate([_op.tensor.maximum, _op.tensor.minimum]): if i < len(inputs) - 1: - result = op(result, inputs[i + 1]) + if inputs[i + 1] is not None: + result = op(result, inputs[i + 1]) return result @@ -2393,9 +2459,10 @@ def _impl_v10(cls, inputs, attr, params): dtype = infer_type(boxes).checked_type.dtype if "center_point_box" in attr: - assert ( - attr["center_point_box"] == 0 - ), "Only support center_point_box = 0 in onnx importer right now" + if attr["center_point_box"] != 0: + raise NotImplementedError( + "Only support center_point_box = 0 in ONNX NonMaxSuprresion" + ) if iou_threshold is None: iou_threshold = _expr.const(0.0, dtype="float32") @@ -2718,7 +2785,7 @@ def _get_convert_map(opset): "Softplus": Softplus.get_converter(opset), # softmax default axis is different in onnx "Softmax": Softmax.get_converter(opset), - "LogSoftmax": AttrCvt("log_softmax", {"axis": ("axis", 1)}), + "LogSoftmax": LogSoftmax.get_converter(opset), "OneHot": OneHot.get_converter(opset), # 'Hardmax' "Softsign": Softsign.get_converter(opset), @@ -2958,6 +3025,8 @@ def from_onnx(self, graph, opset, get_output_expr=False): for i in node.input: if i != "": inputs[i] = self._nodes[self._renames.get(i, i)] + else: + inputs[i] = None i_name = self._parse_value_proto(node) node_output = self._fix_outputs(op_name, node.output) attr["tvm_custom"] = {} diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 4129b610cb7c..df0ae767460a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -905,10 +905,13 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): end = const(list(end)) if isinstance(strides, (tuple, list)): strides = const(list(strides)) - normalized_begin = _make.where( + begin = _make.where( begin < cast_like(const(0), begin), begin + cast_like(shape_of(data), begin), begin ) - return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode) + begin = _make.where( + begin >= cast_like(shape_of(data), begin), cast_like(shape_of(data), begin), begin + ) + return _dyn_make.strided_slice(data, begin, end, strides, slice_mode) return _make.strided_slice(data, begin, end, strides, slice_mode) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 5a6216ac705d..ec89a3d844d1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4090,6 +4090,169 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_cumsum(data, 1, 1, 1, type="int32") +from onnx import numpy_helper + +f = onnx.__file__ +import glob + +onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) + +unsupported_onnx_tests = [ + "test_basic_convinteger/", + "test_bitshift_left_uint16/", + "test_bitshift_left_uint32/", + "test_bitshift_left_uint64/", + "test_bitshift_left_uint8/", + "test_bitshift_right_uint16/", + "test_bitshift_right_uint32/", + "test_bitshift_right_uint64/", + "test_bitshift_right_uint8/", + "test_cast_DOUBLE_to_FLOAT16/", + "test_cast_FLOAT16_to_DOUBLE/", + "test_cast_FLOAT16_to_FLOAT/", + "test_cast_FLOAT_to_FLOAT16/", + "test_cast_FLOAT_to_STRING/", + "test_cast_STRING_to_FLOAT/", + "test_compress_0/", + "test_compress_1/", + "test_compress_default_axis/", + "test_compress_negative_axis/", + "test_convinteger_with_padding/", + "test_convtranspose_dilations/", + "test_convtranspose_output_shape/", + "test_cumsum_1d/", + "test_cumsum_1d_exclusive/", + "test_cumsum_1d_reverse/", + "test_cumsum_1d_reverse_exclusive/", + "test_cumsum_2d_axis_0/", + "test_cumsum_2d_axis_1/", + "test_cumsum_2d_negative_axis/", + "test_dequantizelinear/", + "test_det_2d/", + "test_det_nd/", + "test_dynamicquantizelinear/", + "test_dynamicquantizelinear_expanded/", + "test_dynamicquantizelinear_max_adjusted/", + "test_dynamicquantizelinear_max_adjusted_expanded/", + "test_dynamicquantizelinear_min_adjusted/", + "test_dynamicquantizelinear_min_adjusted_expanded/", + "test_eyelike_populate_off_main_diagonal/", + "test_eyelike_with_dtype/", + "test_eyelike_without_dtype/", + "test_hardmax_axis_0/", + "test_hardmax_axis_1/", + "test_hardmax_axis_2/", + "test_hardmax_default_axis/", + "test_hardmax_example/", + "test_hardmax_negative_axis/", + "test_hardmax_one_hot/", + "test_isinf_negative/", + "test_isinf_positive/", + "test_lstm_defaults/", + "test_lstm_with_initial_bias/", + "test_lstm_with_peepholes/", + "test_matmulinteger/", + "test_maxpool_2d_dilations/", + "test_maxpool_2d_same_lower/", + "test_maxpool_2d_same_upper/", + "test_maxpool_with_argmax_2d_precomputed_pads/", + "test_maxpool_with_argmax_2d_precomputed_strides/", + "test_maxunpool_export_with_output_shape/", + "test_mvn/", + "test_nonmaxsuppression_center_point_box_format/", + "test_qlinearconv/", + "test_qlinearmatmul_2D/", + "test_qlinearmatmul_3D/", + "test_quantizelinear/", + "test_range_float_type_positive_delta_expanded/", + "test_range_int32_type_negative_delta_expanded/", + "test_resize_downsample_scales_cubic/", + "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside/", + "test_resize_downsample_scales_cubic_align_corners/", + "test_resize_downsample_scales_linear/", + "test_resize_downsample_scales_nearest/", + "test_resize_downsample_sizes_cubic/", + "test_resize_downsample_sizes_linear_pytorch_half_pixel/", + "test_resize_downsample_sizes_nearest/", + "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn/", + "test_resize_tf_crop_and_resize/", + "test_resize_upsample_scales_cubic/", + "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside/", + "test_resize_upsample_scales_cubic_align_corners/", + "test_resize_upsample_scales_cubic_asymmetric/", + "test_resize_upsample_scales_linear/", + "test_resize_upsample_sizes_cubic/", + "test_resize_upsample_sizes_nearest_ceil_half_pixel/", + "test_resize_upsample_sizes_nearest_floor_align_corners/", + "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/", + "test_reversesequence_batch/", + "test_reversesequence_time/", + "test_rnn_seq_length/", + "test_roialign/", + "test_round/", + "test_scan9_sum/", + "test_scan_sum/", + "test_scatternd/", + "test_selu_default/", + "test_shrink_hard/", + "test_shrink_soft/", + "test_simple_rnn_defaults/", + "test_simple_rnn_with_initial_bias/", + "test_slice_neg_steps/", + "test_slice_start_out_of_bounds/", + "test_strnormalizer_export_monday_casesensintive_lower/", + "test_strnormalizer_export_monday_casesensintive_nochangecase/", + "test_strnormalizer_export_monday_casesensintive_upper/", + "test_strnormalizer_export_monday_empty_output/", + "test_strnormalizer_export_monday_insensintive_upper_twodim/", + "test_strnormalizer_nostopwords_nochangecase/", + "test_tfidfvectorizer_tf_batch_onlybigrams_skip0/", + "test_tfidfvectorizer_tf_batch_onlybigrams_skip5/", + "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5/", + "test_tfidfvectorizer_tf_only_bigrams_skip0/", + "test_tfidfvectorizer_tf_onlybigrams_levelempty/", + "test_tfidfvectorizer_tf_onlybigrams_skip5/", + "test_tfidfvectorizer_tf_uniandbigrams_skip5/", + "test_top_k_smallest/", + "test_unique_not_sorted_without_axis/", + "test_unique_sorted_with_axis/", + "test_unique_sorted_with_axis_3d/", + "test_unique_sorted_with_negative_axis/", + "test_unique_sorted_without_axis/", + "test_unsqueeze_unsorted_axes/", + "test_upsample_nearest/", +] + + +@pytest.mark.parametrize("test", onnx_test_folders) +def test_onnx_nodes(test): + for failure in unsupported_onnx_tests: + if failure in test: + pytest.skip() + break + onnx_model = onnx.load(test + "/model.onnx") + inputs = [] + outputs = [] + for dataset in glob.glob(test + "/*/"): + tensors = sorted(glob.glob(dataset + "/*.pb")) + for tensor in tensors: + new_tensor = onnx.TensorProto() + with open(tensor, "rb") as f: + new_tensor.ParseFromString(f.read()) + if "input" in tensor.split("/")[-1]: + inputs.append(numpy_helper.to_array(new_tensor)) + elif "output" in tensor.split("/")[-1]: + outputs.append(numpy_helper.to_array(new_tensor)) + else: + raise ImportError(str(tensor) + " not labeled as an import or an output") + tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0)) + if len(outputs) == 1: + tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5) + else: + for output, val in zip(outputs, tvm_val): + tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5) + + def test_wrong_input(): node = helper.make_node( "Softplus",