diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 180913640..594c625e2 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -145,7 +145,7 @@ def test_transpose_with_split(self, input_shape, perm, inner_perm): ((1, -1), (1, 1710), (1710,), [1, 0]), ((3, 1, 1, 5, -1), (3, 1, 1, 5, 6), (3, 5, 6), [0, 2, 3, 4, 1]), ]) - @check_opset_max_version(12, "split attribute changed to input in opset 13") + @check_opset_max_version(12, "split attribute changed to input since opset 13") def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, output_shape, perm): node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") node2 = helper.make_node("Split", ["Y"], ["Z"], axis=1, split=[1], name="split") @@ -162,6 +162,31 @@ def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, o self.run_transpose_compare(["B"], {"X": np.random.randn(*specific_input).astype(np.float32)}, model_proto, remaining_transpose_num=0) + @parameterized.expand([ + ((3, 1, 1), (1, 1, 3), (1), [0, 2, 3, 1]), + ((256, 1, 1), (1, 1, 256), (1), [0, 2, 3, 1]) + ]) + @check_opset_min_version(13, "split attribute changed to input since opset 13") + def test_transpose_with_split_opset13(self, input_shape, output_shape, split_val, perm): + unsqueeze_axes = self._make_onnx_const(np.array([0], dtype=np.int64), "axes1") + unsqueeze = helper.make_node("Unsqueeze", ["X", "axes1"], ["Y"], name="unsqueeze") + trans = helper.make_node("Transpose", ["Y"], ["Z"], perm=perm, name="trans") + split_attr = self._make_onnx_const(np.array([split_val], dtype=np.int64), "split_attr") + split = helper.make_node("Split", ["Z", "split_attr"], ["A"], axis=0, name="split") + squeeze_axes = self._make_onnx_const(np.array([1], dtype=np.int64), "axes2") + squeeze = helper.make_node("Squeeze", ["A", "axes2"], ["B"], name="squeeze") + + graph = helper.make_graph( + [unsqueeze_axes, unsqueeze, trans, split_attr, split, squeeze_axes, squeeze], + "test_transpose_with_split_opset13", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)], + [helper.make_tensor_value_info("B", TensorProto.FLOAT, output_shape)], + ) + + model_proto = self.make_model(graph, producer_name="onnx-tests") + self.run_transpose_compare(["B"], {"X": np.random.randn(*input_shape).astype(np.float32)}, + model_proto, remaining_transpose_num=0) + @parameterized.expand([ ((2, 3, 4), [2, 0, 1], [1, 2, 0]), ((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]), @@ -717,7 +742,7 @@ def test_transpose_sqrt(self, shape, perm_input, perm_output): ((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]), ((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]), ]) - @check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze1(self, input_shape, output_shape, perm, expected_perm): # squeeze the first dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -768,7 +793,7 @@ def test_transpose_with_unsqueeze(self, input_shape, output_shape, perm, axes_va ((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]), ((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]), ]) - @check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze1_13(self, input_shape, output_shape, perm, expected_perm): # squeeze the first dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -791,7 +816,7 @@ def test_transpose_with_squeeze1_13(self, input_shape, output_shape, perm, expec ((3, 4, 1, 5), (3, 5, 4), [0, 2, 3, 1], [0, 2, 1]), ((3, 4, 1, 5, 6), (3, 5, 6, 4), [0, 2, 3, 4, 1], [0, 2, 3, 1]), ]) - @check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze2(self, input_shape, output_shape, perm, expected_perm): # squeeze the second dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -813,7 +838,7 @@ def test_transpose_with_squeeze2(self, input_shape, output_shape, perm, expected ((3, 4, 1, 5), (3, 5, 4), [0, 2, 3, 1], [0, 2, 1]), ((3, 4, 1, 5, 6), (3, 5, 6, 4), [0, 2, 3, 4, 1], [0, 2, 3, 1]), ]) - @check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze2_13(self, input_shape, output_shape, perm, expected_perm): # squeeze the second dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -836,7 +861,7 @@ def test_transpose_with_squeeze2_13(self, input_shape, output_shape, perm, expec ((3, 1, 4, 5), (3, 4, 5), [0, 2, 3, 1]), ((3, 1, 4, 5, 6), (3, 4, 5, 6), [0, 2, 3, 4, 1]), ]) - @check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze3(self, input_shape, output_shape, perm): # squeeze the last dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -857,7 +882,7 @@ def test_transpose_with_squeeze3(self, input_shape, output_shape, perm): ((3, 1, 4, 5), (3, 4, 5), [0, 2, 3, 1]), ((3, 1, 4, 5, 6), (3, 4, 5, 6), [0, 2, 3, 4, 1]), ]) - @check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze3_13(self, input_shape, output_shape, perm): # squeeze the last dim node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -879,7 +904,7 @@ def test_transpose_with_squeeze3_13(self, input_shape, output_shape, perm): ((3, 1, 1, 5), (3, 5), [0, 2, 3, 1]), ((3, 1, 1, 5, 4), (3, 5, 4), [0, 2, 3, 4, 1]), ]) - @check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze4(self, input_shape, output_shape, perm): # squeeze the two dims node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -900,7 +925,7 @@ def test_transpose_with_squeeze4(self, input_shape, output_shape, perm): ((3, 1, 1, 5), (3, 5), [0, 2, 3, 1]), ((3, 1, 1, 5, 4), (3, 5, 4), [0, 2, 3, 4, 1]), ]) - @check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13") def test_transpose_with_squeeze4_13(self, input_shape, output_shape, perm): # squeeze the two dims node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") @@ -2156,7 +2181,7 @@ def test_const_fold_concat(self): self.run_and_compare(["res"], {"inp": np.random.randn(6, 12).astype(np.float32)}, model_proto, "Concat", 0) - @check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13") def test_const_fold_unsqueeze_with_const(self): shape = (6, 6) const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape, @@ -2176,7 +2201,7 @@ def test_const_fold_unsqueeze_with_const(self): self.run_and_compare(["res"], {"X": np.random.randn(1).astype(np.float32)}, model_proto, "Unsqueeze", 0) - @check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13") + @check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13") def test_const_fold_unsqueeze_with_const_13(self): shape = (6, 6) const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape, @@ -2254,7 +2279,7 @@ def test_const_fold_split_one(self): self.run_and_compare(["out4"], {"inp": np.random.randn(2, 6, 1).astype(np.float32)}, model_proto, "Split", 0) - @check_opset_min_version(13, "Split changed in opset 13") + @check_opset_min_version(13, "Split changed since opset 13") def test_const_fold_split_const_splits_13(self): shape = (2, 6, 1) const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape, @@ -2277,7 +2302,7 @@ def test_const_fold_split_const_splits_13(self): self.run_and_compare(["out4"], {"inp": np.random.randn(2, 3, 1).astype(np.float32)}, model_proto, "Split", 0) - @check_opset_max_version(12, "Split changed in opset 13") + @check_opset_max_version(12, "Split changed since opset 13") def test_const_fold_split_const_splits(self): shape = (2, 6, 1) const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape, diff --git a/tf2onnx/optimizer/transpose_optimizer.py b/tf2onnx/optimizer/transpose_optimizer.py index dd5377f0a..1a82a11fa 100644 --- a/tf2onnx/optimizer/transpose_optimizer.py +++ b/tf2onnx/optimizer/transpose_optimizer.py @@ -671,11 +671,19 @@ def _concat_handler(self, trans, node): def _split_handler(self, trans, node): # Todo: need handle cases where Split node has more than 1 outputs. + split = None + if self._g.opset >= 13 and len(node.input) > 1 and node.inputs[1].is_const(): + # split is an input not attr since opset 13 + split = node.inputs[1].get_tensor_value(as_list=True) if self._handle_node_having_branches(trans, node): perm = trans.get_attr_value("perm") axis = node.get_attr_value("axis", 0) new_axis = perm[axis] node.set_attr("axis", new_axis) + if split: + new_axes_np = np.array(split, dtype=np.int64) + new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np) + self._g.replace_inputs(node, [node.input[0], new_axes_const.output[0]]) return True return False @@ -747,7 +755,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes): shape_after_trans = [input_shape[i] for i in ori_perm] output_shape = [shape_after_trans[i] for i in range(n) if i not in ori_squeeze_axes] # calculate new_perm - # after switch, the output shape should be same, using this condtion we can figure the new perm + # after switch, the output shape should be same, using this condition we can figure the new perm shape_after_squeeze = [input_shape[i] for i in range(n) if i not in new_squeeze_axes] new_perm = [shape_after_squeeze.index(i) for i in output_shape] @@ -757,7 +765,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes): return False axes = None - # in opset 13, axes is an input not attr + # axes is an input not attr since opset 13 if node.get_attr("axes"): axes = node.get_attr("axes").ints if len(node.input) > 1 and node.inputs[1].is_const():