Skip to content

Commit

Permalink
fix slice axes dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 9a79560 commit 1b3969a
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,13 +1341,16 @@ def _impl_v10(cls, inputs, attr, params):
axes = inputs[3]
steps = inputs[4]

data_rank = len(infer_shape(inputs[0]))
ishape = infer_shape(inputs[0])
data_rank = len(ishape)

def has_static_axes():
return (isinstance(axes, _expr.Constant) and
isinstance(starts, _expr.Constant) and
isinstance(ends, _expr.Constant) and
(steps is None or isinstance(steps, _expr.Constant)))
return (
isinstance(axes, _expr.Constant)
and isinstance(starts, _expr.Constant)
and isinstance(ends, _expr.Constant)
and (steps is None or isinstance(steps, _expr.Constant))
)

# Update the starts and ends according to axes if required.
if axes is not None and has_static_axes():
Expand All @@ -1359,7 +1362,10 @@ def has_static_axes():
else:
strides_np = steps.data.asnumpy().astype("int64")

return _op.strided_slice(inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np))
if all([isinstance(ishape[i], int) for i in axes_np]):
return _op.strided_slice(
inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np)
)

# Update the starts and ends according to axes if required.
if axes is not None:
Expand Down

0 comments on commit 1b3969a

Please sign in to comment.