Skip to content

Commit

Permalink
Fix problem with adding more than one tf.newaxis at the same time (#2007
Browse files Browse the repository at this point in the history
)

Signed-off-by: southfreebird <[email protected]>

Co-authored-by: iolkhovsky <[email protected]>
  • Loading branch information
southfreebird and iolkhovsky authored Jul 27, 2022
1 parent 404e2b7 commit 1c7d4ce
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5893,5 +5893,23 @@ def func(x):
x_val = make_xval([3, 4])
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_opset_min_version(10, "Slice")
def test_addition_two_newaxis_simultaneously(self):
def func(x):
op = x[..., tf.newaxis, tf.newaxis]
return tf.identity(op, name=_TFOUTPUT)

x_val = make_xval([2, 3])
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_opset_min_version(10, "Slice")
def test_addition_three_newaxis_simultaneously(self):
def func(x):
op = x[..., tf.newaxis, tf.newaxis, tf.newaxis]
return tf.identity(op, name=_TFOUTPUT)

x_val = make_xval([2, 3])
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

if __name__ == '__main__':
unittest_main()
23 changes: 23 additions & 0 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,29 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
begin_mask |= 1 << bit
end_mask |= 1 << bit

if ellipsis_mask:
unqueeze_at = []
ellipsis_gap = 0
num_new = 0
end_mask = node.get_attr("end_mask")
end_mask = end_mask.i if end_mask is not None else 0
begin_mask = node.get_attr("begin_mask")
begin_mask = begin_mask.i if begin_mask is not None else 0

for bit in range(32):
new_axis_flag = (new_axis_mask >> bit) & 1
ellipsis_flag = (ellipsis_mask >> bit) & 1
num_new += not ellipsis_flag and new_axis_flag

for bit in range(32):
if (ellipsis_mask >> bit) & 1:
ellipsis_gap = len(ctx.get_shape(input_x)) - param_rank + num_new + 1
elif (new_axis_mask >> bit) & 1:
effective_bit = bit if not ellipsis_gap else bit + ellipsis_gap - 1
unqueeze_at.append(effective_bit)
begin_mask |= 1 << bit
end_mask |= 1 << bit

input_x = GraphBuilder(ctx).make_unsqueeze(
{'data': input_x, 'axes': unqueeze_at})

Expand Down

0 comments on commit 1c7d4ce

Please sign in to comment.