From 8af3ae7206287e4114b907261c6d3568266388fb Mon Sep 17 00:00:00 2001 From: Daniel4078 <45633544+Daniel4078@users.noreply.github.com> Date: Wed, 10 Jul 2024 19:08:01 +0800 Subject: [PATCH] feat: add torch.Tensor frontend masked_scatter and masked_scatter_ (#28783) Co-authored-by: Jin Wang Co-authored-by: Sam-Armstrong --- ivy/functional/frontends/torch/tensor.py | 17 ++++ .../test_frontends/test_torch/test_tensor.py | 87 +++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 52293d9cbf1e..34320c1a071c 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1077,6 +1077,23 @@ def masked_fill_(self, mask, value): def masked_select(self, mask): return torch_frontend.masked_select(self, mask) + def masked_scatter(self, mask, source): + flat_self = torch_frontend.flatten(self.clone()) + flat_mask = torch_frontend.flatten(mask) + flat_source = torch_frontend.flatten(source) + indices = torch_frontend.squeeze(torch_frontend.nonzero(flat_mask), -1) + flat_self.scatter_(0, indices, flat_source[:indices.shape[0]]) + return flat_self.reshape(self.shape) + + def masked_scatter_(self, mask, source): + flat_self = torch_frontend.flatten(self.clone()) + flat_mask = torch_frontend.flatten(mask) + flat_source = torch_frontend.flatten(source) + indices = torch_frontend.squeeze(torch_frontend.nonzero(flat_mask), -1) + flat_self.scatter_(0, indices, flat_source[:indices.shape[0]]) + self.ivy_array = flat_self.reshape(self.shape).ivy_array + return self + @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def index_add_(self, dim, index, source, *, alpha=1): self.ivy_array = torch_frontend.index_add( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 25920923264d..478ee9a8c3ae 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -332,6 +332,24 @@ def _masked_fill_helper(draw): return dtypes[0], xs[0], cond, fill_value +@st.composite +def _masked_scatter_helper(draw): + shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) + dtypes, xs = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shape=shape, + shared_dtype=True, + large_abs_safety_factor=16, + small_abs_safety_factor=16, + safety_factor_scale="log", + ) + ) + mask = draw(helpers.array_values(dtype="bool", shape=shape)) + return dtypes[0], xs[0], mask, xs[1] + + @st.composite def _repeat_helper(draw): shape = draw( @@ -9345,6 +9363,75 @@ def test_torch_masked_select( on_device=on_device, ) +# masked_scatter +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="masked_scatter", + dtype_x_mask_val=_masked_scatter_helper(), +) +def test_torch_masked_scatter( + dtype_x_mask_val, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + dtype, x, mask, val = dtype_x_mask_val + helpers.test_frontend_method( + init_input_dtypes=[dtype], + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x, + }, + method_input_dtypes=["bool", dtype], + method_all_as_kwargs_np={ + "mask": mask, + "source": val, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + +# masked_scatter_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="masked_scatter_", + dtype_x_mask_val=_masked_scatter_helper(), +) +def test_torch_masked_scatter_( + dtype_x_mask_val, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + dtype, x, mask, val = dtype_x_mask_val + helpers.test_frontend_method( + init_input_dtypes=[dtype], + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x, + }, + method_input_dtypes=["bool", dtype], + method_all_as_kwargs_np={ + "mask": mask, + "source": val, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) # matmul @handle_frontend_method(