Skip to content

Commit

Permalink
[Relay] Fix bug in transpose_shape_func (apache#6180)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored and Trevor Morris committed Aug 26, 2020
1 parent bb68693 commit 3e3980a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,10 @@ def transpose_shape_func(attrs, inputs, _):
if axes is None:
axes = list(range(inputs[0].shape[0].value))
axes.reverse()
axes = list(axes)
for i, axis in enumerate(axes):
if axis < 0:
axes[i] = inputs[0].shape[0] - axis
axes[i] = inputs[0].shape[0] + axis
return [_transpose_shape_func(inputs[0], convert(axes))]

@script
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def test_any_transpose():
verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2))
verify_any_transpose(any_dims(3), None, (2, 3, 4))
verify_any_transpose(any_dims(6), (0, 1, 3, 2, 5, 4), (11, 12, 2, 1, 9, 17))
verify_any_transpose(any_dims(2), (-1, 0), (3, 2))

def verify_any_squeeze(data_shape, axis, static_data_shape):
mod = tvm.IRModule()
Expand Down

0 comments on commit 3e3980a

Please sign in to comment.