diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index f2fc21e89..594c625e2 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -145,7 +145,7 @@ def test_transpose_with_split(self, input_shape, perm, inner_perm): ((1, -1), (1, 1710), (1710,), [1, 0]), ((3, 1, 1, 5, -1), (3, 1, 1, 5, 6), (3, 5, 6), [0, 2, 3, 4, 1]), ]) - @check_opset_max_version(12, "split attribute changed to input in opset 13") + @check_opset_max_version(12, "split attribute changed to input since opset 13") def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, output_shape, perm): node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") node2 = helper.make_node("Split", ["Y"], ["Z"], axis=1, split=[1], name="split") @@ -166,7 +166,7 @@ def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, o ((3, 1, 1), (1, 1, 3), (1), [0, 2, 3, 1]), ((256, 1, 1), (1, 1, 256), (1), [0, 2, 3, 1]) ]) - @check_opset_min_version(13, "split attribute changed to input in opset 13") + @check_opset_min_version(13, "split attribute changed to input since opset 13") def test_transpose_with_split_opset13(self, input_shape, output_shape, split_val, perm): unsqueeze_axes = self._make_onnx_const(np.array([0], dtype=np.int64), "axes1") unsqueeze = helper.make_node("Unsqueeze", ["X", "axes1"], ["Y"], name="unsqueeze") @@ -742,7 +742,7 @@ def test_transpose_sqrt(self, shape, perm_input, perm_output): ((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]), ((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]), ]) - @check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze1(self, input_shape, output_shape, perm, expected_perm): # squeeze the first dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -793,7 +793,7 @@ def test_transpose_with_unsqueeze(self, input_shape, output_shape, perm, axes_va ((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]), ((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]), ]) - @check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze1_13(self, input_shape, output_shape, perm, expected_perm): # squeeze the first dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -816,7 +816,7 @@ def test_transpose_with_squeeze1_13(self, input_shape, output_shape, perm, expec ((3, 4, 1, 5), (3, 5, 4), [0, 2, 3, 1], [0, 2, 1]), ((3, 4, 1, 5, 6), (3, 5, 6, 4), [0, 2, 3, 4, 1], [0, 2, 3, 1]), ]) - @check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze2(self, input_shape, output_shape, perm, expected_perm): # squeeze the second dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -838,7 +838,7 @@ def test_transpose_with_squeeze2(self, input_shape, output_shape, perm, expected ((3, 4, 1, 5), (3, 5, 4), [0, 2, 3, 1], [0, 2, 1]), ((3, 4, 1, 5, 6), (3, 5, 6, 4), [0, 2, 3, 4, 1], [0, 2, 3, 1]), ]) - @check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze2_13(self, input_shape, output_shape, perm, expected_perm): # squeeze the second dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -861,7 +861,7 @@ def test_transpose_with_squeeze2_13(self, input_shape, output_shape, perm, expec ((3, 1, 4, 5), (3, 4, 5), [0, 2, 3, 1]), ((3, 1, 4, 5, 6), (3, 4, 5, 6), [0, 2, 3, 4, 1]), ]) - @check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze3(self, input_shape, output_shape, perm): # squeeze the last dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -882,7 +882,7 @@ def test_transpose_with_squeeze3(self, input_shape, output_shape, perm): ((3, 1, 4, 5), (3, 4, 5), [0, 2, 3, 1]), ((3, 1, 4, 5, 6), (3, 4, 5, 6), [0, 2, 3, 4, 1]), ]) - @check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze3_13(self, input_shape, output_shape, perm): # squeeze the last dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -904,7 +904,7 @@ def test_transpose_with_squeeze3_13(self, input_shape, output_shape, perm): ((3, 1, 1, 5), (3, 5), [0, 2, 3, 1]), ((3, 1, 1, 5, 4), (3, 5, 4), [0, 2, 3, 4, 1]), ]) - @check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze4(self, input_shape, output_shape, perm): # squeeze the two dims node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -925,7 +925,7 @@ def test_transpose_with_squeeze4(self, input_shape, output_shape, perm): ((3, 1, 1, 5), (3, 5), [0, 2, 3, 1]), ((3, 1, 1, 5, 4), (3, 5, 4), [0, 2, 3, 4, 1]), ]) - @check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze4_13(self, input_shape, output_shape, perm): # squeeze the two dims node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -2181,7 +2181,7 @@ def test_const_fold_concat(self): self.run_and_compare(["res"], {"inp": np.random.randn(6, 12).astype(np.float32)}, model_proto, "Concat", 0) - @check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13") def test_const_fold_unsqueeze_with_const(self): shape = (6, 6) const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape, @@ -2201,7 +2201,7 @@ def test_const_fold_unsqueeze_with_const(self): self.run_and_compare(["res"], {"X": np.random.randn(1).astype(np.float32)}, model_proto, "Unsqueeze", 0) - @check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13") def test_const_fold_unsqueeze_with_const_13(self): shape = (6, 6) const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape, @@ -2279,7 +2279,7 @@ def test_const_fold_split_one(self): self.run_and_compare(["out4"], {"inp": np.random.randn(2, 6, 1).astype(np.float32)}, model_proto, "Split", 0) - @check_opset_min_version(13, "Split changed in opset 13") + @check_opset_min_version(13, "Split changed since opset 13") def test_const_fold_split_const_splits_13(self): shape = (2, 6, 1) const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape, @@ -2302,7 +2302,7 @@ def test_const_fold_split_const_splits_13(self): self.run_and_compare(["out4"], {"inp": np.random.randn(2, 3, 1).astype(np.float32)}, model_proto, "Split", 0) - @check_opset_max_version(12, "Split changed in opset 13") + @check_opset_max_version(12, "Split changed since opset 13") def test_const_fold_split_const_splits(self): shape = (2, 6, 1) const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape, diff --git a/tf2onnx/optimizer/transpose_optimizer.py b/tf2onnx/optimizer/transpose_optimizer.py index cffe6ef24..1a82a11fa 100644 --- a/tf2onnx/optimizer/transpose_optimizer.py +++ b/tf2onnx/optimizer/transpose_optimizer.py @@ -672,15 +672,15 @@ def _concat_handler(self, trans, node): def _split_handler(self, trans, node): # Todo: need handle cases where Split node has more than 1 outputs. split = None - if len(node.input) > 1 and node.inputs[1].is_const(): + if self._g.opset >= 13 and len(node.input) > 1 and node.inputs[1].is_const(): + # split is an input not attr since opset 13 split = node.inputs[1].get_tensor_value(as_list=True) if self._handle_node_having_branches(trans, node): perm = trans.get_attr_value("perm") axis = node.get_attr_value("axis", 0) new_axis = perm[axis] node.set_attr("axis", new_axis) - if self._g.opset >= 13 and split: - # in opset 13, split attr is an input not attr + if split: new_axes_np = np.array(split, dtype=np.int64) new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np) self._g.replace_inputs(node, [node.input[0], new_axes_const.output[0]]) @@ -765,7 +765,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes): return False axes = None - # in opset 13, axes is an input not attr + # axes is an input not attr since opset 13 if node.get_attr("axes"): axes = node.get_attr("axes").ints if len(node.input) > 1 and node.inputs[1].is_const():