Skip to content

Commit

Permalink
added masked_scatter and masked_scatter_, make index_put and index_pu…
Browse files Browse the repository at this point in the history
…t_ use same approach
  • Loading branch information
Jin Wang committed Jul 2, 2024
1 parent dce10a6 commit ef20966
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,16 @@ def masked_fill_(self, mask, value):
self.ivy_array = self.masked_fill(mask, value).ivy_array
return self

def masked_scatter(self, mask, source):
ret = self.clone()
ret.index_put(torch_frontend.nonzero(mask, as_tuple=True), source)
return ret


def masked_scatter_(self, mask, source):
self.index_put(torch_frontend.nonzero(mask, as_tuple=True), source)
return self

@with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch")
def index_add_(self, dim, index, source, *, alpha=1):
self.ivy_array = torch_frontend.index_add(
Expand Down Expand Up @@ -2300,10 +2310,16 @@ def corrcoef(self):

def index_put(self, indices, values, accumulate=False):
ret = self.clone()
def _set_add(index):
ret[index] += values

def _set(index):
ret[index] = values

if accumulate:
ret[indices[0]] += values
ivy.map(fn=_set_add, unique={"index": indices})
else:
ret[indices[0]] = values
ivy.map(fn=_set, unique={"index": indices})
return ret

def index_put_(self, indices, values, accumulate=False):
Expand Down

0 comments on commit ef20966

Please sign in to comment.