From e9c044fb8163684e4deda74436b19d213cb2f34e Mon Sep 17 00:00:00 2001 From: Jin Wang Date: Tue, 2 Jul 2024 11:52:28 +0800 Subject: [PATCH 1/8] added masked_scatter and masked_scatter_ with test for masked_scatter --- ivy/functional/frontends/torch/tensor.py | 20 +++++++++- .../test_frontends/test_torch/test_tensor.py | 40 ++++++++++++++++++- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index da9c3a2dc243..23e3c906159e 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1074,6 +1074,16 @@ def masked_fill_(self, mask, value): def masked_select(self, mask): return torch_frontend.masked_select(self, mask) + def masked_scatter(self, mask, tensor): + ret = self.clone() + ret.index_put(torch_frontend.nonzero(mask, as_tuple=True), tensor) + return ret + + + def masked_scatter_(self, mask, source): + self.index_put(torch_frontend.nonzero(mask, as_tuple=True), source) + return self + @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def index_add_(self, dim, index, source, *, alpha=1): self.ivy_array = torch_frontend.index_add( @@ -2314,10 +2324,16 @@ def corrcoef(self): def index_put(self, indices, values, accumulate=False): ret = self.clone() + def _set_add(index): + ret[index] += values + + def _set(index): + ret[index] = values + if accumulate: - ret[indices[0]] += values + ivy.map(fn=_set_add, unique={"index": indices}) else: - ret[indices[0]] = values + ivy.map(fn=_set, unique={"index": indices}) return ret def index_put_(self, indices, values, accumulate=False): diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 25920923264d..c8949918b07d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -321,9 +321,11 @@ def _get_dtype_input_and_vectors(draw, with_input=False, same_size=False): @st.composite -def _masked_fill_helper(draw): +def _masked_fill_helper(draw, scatter=False): cond, xs, dtypes = draw(_broadcastable_trio()) - if ivy.is_uint_dtype(dtypes[0]): + if scatter: + fill_value = draw(helpers.array_values(dtype=dtypes[0], shape=cond.shape, min_value=0, max_value=5)) + elif ivy.is_uint_dtype(dtypes[0]): fill_value = draw(helpers.ints(min_value=0, max_value=5)) elif ivy.is_int_dtype(dtypes[0]): fill_value = draw(helpers.ints(min_value=-5, max_value=5)) @@ -9345,6 +9347,40 @@ def test_torch_masked_select( on_device=on_device, ) +# masked_scatter +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="masked_scatter", + x_mask_val=_masked_fill_helper(), +) +def test_torch_masked_fill( + x_mask_val, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + dtype, x, mask, val = x_mask_val + helpers.test_frontend_method( + init_input_dtypes=[dtype], + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x, + }, + method_input_dtypes=["bool", dtype], + method_all_as_kwargs_np={ + "mask": mask, + "tensor": val, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) # matmul @handle_frontend_method( From 3e529050830c23d68fa950cafdad603fc905975e Mon Sep 17 00:00:00 2001 From: Jin Wang Date: Tue, 2 Jul 2024 16:54:46 +0800 Subject: [PATCH 2/8] small fix on test --- ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index c8949918b07d..f76b066ff9f5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -9354,7 +9354,7 @@ def test_torch_masked_select( method_name="masked_scatter", x_mask_val=_masked_fill_helper(), ) -def test_torch_masked_fill( +def test_torch_masked_scatter( x_mask_val, frontend_method_data, init_flags, @@ -9373,7 +9373,7 @@ def test_torch_masked_fill( method_input_dtypes=["bool", dtype], method_all_as_kwargs_np={ "mask": mask, - "tensor": val, + "source": val, }, frontend_method_data=frontend_method_data, init_flags=init_flags, From 1253cec535d5939e8a1e8453980c407d7fcd7e07 Mon Sep 17 00:00:00 2001 From: Jin Wang Date: Thu, 4 Jul 2024 16:33:42 +0800 Subject: [PATCH 3/8] trying to fix failing test --- ivy/functional/frontends/torch/tensor.py | 12 ++++++-- .../test_frontends/test_torch/test_tensor.py | 29 ++++++++++++++----- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 23e3c906159e..5f5e00e1d322 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1074,14 +1074,20 @@ def masked_fill_(self, mask, value): def masked_select(self, mask): return torch_frontend.masked_select(self, mask) - def masked_scatter(self, mask, tensor): + def masked_scatter(self, mask, source): ret = self.clone() - ret.index_put(torch_frontend.nonzero(mask, as_tuple=True), tensor) + if torch_frontend.count_nonzero(mask) == 0: + return ret + conv = torch_frontend.nonzero(mask, as_tuple=True) + ret.index_put(conv, source) return ret def masked_scatter_(self, mask, source): - self.index_put(torch_frontend.nonzero(mask, as_tuple=True), source) + if torch_frontend.count_nonzero(mask) == 0: + return self + conv = torch_frontend.nonzero(mask, as_tuple=True) + self.index_put(conv, source) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index f76b066ff9f5..8fc3b86f0797 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -321,11 +321,9 @@ def _get_dtype_input_and_vectors(draw, with_input=False, same_size=False): @st.composite -def _masked_fill_helper(draw, scatter=False): +def _masked_fill_helper(draw): cond, xs, dtypes = draw(_broadcastable_trio()) - if scatter: - fill_value = draw(helpers.array_values(dtype=dtypes[0], shape=cond.shape, min_value=0, max_value=5)) - elif ivy.is_uint_dtype(dtypes[0]): + if ivy.is_uint_dtype(dtypes[0]): fill_value = draw(helpers.ints(min_value=0, max_value=5)) elif ivy.is_int_dtype(dtypes[0]): fill_value = draw(helpers.ints(min_value=-5, max_value=5)) @@ -334,6 +332,23 @@ def _masked_fill_helper(draw, scatter=False): return dtypes[0], xs[0], cond, fill_value +def _masked_scatter_helper(draw): + shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) + dtypes, xs = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shape=shape, + shared_dtype=True, + large_abs_safety_factor=16, + small_abs_safety_factor=16, + safety_factor_scale="log", + ) + ) + mask = draw(helpers.array_values(dtype=["bool"], shape=shape)) + return dtypes[0], xs[0], mask, xs[1] + + @st.composite def _repeat_helper(draw): shape = draw( @@ -9352,10 +9367,10 @@ def test_torch_masked_select( class_tree=CLASS_TREE, init_tree="torch.tensor", method_name="masked_scatter", - x_mask_val=_masked_fill_helper(), + dtype_x_mask_val=_masked_scatter_helper(), ) def test_torch_masked_scatter( - x_mask_val, + dtype_x_mask_val, frontend_method_data, init_flags, method_flags, @@ -9363,7 +9378,7 @@ def test_torch_masked_scatter( on_device, backend_fw, ): - dtype, x, mask, val = x_mask_val + dtype, x, mask, val = dtype_x_mask_val helpers.test_frontend_method( init_input_dtypes=[dtype], backend_to_test=backend_fw, From 76112cd55761794463bd57f6749f1d6019adf882 Mon Sep 17 00:00:00 2001 From: Jin Wang Date: Sat, 6 Jul 2024 11:02:28 +0800 Subject: [PATCH 4/8] small fix --- ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 8fc3b86f0797..76a9a3b8328a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -332,6 +332,7 @@ def _masked_fill_helper(draw): return dtypes[0], xs[0], cond, fill_value +@st.composite def _masked_scatter_helper(draw): shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) dtypes, xs = draw( From fa41368156872fa03c09e344daedd579d3bf9d94 Mon Sep 17 00:00:00 2001 From: Jin Wang Date: Sat, 6 Jul 2024 12:31:59 +0800 Subject: [PATCH 5/8] tests are all passing --- ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 76a9a3b8328a..9d98cf89d5ae 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -346,7 +346,7 @@ def _masked_scatter_helper(draw): safety_factor_scale="log", ) ) - mask = draw(helpers.array_values(dtype=["bool"], shape=shape)) + mask = draw(helpers.array_values(dtype=helpers.get_dtypes("bool"), shape=shape)) return dtypes[0], xs[0], mask, xs[1] From 7a94a9724b87f326f28c304104e8f20fc7db9811 Mon Sep 17 00:00:00 2001 From: Sam-Armstrong Date: Wed, 10 Jul 2024 02:38:26 +0100 Subject: [PATCH 6/8] fix test_torch_masked_scatter --- ivy/functional/frontends/torch/tensor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 5f5e00e1d322..dc55a9c86567 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1075,12 +1075,12 @@ def masked_select(self, mask): return torch_frontend.masked_select(self, mask) def masked_scatter(self, mask, source): - ret = self.clone() - if torch_frontend.count_nonzero(mask) == 0: - return ret - conv = torch_frontend.nonzero(mask, as_tuple=True) - ret.index_put(conv, source) - return ret + flat_self = torch_frontend.flatten(self.clone()) + flat_mask = torch_frontend.flatten(mask) + flat_source = torch_frontend.flatten(source) + indices = torch_frontend.squeeze(torch_frontend.nonzero(flat_mask), -1) + flat_self.scatter_(0, indices, flat_source[:indices.shape[0]]) + return flat_self.reshape(self.shape) def masked_scatter_(self, mask, source): From c693d0e80cf0ca73ecebe16d4a37caa1d356c8a4 Mon Sep 17 00:00:00 2001 From: Jin Wang Date: Wed, 10 Jul 2024 11:44:40 +0800 Subject: [PATCH 7/8] now only paddle test is failing when put_along_axis of paddle is called, during with the dtype is passed as None when converting result back to tensor for return --- ivy/functional/frontends/torch/tensor.py | 21 +++++------ .../test_frontends/test_torch/test_tensor.py | 37 ++++++++++++++++++- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 7d159c1a46c0..96200af3baac 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1087,12 +1087,15 @@ def masked_scatter(self, mask, source): def masked_scatter_(self, mask, source): - if torch_frontend.count_nonzero(mask) == 0: - return self - conv = torch_frontend.nonzero(mask, as_tuple=True) - self.index_put(conv, source) + flat_self = torch_frontend.flatten(self.clone()) + flat_mask = torch_frontend.flatten(mask) + flat_source = torch_frontend.flatten(source) + indices = torch_frontend.squeeze(torch_frontend.nonzero(flat_mask), -1) + flat_self.scatter_(0, indices, flat_source[:indices.shape[0]]) + self = flat_self.reshape(self.shape) return self + @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def index_add_(self, dim, index, source, *, alpha=1): self.ivy_array = torch_frontend.index_add( @@ -2333,16 +2336,10 @@ def corrcoef(self): def index_put(self, indices, values, accumulate=False): ret = self.clone() - def _set_add(index): - ret[index] += values - - def _set(index): - ret[index] = values - if accumulate: - ivy.map(fn=_set_add, unique={"index": indices}) + ret[indices[0]] += values else: - ivy.map(fn=_set, unique={"index": indices}) + ret[indices[0]] = values return ret def index_put_(self, indices, values, accumulate=False): diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 9d98cf89d5ae..478ee9a8c3ae 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -346,7 +346,7 @@ def _masked_scatter_helper(draw): safety_factor_scale="log", ) ) - mask = draw(helpers.array_values(dtype=helpers.get_dtypes("bool"), shape=shape)) + mask = draw(helpers.array_values(dtype="bool", shape=shape)) return dtypes[0], xs[0], mask, xs[1] @@ -9398,6 +9398,41 @@ def test_torch_masked_scatter( on_device=on_device, ) +# masked_scatter_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="masked_scatter_", + dtype_x_mask_val=_masked_scatter_helper(), +) +def test_torch_masked_scatter_( + dtype_x_mask_val, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + dtype, x, mask, val = dtype_x_mask_val + helpers.test_frontend_method( + init_input_dtypes=[dtype], + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x, + }, + method_input_dtypes=["bool", dtype], + method_all_as_kwargs_np={ + "mask": mask, + "source": val, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + # matmul @handle_frontend_method( class_tree=CLASS_TREE, From 337b56eff33985a175a8ff863e98632a1cffd262 Mon Sep 17 00:00:00 2001 From: Sam-Armstrong Date: Wed, 10 Jul 2024 12:05:30 +0100 Subject: [PATCH 8/8] minor fix --- ivy/functional/frontends/torch/tensor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 96200af3baac..34320c1a071c 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1085,17 +1085,15 @@ def masked_scatter(self, mask, source): flat_self.scatter_(0, indices, flat_source[:indices.shape[0]]) return flat_self.reshape(self.shape) - def masked_scatter_(self, mask, source): flat_self = torch_frontend.flatten(self.clone()) flat_mask = torch_frontend.flatten(mask) flat_source = torch_frontend.flatten(source) indices = torch_frontend.squeeze(torch_frontend.nonzero(flat_mask), -1) flat_self.scatter_(0, indices, flat_source[:indices.shape[0]]) - self = flat_self.reshape(self.shape) + self.ivy_array = flat_self.reshape(self.shape).ivy_array return self - @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def index_add_(self, dim, index, source, *, alpha=1): self.ivy_array = torch_frontend.index_add(