Skip to content

Commit

Permalink
chunk was replaced by new one in pytorch frontend. it is faster in 2.…
Browse files Browse the repository at this point in the history
…5 times
  • Loading branch information
vvchernov committed Aug 12, 2021
1 parent ee3e462 commit 7365266
Showing 1 changed file with 2 additions and 45 deletions.
47 changes: 2 additions & 45 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,7 @@ def func(x):

return func(data)

def chunk_dev(self, inputs, input_types):
def chunk(self, inputs, input_types):
data = inputs[0]

num_chunks = int(inputs[1])
Expand All @@ -1595,54 +1595,11 @@ def chunk_dev(self, inputs, input_types):
unif_size = int(dim / num_chunks)

indeces = []
for i in range(0, dim, unif_size):
for i in range(unif_size, dim, unif_size):
indeces.append(i)

return _op.split(data, indeces, axis)

def chunk(self, inputs, input_types):
data = inputs[0]

num_chunks = int(inputs[1])
axis = int(inputs[2])

if isinstance(data, _expr.Expr):
inferred_shape = self.infer_shape_with_prelude(data)

shape = []
for infer in inferred_shape:
shape.append(infer)

dim = int(shape[axis])

if dim % num_chunks:
unif_size = int(dim / (num_chunks - 1))
else:
unif_size = int(dim / num_chunks)

chunks = []
for i in range(0, dim, unif_size):
begin = [0] * len(shape)
end = shape[:]
begin[axis] = i
end[axis] = i + unif_size
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride)
chunks.append(chunk_out)

if dim % num_chunks:
begin = [0] * len(shape)
end = shape[:]
begin[axis] = unif_size * (num_chunks - 1)
end[axis] = dim
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride)
chunks.append(chunk_out)

return chunks

def matmul(self, inputs, input_types):

inputs_0 = inputs[0]
Expand Down

0 comments on commit 7365266

Please sign in to comment.