From bb393feb3b35b150684a7ee5aa1809162ccfb4fb Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Thu, 20 Aug 2020 17:51:25 -0700 Subject: [PATCH 1/3] Add Pytorch advanced indexing --- python/tvm/relay/frontend/pytorch.py | 48 +++++++++++++++++-- tests/python/frontend/pytorch/test_forward.py | 34 ++++++++++++- 2 files changed, 77 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 85dd5f4ce48b..a6440194e51d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -274,16 +274,18 @@ def _impl(inputs, input_types): end[dim] = min(end[dim], int(inputs[3])) else: if isinstance(inputs[3], _expr.Call): - end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) + target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) else: - end[dim] = inputs[3] + target_end = inputs[3] + + end[dim] = min(end[dim], target_end) strides.append(int(inputs[4])) return _op.transform.strided_slice(data, begin=_expr.const(begin), end=_expr.const(end), strides=_expr.const(strides), - slice_mode="size") + slice_mode="end") return _impl def _split(): @@ -1755,6 +1757,45 @@ def _impl(inputs, input_types): return _impl +def _index(): + def _impl(inputs, input_types): + data = inputs[0] + indices = [] + max_indices_len = -1 + for index in inputs[1]: + if not isinstance(index, _expr.Constant): + raise RuntimeError("Only supports constant indices for " + "pytorch advanced indexing ") + cindex_len = index.data.shape[0] + if cindex_len > max_indices_len: + max_indices_len = cindex_len + + for index in inputs[1]: + cnp = index.data.asnumpy() + cindex_len = cnp.shape[0] + if cindex_len < max_indices_len: + cnp = np.tile(cnp, max_indices_len // cindex_len) + indices.append(cnp) + + ret = [] + slice_map = {} + for i in range(indices[0].shape[0]): + tmp = data + current_indices = [] + for index in indices: + current_indices.append(index[i]) + index_key = tuple(current_indices) + if index_key in slice_map: + tmp = slice_map[index_key] + else: + tmp = _op.take(tmp, _expr.const(index[i]), axis=0) + slice_map[index_key] = tmp + ret.append(_op.expand_dims(tmp, axis=0)) + + return _op.concatenate(ret, axis=0) + return _impl + + def _meshgrid(): def _impl(inputs, input_types): data = inputs[0] @@ -2060,6 +2101,7 @@ def _get_convert_map(prelude): "aten::type_as" : _type_as(), "aten::gather" : _gather(), "aten::index_select" : _select(), + "aten::index" : _index(), } return convert_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d5b4ed2fc9c8..37c60bf6b25d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1202,13 +1202,13 @@ def forward(self, *args): class Slice2(Module): def forward(self, *args): - return args[0][0, :, :, :] + return args[0][0, :, :-3, :] class Slice3(Module): def forward(self, *args): x0 = torch.tensor(2) - torch.tensor(1) x1 = torch.tensor(3) + torch.tensor(1) - return args[0][:, x0:, :x1, :] + return args[0][:, x0:, 1:x1, :] input_data = torch.rand(input_shape).float() verify_model(Slice1().float().eval(), input_data=input_data) @@ -2607,6 +2607,35 @@ def forward(self, *args): verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) +def test_forward_index(): + torch.set_grad_enabled(False) + input_shape = [3, 4, 5, 6] + + class Index0(Module): + def __init__(self): + super().__init__() + if torch.cuda.is_available(): + self.inp = self.inp.cuda() + + def forward(self, x): + return x[[0, 1], [0, 2], :2, 4] + + input_data = torch.rand(input_shape).float() + verify_model(Index0().eval(), input_data=input_data) + + class Index1(Module): + def __init__(self): + super().__init__() + if torch.cuda.is_available(): + self.inp = self.inp.cuda() + + def forward(self, x): + return x[[0], [1, 2, 3, 0], [3, 1, 2, 2], [4, 2, 1, 0]] + + input_data = torch.rand(input_shape).float() + verify_model(Index1().eval(), input_data=input_data) + + def test_forward_pretrained_bert_base_uncased(): ###################################################################### # This is an example how to run BERT models using TVM @@ -2846,6 +2875,7 @@ def test_forward_pretrained_bert_base_uncased(): test_adaptive_pool3d() test_conv3d() test_conv3d_transpose() + test_forward_index() # Model tests test_resnet18() From 5ede75f93c598ce94cc97fc6b1ffffadaa9b7245 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Thu, 20 Aug 2020 18:01:23 -0700 Subject: [PATCH 2/3] Minor fix for test --- tests/python/frontend/pytorch/test_forward.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 37c60bf6b25d..9ce2b566b085 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2612,11 +2612,6 @@ def test_forward_index(): input_shape = [3, 4, 5, 6] class Index0(Module): - def __init__(self): - super().__init__() - if torch.cuda.is_available(): - self.inp = self.inp.cuda() - def forward(self, x): return x[[0, 1], [0, 2], :2, 4] @@ -2624,11 +2619,6 @@ def forward(self, x): verify_model(Index0().eval(), input_data=input_data) class Index1(Module): - def __init__(self): - super().__init__() - if torch.cuda.is_available(): - self.inp = self.inp.cuda() - def forward(self, x): return x[[0], [1, 2, 3, 0], [3, 1, 2, 2], [4, 2, 1, 0]] From ce41b8b69a1b452e73d8595add6be878c85cd8fd Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 21 Aug 2020 18:03:19 +0000 Subject: [PATCH 3/3] Fix for cuda --- python/tvm/relay/frontend/pytorch.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a6440194e51d..e990bdd56107 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1761,16 +1761,21 @@ def _index(): def _impl(inputs, input_types): data = inputs[0] indices = [] + raw_indices = [] max_indices_len = -1 for index in inputs[1]: if not isinstance(index, _expr.Constant): - raise RuntimeError("Only supports constant indices for " - "pytorch advanced indexing ") + try: + index = _expr.const(_infer_value(index, {})) + except Exception: + raise RuntimeError("Only supports constant indices for " + "pytorch advanced indexing ") + raw_indices.append(index) cindex_len = index.data.shape[0] if cindex_len > max_indices_len: max_indices_len = cindex_len - for index in inputs[1]: + for index in raw_indices: cnp = index.data.asnumpy() cindex_len = cnp.shape[0] if cindex_len < max_indices_len: