From ba08c5668926cef35d472524fcce9280c56b5833 Mon Sep 17 00:00:00 2001 From: Haris Mahmood <70361308+hmahmood24@users.noreply.github.com> Date: Sat, 3 Feb 2024 20:35:06 +0000 Subject: [PATCH] feat(frontends): Added erfinv and erfinv_ to torch frontend along with tests --- .../frontends/torch/miscellaneous_ops.py | 5 ++ ivy/functional/frontends/torch/tensor.py | 12 +++ .../test_torch/test_miscellaneous_ops.py | 31 ++++++++ .../test_frontends/test_torch/test_tensor.py | 77 +++++++++++++++++++ 4 files changed, 125 insertions(+) diff --git a/ivy/functional/frontends/torch/miscellaneous_ops.py b/ivy/functional/frontends/torch/miscellaneous_ops.py index ee0143f3b70df..b9d15f8a976df 100644 --- a/ivy/functional/frontends/torch/miscellaneous_ops.py +++ b/ivy/functional/frontends/torch/miscellaneous_ops.py @@ -239,6 +239,11 @@ def einsum(equation, *operands): return ivy.einsum(equation, *operands) +@to_ivy_arrays_and_back +def erfinv(input, *, out=None): + return ivy.erfinv(input, out=out) + + @to_ivy_arrays_and_back def finfo(dtype): return ivy.finfo(dtype) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index ec96b640d5092..22195e6f3008d 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -2279,6 +2279,18 @@ def _set(index): return self + @with_unsupported_dtypes({"2.1.2 and below": ("float16", "complex")}, "torch") + def erfinv(self, *, out=None): + return torch_frontend.erfinv(self, out=out) + + @with_unsupported_dtypes({"2.1.2 and below": ("float16", "complex")}, "torch") + def erfinv_(self, *, out=None): + ret = self.erfinv(out=out) + self._ivy_array = ivy.inplace_update( + self._ivy_array, ivy.astype(ret.ivy_array, self._ivy_array.dtype) + ) + return self + # Method aliases absolute, absolute_ = abs, abs_ clip, clip_ = clamp, clamp_ diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py index 5b15e0ec81308..46eb2ada3d210 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py @@ -1017,6 +1017,37 @@ def test_torch_einsum( ) +# erfinv +@handle_frontend_test( + fn_tree="torch.erfinv", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-1, + max_value=1, + abs_smallest_val=1e-05, + ), +) +def test_torch_erfinv( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + ) + + @handle_frontend_test( fn_tree="torch.flatten", dtype_input_axes=helpers.dtype_values_axis( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index a0e89163f2f76..b5490c5ac5f5e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -6374,6 +6374,83 @@ def test_torch_erf_( ) +# erfinv_ tests +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="erfinv_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-1, + max_value=1, + abs_smallest_val=1e-05, + ), +) +def test_torch_erfinv( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + +# erfinv_ tests +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="erfinv_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-1, + max_value=1, + abs_smallest_val=1e-05, + ), + test_inplace=st.just(True), +) +def test_torch_erfinv_( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # exp @handle_frontend_method( class_tree=CLASS_TREE,