Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
Signed-off-by: Deyu Huang <[email protected]>
  • Loading branch information
hwangdeyu committed Jul 15, 2022
1 parent 7e8606f commit a759597
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,15 +672,15 @@ def _concat_handler(self, trans, node):
def _split_handler(self, trans, node):
# Todo: need handle cases where Split node has more than 1 outputs.
split = None
if len(node.input) > 1 and node.inputs[1].is_const():
if self._g.opset >= 13 and len(node.input) > 1 and node.inputs[1].is_const():
# in opset 13, split is an input not attr
split = node.inputs[1].get_tensor_value(as_list=True)
if self._handle_node_having_branches(trans, node):
perm = trans.get_attr_value("perm")
axis = node.get_attr_value("axis", 0)
new_axis = perm[axis]
node.set_attr("axis", new_axis)
if self._g.opset >= 13 and split:
# in opset 13, split attr is an input not attr
if split:
new_axes_np = np.array(split, dtype=np.int64)
new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np)
self._g.replace_inputs(node, [node.input[0], new_axes_const.output[0]])
Expand Down

0 comments on commit a759597

Please sign in to comment.