From 2a01eabb96f7bbdc799f73813fe905291398af06 Mon Sep 17 00:00:00 2001 From: PineApple777 Date: Wed, 18 Sep 2024 23:30:25 +0900 Subject: [PATCH 1/3] initial support index_fill_ --- python/tvm/relay/frontend/pytorch.py | 27 +++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 16 +++++++++++ 2 files changed, 43 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0d93ff987c6e..58a483e61436 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -636,6 +636,32 @@ def tensor_split(self, inputs, input_types): return _op.split(data, sections, dim) + def index_fill(self, inputs, input_types): + data = inputs[0] + dim = inputs[1] + indices = inputs[2] + value = inputs[3] + + dtype = self.infer_type(data).dtype + input_shape = self.infer_shape(data) + input_rank = len(input_shape) + dim = input_rank - 1 if dim == -1 else dim + + value = _op.nn.const(value, dtype=dtype) + indices_shape = list(input_shape) + indices_shape[dim] = 1 + + result = data + if np.isscalar(indices): + idx_val = _op.full(fill_value=indices, shape=indices_shape, dtype="int64") + else: + length = self.infer_shape(indices)[0] + for i in range(length): + idx_val = _op.transform.take(indices, indices=_op.nn.const(i), axis=0) + idx_val = _op.full(fill_value=idx_val, shape=indices_shape, dtype="int64") + result = _op.scatter_elements(data=result, indices=idx_val, updates=value, axis=dim, reduction="update") + return result + def select(self, inputs, input_types): data = inputs[0] dim = int(inputs[1]) @@ -4039,6 +4065,7 @@ def create_convert_map(self): "aten::pixel_shuffle": self.pixel_shuffle, "aten::device": self.none, "prim::device": self.none, + "aten::index_fill_": self.index_fill, "aten::sub": self.sub, "aten::max": self.max, "aten::min": self.min, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9f8fac93061c..0b5f740b08c4 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3595,6 +3595,22 @@ def test_func(x): verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()]) +def test_forward_index_fill_(): + """ text_forward_index_fill """ + torch.set_grad_enabled(False) + + def test_func_vector_index(x): + index = torch.arange(start=0, end=2, step=1) + return x.index_fill_(dim=1, index=index, value=3.0) + + def test_func_scalar_index(x): + return x.index_fill_(dim=1, index=1, value=3.0) + + verify_model_with_input(test_func_vector_index, [torch.rand([1, 3, 224, 224], dtype=torch.float32)]) + verify_model_with_input(test_func_vector_index, [torch.rand([32, 128, 128], dtype=torch.float32)]) + verify_model_with_input(test_func_scalar_index, [torch.rand([1, 32, 224, 224], dtype=torch.float32)]) + verify_model_with_input(test_func_scalar_index, [torch.rand([128, 128], dtype=torch.float32)]) + @tvm.testing.uses_gpu def test_forward_linspace(): """test_forward_linspace""" From 82a8e971ec99761252cb95f18547c004dc64321c Mon Sep 17 00:00:00 2001 From: PineApple777 Date: Wed, 18 Sep 2024 23:41:55 +0900 Subject: [PATCH 2/3] fix scalar index conversion --- python/tvm/relay/frontend/pytorch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 58a483e61436..acdd3918982c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -651,10 +651,11 @@ def index_fill(self, inputs, input_types): indices_shape = list(input_shape) indices_shape[dim] = 1 - result = data if np.isscalar(indices): idx_val = _op.full(fill_value=indices, shape=indices_shape, dtype="int64") + result = _op.scatter_elements(data=data, indices=idx_val, updates=value, axis=dim, reduction="update") else: + result = data length = self.infer_shape(indices)[0] for i in range(length): idx_val = _op.transform.take(indices, indices=_op.nn.const(i), axis=0) From 6e9f5b022d2a6a6641faf3128ef44c1b84c198a7 Mon Sep 17 00:00:00 2001 From: PineApple777 Date: Sun, 13 Oct 2024 16:15:40 +0900 Subject: [PATCH 3/3] fix lint --- python/tvm/relay/frontend/pytorch.py | 21 ++++++++++++------- tests/python/frontend/pytorch/test_forward.py | 21 +++++++++++++++---- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index acdd3918982c..2a28f488489d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -52,6 +52,7 @@ __all__ = ["from_pytorch"] + # This returns a "subgraph" which puts variables whenever # the type is known. It also records things to map the input # nodes to the extracted graph's nodes. @@ -653,14 +654,18 @@ def index_fill(self, inputs, input_types): if np.isscalar(indices): idx_val = _op.full(fill_value=indices, shape=indices_shape, dtype="int64") - result = _op.scatter_elements(data=data, indices=idx_val, updates=value, axis=dim, reduction="update") + result = _op.scatter_elements( + data=data, indices=idx_val, updates=value, axis=dim, reduction="update" + ) else: result = data length = self.infer_shape(indices)[0] for i in range(length): idx_val = _op.transform.take(indices, indices=_op.nn.const(i), axis=0) idx_val = _op.full(fill_value=idx_val, shape=indices_shape, dtype="int64") - result = _op.scatter_elements(data=result, indices=idx_val, updates=value, axis=dim, reduction="update") + result = _op.scatter_elements( + data=result, indices=idx_val, updates=value, axis=dim, reduction="update" + ) return result def select(self, inputs, input_types): @@ -3895,7 +3900,7 @@ def inplace_copy(self, inputs, input_types): # Create indices nelem = 1 - for (begin, end) in index_map.values(): + for begin, end in index_map.values(): nelem *= end - begin chunk_sizes = [nelem] for i in range(1, last_index_dim + 1): @@ -4505,7 +4510,7 @@ def body(*current_vals): # Update loop variables using the prev iteration outputs assert len(current_vals) == num_block_inputs + len(free_vars) - for (i, val) in enumerate(current_vals): + for i, val in enumerate(current_vals): if i < num_block_inputs: outputs[block_input_names[i]] = val else: @@ -5486,9 +5491,11 @@ def from_pytorch( } data_inputs = sorted( data_inputs, - key=lambda data_input: order_input_infos[data_input.name_hint] - if data_input.name_hint in order_input_infos - else -1, + key=lambda data_input: ( + order_input_infos[data_input.name_hint] + if data_input.name_hint in order_input_infos + else -1 + ), reverse=True, ) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 0b5f740b08c4..3b3919899115 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2940,6 +2940,7 @@ def forward(self, inp): @tvm.testing.uses_gpu def test_simple_rnn(): """test_simple_rnn""" + # The mixed tracing and scripting example from # https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html#mixing-scripting-and-tracing class DecisionGate(torch.nn.Module): @@ -3596,7 +3597,7 @@ def test_func(x): def test_forward_index_fill_(): - """ text_forward_index_fill """ + """text_forward_index_fill""" torch.set_grad_enabled(False) def test_func_vector_index(x): @@ -3606,11 +3607,18 @@ def test_func_vector_index(x): def test_func_scalar_index(x): return x.index_fill_(dim=1, index=1, value=3.0) - verify_model_with_input(test_func_vector_index, [torch.rand([1, 3, 224, 224], dtype=torch.float32)]) - verify_model_with_input(test_func_vector_index, [torch.rand([32, 128, 128], dtype=torch.float32)]) - verify_model_with_input(test_func_scalar_index, [torch.rand([1, 32, 224, 224], dtype=torch.float32)]) + verify_model_with_input( + test_func_vector_index, [torch.rand([1, 3, 224, 224], dtype=torch.float32)] + ) + verify_model_with_input( + test_func_vector_index, [torch.rand([32, 128, 128], dtype=torch.float32)] + ) + verify_model_with_input( + test_func_scalar_index, [torch.rand([1, 32, 224, 224], dtype=torch.float32)] + ) verify_model_with_input(test_func_scalar_index, [torch.rand([128, 128], dtype=torch.float32)]) + @tvm.testing.uses_gpu def test_forward_linspace(): """test_forward_linspace""" @@ -4216,6 +4224,7 @@ def test_weight_names(): @tvm.testing.uses_gpu def test_duplicate_weight_use(): """test_duplicate_weight_use""" + # The test cases doesn't make any sense as a neural network, # the issue popped up in shared input/output embeddings of bert, # but this is quicker @@ -4471,6 +4480,7 @@ def forward(self, data): def test_forward_scatter(): """test_forward_scatter""" + # integer cannot be traced def test_fn_scatter(dim): return lambda data, index, src: torch.scatter(data, dim=dim, index=index, src=src) @@ -4507,6 +4517,7 @@ def test_fn_scatter_add(dim): def test_forward_scatter_reduce(): """test_forward_scatter_reduce""" + # integer cannot be traced def test_fn_scatter_reduce(dim, reduce): return lambda data, index, src: torch.scatter_reduce( @@ -4531,6 +4542,7 @@ def test_fn_scatter_reduce(dim, reduce): def test_forward_index_put(): """test_forward_index_put""" + # torch.index_put for 2D tensor and default accumulate (False) def test_fn_index_put2(): return lambda data, xidx, yidx, values: torch.index_put( @@ -5326,6 +5338,7 @@ def test_remainder(x, y): def test_softmax_fuse(): """test_softmax_fuse""" + # https://github.com/apache/tvm/issues/12001 class Model(torch.nn.Module): """Pytorch model module"""