From 58227ae570c8562ad51a135e165f35e9af3415a6 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Tue, 13 Feb 2024 21:17:43 +0000 Subject: [PATCH] feat: add torch.special.erfc in torch frontend and turn the implementation in pointwise ops into an alias for the special func --- erfissues.py | 5 +++++ ivy/functional/frontends/torch/__init__.py | 1 + ivy/functional/frontends/torch/pointwise_ops.py | 9 +++------ ivy/functional/frontends/torch/special/__init__.py | 1 + .../frontends/torch/special/special_funcs.py | 13 +++++++++++++ 5 files changed, 23 insertions(+), 6 deletions(-) create mode 100644 erfissues.py create mode 100644 ivy/functional/frontends/torch/special/__init__.py create mode 100644 ivy/functional/frontends/torch/special/special_funcs.py diff --git a/erfissues.py b/erfissues.py new file mode 100644 index 0000000000000..3331f8968a491 --- /dev/null +++ b/erfissues.py @@ -0,0 +1,5 @@ +import torch +print(torch.special.erfc(torch.tensor([0, -1., 10.]))) +from ivy.functional.frontends import torch as ivy_torch +print(ivy_torch.special.erfc(ivy_torch.tensor([0, -1., 10.]))) +print(ivy_torch.erfc(ivy_torch.tensor([0, -1., 10.]))) \ No newline at end of file diff --git a/ivy/functional/frontends/torch/__init__.py b/ivy/functional/frontends/torch/__init__.py index af031aa7c9d26..e9f0986016442 100644 --- a/ivy/functional/frontends/torch/__init__.py +++ b/ivy/functional/frontends/torch/__init__.py @@ -261,6 +261,7 @@ def promote_types_of_torch_inputs( from . import nn from .nn.functional import softmax, relu, lstm +from . import special from . import tensor from .tensor import * from . import blas_and_lapack_ops diff --git a/ivy/functional/frontends/torch/pointwise_ops.py b/ivy/functional/frontends/torch/pointwise_ops.py index 7ae00ab085319..9c0c4601c3522 100644 --- a/ivy/functional/frontends/torch/pointwise_ops.py +++ b/ivy/functional/frontends/torch/pointwise_ops.py @@ -10,6 +10,9 @@ ) +erfc = torch_frontend.special.erfc + + @to_ivy_arrays_and_back def abs(input, *, out=None): return ivy.abs(input, out=out) @@ -189,12 +192,6 @@ def erf(input, *, out=None): return ivy.erf(input, out=out) -@with_unsupported_dtypes({"2.2 and below": ("float16", "complex")}, "torch") -@to_ivy_arrays_and_back -def erfc(input, *, out=None): - return 1.0 - ivy.erf(input, out=out) - - @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @to_ivy_arrays_and_back def exp(input, *, out=None): diff --git a/ivy/functional/frontends/torch/special/__init__.py b/ivy/functional/frontends/torch/special/__init__.py new file mode 100644 index 0000000000000..7dc36a681a948 --- /dev/null +++ b/ivy/functional/frontends/torch/special/__init__.py @@ -0,0 +1 @@ +from .special_funcs import * diff --git a/ivy/functional/frontends/torch/special/special_funcs.py b/ivy/functional/frontends/torch/special/special_funcs.py new file mode 100644 index 0000000000000..a48c5a2bc006d --- /dev/null +++ b/ivy/functional/frontends/torch/special/special_funcs.py @@ -0,0 +1,13 @@ +import ivy +from ivy.func_wrapper import ( + with_unsupported_dtypes, +) +from ivy.functional.frontends.torch.func_wrapper import ( + to_ivy_arrays_and_back, +) + + +@with_unsupported_dtypes({"2.2 and below": ("float16", "complex")}, "torch") +@to_ivy_arrays_and_back +def erfc(input, *, out=None): + return 1.0 - ivy.erf(input, out=out)