From ee3e4623c72e1ebd8ff748785619ba8801035e65 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 10 Aug 2021 15:46:02 +0300 Subject: [PATCH 1/2] alternative chunk op was implemented in pytorch frontend. aten::unsafe_chunk was added to op map in pytorch frontend --- python/tvm/relay/frontend/pytorch.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 83ee1d3377f4..50c89c5a73de 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1574,6 +1574,32 @@ def func(x): return func(data) + def chunk_dev(self, inputs, input_types): + data = inputs[0] + + num_chunks = int(inputs[1]) + axis = int(inputs[2]) + + if isinstance(data, _expr.Expr): + inferred_shape = self.infer_shape_with_prelude(data) + + shape = [] + for infer in inferred_shape: + shape.append(infer) + + dim = int(shape[axis]) + + if dim % num_chunks: + unif_size = int(dim / (num_chunks - 1)) + else: + unif_size = int(dim / num_chunks) + + indeces = [] + for i in range(0, dim, unif_size): + indeces.append(i) + + return _op.split(data, indeces, axis) + def chunk(self, inputs, input_types): data = inputs[0] @@ -2681,6 +2707,7 @@ def create_convert_map(self): "aten::alpha_dropout": self.dropout, "aten::mean": self.mean, "aten::chunk": self.chunk, + "aten::unsafe_chunk": self.chunk, "aten::matmul": self.matmul, "aten::bmm": self.matmul, "aten::expand": self.expand, From 736526674dbf29dca70611cda68a3ae4b8567dd5 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 11 Aug 2021 13:21:23 +0300 Subject: [PATCH 2/2] chunk was replaced by new one in pytorch frontend. it is faster in 2.5 times --- python/tvm/relay/frontend/pytorch.py | 47 ++-------------------------- 1 file changed, 2 insertions(+), 45 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 50c89c5a73de..9406c3b2ea9b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1574,7 +1574,7 @@ def func(x): return func(data) - def chunk_dev(self, inputs, input_types): + def chunk(self, inputs, input_types): data = inputs[0] num_chunks = int(inputs[1]) @@ -1595,54 +1595,11 @@ def chunk_dev(self, inputs, input_types): unif_size = int(dim / num_chunks) indeces = [] - for i in range(0, dim, unif_size): + for i in range(unif_size, dim, unif_size): indeces.append(i) return _op.split(data, indeces, axis) - def chunk(self, inputs, input_types): - data = inputs[0] - - num_chunks = int(inputs[1]) - axis = int(inputs[2]) - - if isinstance(data, _expr.Expr): - inferred_shape = self.infer_shape_with_prelude(data) - - shape = [] - for infer in inferred_shape: - shape.append(infer) - - dim = int(shape[axis]) - - if dim % num_chunks: - unif_size = int(dim / (num_chunks - 1)) - else: - unif_size = int(dim / num_chunks) - - chunks = [] - for i in range(0, dim, unif_size): - begin = [0] * len(shape) - end = shape[:] - begin[axis] = i - end[axis] = i + unif_size - stride = [1] * len(shape) - - chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride) - chunks.append(chunk_out) - - if dim % num_chunks: - begin = [0] * len(shape) - end = shape[:] - begin[axis] = unif_size * (num_chunks - 1) - end[axis] = dim - stride = [1] * len(shape) - - chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride) - chunks.append(chunk_out) - - return chunks - def matmul(self, inputs, input_types): inputs_0 = inputs[0]