diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0d93ff987c6e..acdd3918982c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -636,6 +636,33 @@ def tensor_split(self, inputs, input_types): return _op.split(data, sections, dim) + def index_fill(self, inputs, input_types): + data = inputs[0] + dim = inputs[1] + indices = inputs[2] + value = inputs[3] + + dtype = self.infer_type(data).dtype + input_shape = self.infer_shape(data) + input_rank = len(input_shape) + dim = input_rank - 1 if dim == -1 else dim + + value = _op.nn.const(value, dtype=dtype) + indices_shape = list(input_shape) + indices_shape[dim] = 1 + + if np.isscalar(indices): + idx_val = _op.full(fill_value=indices, shape=indices_shape, dtype="int64") + result = _op.scatter_elements(data=data, indices=idx_val, updates=value, axis=dim, reduction="update") + else: + result = data + length = self.infer_shape(indices)[0] + for i in range(length): + idx_val = _op.transform.take(indices, indices=_op.nn.const(i), axis=0) + idx_val = _op.full(fill_value=idx_val, shape=indices_shape, dtype="int64") + result = _op.scatter_elements(data=result, indices=idx_val, updates=value, axis=dim, reduction="update") + return result + def select(self, inputs, input_types): data = inputs[0] dim = int(inputs[1]) @@ -4039,6 +4066,7 @@ def create_convert_map(self): "aten::pixel_shuffle": self.pixel_shuffle, "aten::device": self.none, "prim::device": self.none, + "aten::index_fill_": self.index_fill, "aten::sub": self.sub, "aten::max": self.max, "aten::min": self.min, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9f8fac93061c..0b5f740b08c4 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3595,6 +3595,22 @@ def test_func(x): verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()]) +def test_forward_index_fill_(): + """ text_forward_index_fill """ + torch.set_grad_enabled(False) + + def test_func_vector_index(x): + index = torch.arange(start=0, end=2, step=1) + return x.index_fill_(dim=1, index=index, value=3.0) + + def test_func_scalar_index(x): + return x.index_fill_(dim=1, index=1, value=3.0) + + verify_model_with_input(test_func_vector_index, [torch.rand([1, 3, 224, 224], dtype=torch.float32)]) + verify_model_with_input(test_func_vector_index, [torch.rand([32, 128, 128], dtype=torch.float32)]) + verify_model_with_input(test_func_scalar_index, [torch.rand([1, 32, 224, 224], dtype=torch.float32)]) + verify_model_with_input(test_func_scalar_index, [torch.rand([128, 128], dtype=torch.float32)]) + @tvm.testing.uses_gpu def test_forward_linspace(): """test_forward_linspace"""