Skip to content

Commit

Permalink
feat: add torch.special.erfc in torch frontend and turn the implement…
Browse files Browse the repository at this point in the history
…ation in pointwise ops into an alias for the special func
  • Loading branch information
Ishticode committed Feb 13, 2024
1 parent 902cf55 commit 58227ae
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 6 deletions.
5 changes: 5 additions & 0 deletions erfissues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import torch
print(torch.special.erfc(torch.tensor([0, -1., 10.])))
from ivy.functional.frontends import torch as ivy_torch
print(ivy_torch.special.erfc(ivy_torch.tensor([0, -1., 10.])))
print(ivy_torch.erfc(ivy_torch.tensor([0, -1., 10.])))
1 change: 1 addition & 0 deletions ivy/functional/frontends/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def promote_types_of_torch_inputs(

from . import nn
from .nn.functional import softmax, relu, lstm
from . import special
from . import tensor
from .tensor import *
from . import blas_and_lapack_ops
Expand Down
9 changes: 3 additions & 6 deletions ivy/functional/frontends/torch/pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
)


erfc = torch_frontend.special.erfc


@to_ivy_arrays_and_back
def abs(input, *, out=None):
return ivy.abs(input, out=out)
Expand Down Expand Up @@ -189,12 +192,6 @@ def erf(input, *, out=None):
return ivy.erf(input, out=out)


@with_unsupported_dtypes({"2.2 and below": ("float16", "complex")}, "torch")
@to_ivy_arrays_and_back
def erfc(input, *, out=None):
return 1.0 - ivy.erf(input, out=out)


@with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch")
@to_ivy_arrays_and_back
def exp(input, *, out=None):
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/frontends/torch/special/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .special_funcs import *
13 changes: 13 additions & 0 deletions ivy/functional/frontends/torch/special/special_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import ivy
from ivy.func_wrapper import (
with_unsupported_dtypes,
)
from ivy.functional.frontends.torch.func_wrapper import (
to_ivy_arrays_and_back,
)


@with_unsupported_dtypes({"2.2 and below": ("float16", "complex")}, "torch")
@to_ivy_arrays_and_back
def erfc(input, *, out=None):
return 1.0 - ivy.erf(input, out=out)

0 comments on commit 58227ae

Please sign in to comment.