Skip to content

Commit

Permalink
feat: added the erfinv function to ivy's experimental API (#28159)
Browse files Browse the repository at this point in the history
  • Loading branch information
vedpatwardhan authored Feb 2, 2024
1 parent 6f32c6a commit e786e88
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 1 deletion.
31 changes: 31 additions & 0 deletions ivy/data_classes/array/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,3 +1191,34 @@ def erfc(
ivy.array([1.00000000e+00, 1.84270084e+00, 2.80259693e-45])
"""
return ivy.erfc(self._data, out=out)

def erfinv(
self: ivy.Array,
/,
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""ivy.Array instance method variant of ivy.erfinv. This method simply
wraps the function, and so the docstring for ivy.erfinv also applies to
this method with minimal changes.
Parameters
----------
self
Input array with real or complex valued argument.
out
Alternate output array in which to place the result.
The default is None.
Returns
-------
ret
Values of the inverse error function.
Examples
--------
>>> x = ivy.array([0, -1., 10.])
>>> x.erfinv()
ivy.array([1.00000000e+00, 1.84270084e+00, 2.80259693e-45])
"""
return ivy.erfinv(self._data, out=out)
91 changes: 91 additions & 0 deletions ivy/data_classes/container/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3491,3 +3491,94 @@ def erfc(
}
"""
return self.static_erfc(self, out=out)

@staticmethod
def static_erfinv(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""ivy.Container static method variant of ivy.erfinv. This method
simply wraps the function, and so the docstring for ivy.erfinv also
applies to this method with minimal changes.
Parameters
----------
x
The container whose array contains real or complex valued argument.
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to.
Returns
-------
ret
container with values of the inverse error function.
Examples
--------
>>> x = ivy.Container(a=ivy.array([1., 2.]), b=ivy.array([-3., -4.]))
>>> ivy.Container.static_erfinv(x)
{
a: ivy.array([0.15729921, 0.00467773]),
b: ivy.array([1.99997795, 2.])
}
"""
return ContainerBase.cont_multi_map_in_function(
"erfinv",
x,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def erfinv(
self: ivy.Container,
/,
*,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""ivy.Container instance method variant of ivy.erfinv. This method
simply wraps the function, and so the docstring for ivy.erfinv also
applies to this method with minimal changes.
Parameters
----------
self
The container whose array contains real or complex valued argument.
out
optional output container, for writing the result to.
Returns
-------
ret
container with values of the inverse error function.
Examples
--------
With one :class:`ivy.Container` input:
>>> x = ivy.Container(a=ivy.array([1., 2., 3.]), b=ivy.array([-1., -2., -3.]))
>>> x.erfinv()
{
a: ivy.array([1.57299206e-01, 4.67773480e-03, 2.20904985e-05]),
b: ivy.array([1.84270084, 1.99532223, 1.99997795])
}
"""
return self.static_erfinv(self, out=out)
9 changes: 9 additions & 0 deletions ivy/functional/backends/jax/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,12 @@ def erfc(
out: Optional[JaxArray] = None,
) -> JaxArray:
return js.special.erfc(x)


def erfinv(
x: JaxArray,
/,
*,
out: Optional[JaxArray] = None,
) -> JaxArray:
return js.special.erfinv(x)
12 changes: 12 additions & 0 deletions ivy/functional/backends/numpy/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,15 @@ def is_pos_inf(op):
return np.where(underflow, result_underflow, result_no_underflow).astype(
input_dtype
)


# TODO: Remove this once native function is available.
# Compute an approximation of the error function complement (1 - erf(x)).
def erfinv(
x: np.ndarray,
/,
*,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
with ivy.ArrayMode(False):
return np.sqrt(2) * erfc(x)
10 changes: 10 additions & 0 deletions ivy/functional/backends/paddle/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,3 +815,13 @@ def is_pos_inf(op):
result = paddle.squeeze(result, axis=-1)

return result


@with_supported_dtypes(
{"2.6.0 and below": ("float32", "float64")},
backend_version,
)
def erfinv(
x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None
) -> paddle.Tensor:
return paddle.erfinv(x)
10 changes: 10 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,3 +566,13 @@ def erfc(
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
return tf.math.erfc(x)


@with_supported_dtypes({"2.15.0 and below": ("float",)}, backend_version)
def erfinv(
x: Union[tf.Tensor, tf.Variable],
/,
*,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
return tf.math.erfinv(x)
13 changes: 13 additions & 0 deletions ivy/functional/backends/torch/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,16 @@ def erfc(
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.special.erfc(x)


@with_unsupported_dtypes({"2.1.2 and below": ("float16",)}, backend_version)
def erfinv(
x: torch.Tensor,
/,
*,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.special.erfinv(x, out=out)


erfinv.support_native_out = True
36 changes: 36 additions & 0 deletions ivy/functional/ivy/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,3 +1637,39 @@ def erfc(
ivy.array([0.00467773, 1.84270084, 1. ])
"""
return ivy.current_backend(x).erfc(x, out=out)


@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
@handle_out_argument
@to_native_arrays_and_back
@handle_device
def erfinv(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
out: Optional[ivy.Array] = None,
):
"""Compute the inverse error function.
Parameters
----------
x
Input array of real or complex valued argument.
out
optional output array, for writing the result to.
It must have a shape that the inputs broadcast to.
Returns
-------
ret
Values of the inverse error function.
Examples
--------
>>> x = ivy.array([0, 0.5, -1.])
>>> ivy.erfinv(x)
ivy.array([0.0000, 0.4769, -inf])
"""
return ivy.current_backend(x).erfinv(x, out=out)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from hypothesis import strategies as st
from hypothesis import assume, strategies as st

# local
import ivy
Expand Down Expand Up @@ -510,6 +510,42 @@ def test_erfc(
)


# erfinv
@handle_test(
fn_tree="functional.ivy.experimental.erfinv",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=-1,
max_value=1,
abs_smallest_val=1e-05,
),
)
def test_erfinv(
*,
dtype_and_x,
backend_fw,
test_flags,
fn_name,
on_device,
):
input_dtype, x = dtype_and_x
if on_device == "cpu":
assume("float16" not in input_dtype and "bfloat16" not in input_dtype)
test_values = True
if backend_fw == "numpy":
# the numpy backend requires an approximation which doesn't pass the value tests
test_values = False
helpers.test_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_name=fn_name,
on_device=on_device,
test_values=test_values,
x=x[0],
)


# fix
@handle_test(
fn_tree="functional.ivy.experimental.fix",
Expand Down

0 comments on commit e786e88

Please sign in to comment.