Skip to content

Commit

Permalink
fix: handle dtype casting for masked_fill in torch frontend (#28740)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kacper-W-Kozdon authored Apr 30, 2024
1 parent 65a5363 commit 6e7807c
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,8 +1062,9 @@ def acosh(self):
return torch_frontend.acosh(self)

def masked_fill(self, mask, value):
dtype = ivy.as_native_dtype(self.dtype)
return torch_frontend.tensor(
torch_frontend.where(mask, value, self), dtype=self.dtype
ivy.astype(torch_frontend.where(mask, value, self), dtype)
)

def masked_fill_(self, mask, value):
Expand Down

0 comments on commit 6e7807c

Please sign in to comment.