Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Avoid adding unnecessary slicing #7479

Merged
merged 3 commits into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 24 additions & 31 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,23 +385,28 @@ def tensor_array_concat(lst, axis):

def slice(self, inputs, input_types):
axis_dtype = "int64"
index_size_limit = 2 ** 63 - 1
index_size_limit = sys.maxsize
data = inputs[0]
dshape = self.infer_shape(data)
ndim = len(dshape)
end = []
for dim in dshape:
if isinstance(dim, tvm.tir.Any):
end = _op.shape_of(data)
break
end.append(int(dim))

begin = [0] * ndim
dim = int(inputs[1])
stride = int(inputs[4])
begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int)))
stride = inputs[4]

target_begin, is_begin_const = try_infer_value(
inputs[2], lambda ret: np.asscalar(ret.astype(np.int))
)
target_end, is_end_const = try_infer_value(
inputs[3], lambda ret: np.asscalar(ret.astype(np.int))
)

# A fast path when slicing is nop.
if target_begin == 0 and target_end >= index_size_limit and stride == 1:
return data

# Process begin
begin = [0] * ndim
begin[dim] = target_begin

if not isinstance(begin[dim], int):
tmp = []
for b in begin:
Expand All @@ -414,27 +419,15 @@ def slice(self, inputs, input_types):
if str(btype) != axis_dtype:
begin = _op.cast(begin, axis_dtype)

if isinstance(inputs[3], str) and inputs[3].isdigit():
target_end = int(inputs[3])
# Process end
if isinstance(target_end, int) and target_end >= index_size_limit:
target_end = dshape[dim]

if any([isinstance(d, tvm.tir.Any) for d in dshape]):
end = _op.shape_of(data)
else:
if isinstance(inputs[3], _expr.Expr):
target_end, _ = try_infer_value(
inputs[3], lambda ret: np.asscalar(ret.astype(np.int))
)
else:
target_end = inputs[3]

if isinstance(target_end, int) and target_end >= index_size_limit:
# Quick path for original data.
if (
isinstance(begin, _expr.Constant)
and begin.data.asnumpy().tolist()[dim] == 0
and stride == 1
):
return data
target_end = dshape[dim]
end = dshape

# Process end
if isinstance(target_end, int):
if isinstance(end, list):
end[dim] = target_end
Expand Down Expand Up @@ -474,7 +467,7 @@ def slice(self, inputs, input_types):
end = _op.cast(end, axis_dtype)

strides = [1] * ndim
strides[dim] = int(inputs[4])
strides[dim] = stride

return _op.transform.strided_slice(
data, begin=begin, end=end, strides=strides, slice_mode="end"
Expand Down
6 changes: 1 addition & 5 deletions python/tvm/relay/frontend/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,11 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
add = is_op("add")(mx, one)
mul = is_op("multiply")(cast, add)

# The following doesn't appear in the above Relay snippet. It is required for dynamic
# stride_slice handling
shape_of = is_op("shape_of")(mul)
cast = is_op("cast")(shape_of)
# This corresponds to offsets[:, None], where offsets is the result of multiplication
dyn_strided_slice = dyn_strided_slice_pattern(mul, cast)

# Add offsets to the boxes
expand_dims = is_op("expand_dims")(dyn_strided_slice)
expand_dims = is_op("expand_dims")(mul)
add = is_op("add")(boxes, expand_dims)

# The rest of patterns correspond to the PyTorch frontend conversion
Expand Down