Skip to content

Commit

Permalink
(feat)(torch frontends): added the frontend functions for `torch.Tens…
Browse files Browse the repository at this point in the history
…or.bernoulli_` and `torch.Tensor.numel`
  • Loading branch information
YushaArif99 committed Feb 5, 2024
1 parent c856413 commit decce56
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
4 changes: 2 additions & 2 deletions ivy/functional/frontends/torch/random_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
"torch",
)
@to_ivy_arrays_and_back
def bernoulli(input, *, generator=None, out=None):
def bernoulli(input, p, *, generator=None, out=None):
seed = generator.initial_seed() if generator is not None else None
return ivy.bernoulli(input, seed=seed, out=out)
return ivy.bernoulli(p, logits=input, seed=seed, out=out)


@to_ivy_arrays_and_back
Expand Down
15 changes: 13 additions & 2 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,8 +1201,19 @@ def dot(self, tensor):
return torch_frontend.dot(self, tensor)

@with_supported_dtypes({"2.1.2 and below": ("float32", "float64")}, "torch")
def bernoulli(self, *, generator=None, out=None):
return torch_frontend.bernoulli(self._ivy_array, generator=generator, out=out)
def bernoulli(self, p, *, generator=None, out=None):
return torch_frontend.bernoulli(
self._ivy_array, p, generator=generator, out=out
)

@with_supported_dtypes({"2.1.2 and below": ("float32", "float64")}, "torch")
def bernoulli_(self, p, *, generator=None, out=None):
self.ivy_array = self.bernoulli(p, generator=generator, out=out).ivy_array
return self

def numel(self):
shape = self.shape
return int(ivy.astype(ivy.prod(shape), ivy.int64))

# Special Methods #
# -------------------#
Expand Down

0 comments on commit decce56

Please sign in to comment.