From ef20966b3208578d57ec852db0f00bb51c020b1e Mon Sep 17 00:00:00 2001 From: Jin Wang Date: Tue, 2 Jul 2024 10:25:09 +0800 Subject: [PATCH] added masked_scatter and masked_scatter_, make index_put and index_put_ use same approach --- ivy/functional/frontends/torch/tensor.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 7a9b786c4e9bd..4a6c38211fee6 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1071,6 +1071,16 @@ def masked_fill_(self, mask, value): self.ivy_array = self.masked_fill(mask, value).ivy_array return self + def masked_scatter(self, mask, source): + ret = self.clone() + ret.index_put(torch_frontend.nonzero(mask, as_tuple=True), source) + return ret + + + def masked_scatter_(self, mask, source): + self.index_put(torch_frontend.nonzero(mask, as_tuple=True), source) + return self + @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def index_add_(self, dim, index, source, *, alpha=1): self.ivy_array = torch_frontend.index_add( @@ -2300,10 +2310,16 @@ def corrcoef(self): def index_put(self, indices, values, accumulate=False): ret = self.clone() + def _set_add(index): + ret[index] += values + + def _set(index): + ret[index] = values + if accumulate: - ret[indices[0]] += values + ivy.map(fn=_set_add, unique={"index": indices}) else: - ret[indices[0]] = values + ivy.map(fn=_set, unique={"index": indices}) return ret def index_put_(self, indices, values, accumulate=False):