Skip to content

Commit

Permalink
chunk was replaced by new one in pytorch frontend. it is faster in 2.…
Browse files Browse the repository at this point in the history
…5 times
  • Loading branch information
vvchernov committed Aug 11, 2021
1 parent c3a3643 commit 028ed6e
Showing 1 changed file with 2 additions and 45 deletions.
47 changes: 2 additions & 45 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,7 @@ def func(x):

return func(data)

def chunk_dev(self, inputs, input_types):
def chunk(self, inputs, input_types):
data = inputs[0]

num_chunks = int(inputs[1])
Expand All @@ -1595,54 +1595,11 @@ def chunk_dev(self, inputs, input_types):
unif_size = int(dim / num_chunks)

indeces = []
for i in range(0, dim, unif_size):
for i in range(unif_size, dim, unif_size):
indeces.append(i)

return _op.split(data, indeces, axis)

def chunk(self, inputs, input_types):
data = inputs[0]

num_chunks = int(inputs[1])
axis = int(inputs[2])

if isinstance(data, _expr.Expr):
inferred_shape = self.infer_shape_with_prelude(data)

shape = []
for infer in inferred_shape:
shape.append(infer)

dim = int(shape[axis])

if dim % num_chunks:
unif_size = int(dim / (num_chunks - 1))
else:
unif_size = int(dim / num_chunks)

chunks = []
for i in range(0, dim, unif_size):
begin = [0] * len(shape)
end = shape[:]
begin[axis] = i
end[axis] = i + unif_size
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride)
chunks.append(chunk_out)

if dim % num_chunks:
begin = [0] * len(shape)
end = shape[:]
begin[axis] = unif_size * (num_chunks - 1)
end[axis] = dim
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride)
chunks.append(chunk_out)

return chunks

def matmul(self, inputs, input_types):

inputs_0 = inputs[0]
Expand Down

0 comments on commit 028ed6e

Please sign in to comment.