Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
Signed-off-by: Deyu Huang <[email protected]>
  • Loading branch information
hwangdeyu committed Jul 15, 2022
1 parent 7e8606f commit bee58e3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
28 changes: 14 additions & 14 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_transpose_with_split(self, input_shape, perm, inner_perm):
((1, -1), (1, 1710), (1710,), [1, 0]),
((3, 1, 1, 5, -1), (3, 1, 1, 5, 6), (3, 5, 6), [0, 2, 3, 4, 1]),
])
@check_opset_max_version(12, "split attribute changed to input in opset 13")
@check_opset_max_version(12, "split attribute changed to input since opset 13")
def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, output_shape, perm):
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
node2 = helper.make_node("Split", ["Y"], ["Z"], axis=1, split=[1], name="split")
Expand All @@ -166,7 +166,7 @@ def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, o
((3, 1, 1), (1, 1, 3), (1), [0, 2, 3, 1]),
((256, 1, 1), (1, 1, 256), (1), [0, 2, 3, 1])
])
@check_opset_min_version(13, "split attribute changed to input in opset 13")
@check_opset_min_version(13, "split attribute changed to input since opset 13")
def test_transpose_with_split_opset13(self, input_shape, output_shape, split_val, perm):
unsqueeze_axes = self._make_onnx_const(np.array([0], dtype=np.int64), "axes1")
unsqueeze = helper.make_node("Unsqueeze", ["X", "axes1"], ["Y"], name="unsqueeze")
Expand Down Expand Up @@ -742,7 +742,7 @@ def test_transpose_sqrt(self, shape, perm_input, perm_output):
((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]),
((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]),
])
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze1(self, input_shape, output_shape, perm, expected_perm):
# squeeze the first dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand Down Expand Up @@ -793,7 +793,7 @@ def test_transpose_with_unsqueeze(self, input_shape, output_shape, perm, axes_va
((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]),
((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]),
])
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze1_13(self, input_shape, output_shape, perm, expected_perm):
# squeeze the first dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -816,7 +816,7 @@ def test_transpose_with_squeeze1_13(self, input_shape, output_shape, perm, expec
((3, 4, 1, 5), (3, 5, 4), [0, 2, 3, 1], [0, 2, 1]),
((3, 4, 1, 5, 6), (3, 5, 6, 4), [0, 2, 3, 4, 1], [0, 2, 3, 1]),
])
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze2(self, input_shape, output_shape, perm, expected_perm):
# squeeze the second dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -838,7 +838,7 @@ def test_transpose_with_squeeze2(self, input_shape, output_shape, perm, expected
((3, 4, 1, 5), (3, 5, 4), [0, 2, 3, 1], [0, 2, 1]),
((3, 4, 1, 5, 6), (3, 5, 6, 4), [0, 2, 3, 4, 1], [0, 2, 3, 1]),
])
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze2_13(self, input_shape, output_shape, perm, expected_perm):
# squeeze the second dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -861,7 +861,7 @@ def test_transpose_with_squeeze2_13(self, input_shape, output_shape, perm, expec
((3, 1, 4, 5), (3, 4, 5), [0, 2, 3, 1]),
((3, 1, 4, 5, 6), (3, 4, 5, 6), [0, 2, 3, 4, 1]),
])
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze3(self, input_shape, output_shape, perm):
# squeeze the last dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -882,7 +882,7 @@ def test_transpose_with_squeeze3(self, input_shape, output_shape, perm):
((3, 1, 4, 5), (3, 4, 5), [0, 2, 3, 1]),
((3, 1, 4, 5, 6), (3, 4, 5, 6), [0, 2, 3, 4, 1]),
])
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze3_13(self, input_shape, output_shape, perm):
# squeeze the last dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -904,7 +904,7 @@ def test_transpose_with_squeeze3_13(self, input_shape, output_shape, perm):
((3, 1, 1, 5), (3, 5), [0, 2, 3, 1]),
((3, 1, 1, 5, 4), (3, 5, 4), [0, 2, 3, 4, 1]),
])
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze4(self, input_shape, output_shape, perm):
# squeeze the two dims
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -925,7 +925,7 @@ def test_transpose_with_squeeze4(self, input_shape, output_shape, perm):
((3, 1, 1, 5), (3, 5), [0, 2, 3, 1]),
((3, 1, 1, 5, 4), (3, 5, 4), [0, 2, 3, 4, 1]),
])
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze4_13(self, input_shape, output_shape, perm):
# squeeze the two dims
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand Down Expand Up @@ -2181,7 +2181,7 @@ def test_const_fold_concat(self):
self.run_and_compare(["res"], {"inp": np.random.randn(6, 12).astype(np.float32)}, model_proto,
"Concat", 0)

@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
def test_const_fold_unsqueeze_with_const(self):
shape = (6, 6)
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
Expand All @@ -2201,7 +2201,7 @@ def test_const_fold_unsqueeze_with_const(self):
self.run_and_compare(["res"], {"X": np.random.randn(1).astype(np.float32)}, model_proto,
"Unsqueeze", 0)

@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
def test_const_fold_unsqueeze_with_const_13(self):
shape = (6, 6)
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
Expand Down Expand Up @@ -2279,7 +2279,7 @@ def test_const_fold_split_one(self):
self.run_and_compare(["out4"], {"inp": np.random.randn(2, 6, 1).astype(np.float32)}, model_proto,
"Split", 0)

@check_opset_min_version(13, "Split changed in opset 13")
@check_opset_min_version(13, "Split changed since opset 13")
def test_const_fold_split_const_splits_13(self):
shape = (2, 6, 1)
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
Expand All @@ -2302,7 +2302,7 @@ def test_const_fold_split_const_splits_13(self):
self.run_and_compare(["out4"], {"inp": np.random.randn(2, 3, 1).astype(np.float32)}, model_proto,
"Split", 0)

@check_opset_max_version(12, "Split changed in opset 13")
@check_opset_max_version(12, "Split changed since opset 13")
def test_const_fold_split_const_splits(self):
shape = (2, 6, 1)
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
Expand Down
8 changes: 4 additions & 4 deletions tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,15 +672,15 @@ def _concat_handler(self, trans, node):
def _split_handler(self, trans, node):
# Todo: need handle cases where Split node has more than 1 outputs.
split = None
if len(node.input) > 1 and node.inputs[1].is_const():
if self._g.opset >= 13 and len(node.input) > 1 and node.inputs[1].is_const():
# split is an input not attr since opset 13
split = node.inputs[1].get_tensor_value(as_list=True)
if self._handle_node_having_branches(trans, node):
perm = trans.get_attr_value("perm")
axis = node.get_attr_value("axis", 0)
new_axis = perm[axis]
node.set_attr("axis", new_axis)
if self._g.opset >= 13 and split:
# in opset 13, split attr is an input not attr
if split:
new_axes_np = np.array(split, dtype=np.int64)
new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np)
self._g.replace_inputs(node, [node.input[0], new_axes_const.output[0]])
Expand Down Expand Up @@ -765,7 +765,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
return False

axes = None
# in opset 13, axes is an input not attr
# axes is an input not attr since opset 13
if node.get_attr("axes"):
axes = node.get_attr("axes").ints
if len(node.input) > 1 and node.inputs[1].is_const():
Expand Down

0 comments on commit bee58e3

Please sign in to comment.