From 096a32bca28885b64ad1f5dc3a121062996d3aea Mon Sep 17 00:00:00 2001 From: abhimanyu Date: Thu, 8 Dec 2022 14:20:43 +0530 Subject: [PATCH 1/2] impl svd in torch frotnend, all tests passing --- ivy/functional/frontends/torch/linalg.py | 5 ++ .../test_frontends/test_torch/test_linalg.py | 61 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 0a606e15a037e..f0eeb0fe9ca4f 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -52,3 +52,8 @@ def matrix_power(input, n, *, out=None): @to_ivy_arrays_and_back def matrix_rank(input, *, atol=None, rtol=None, hermitian=False, out=None): return ivy.astype(ivy.matrix_rank(input, atol=atol, rtol=rtol, out=out), ivy.int64) + + +@to_ivy_arrays_and_back +def svd(input,/,*,full_matrices=True): + return ivy.svd(input,compute_uv=True,full_matrices=full_matrices) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 8df8c7ac270f4..149a5db9d4030 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -3,8 +3,11 @@ import numpy as np from hypothesis import strategies as st + # local +import ivy import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import assert_all_close from ivy_tests.test_ivy.helpers import handle_frontend_test @@ -277,3 +280,61 @@ def test_matrix_rank( rtol=rtol, atol=atol, ) + + +# svd +@handle_frontend_test( + fn_tree="torch.linalg.svd", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])) + ), + full_matrices=st.booleans() +) +def test_torch_svd( + *, + dtype_and_x, + full_matrices, + with_out, + num_positional_args, + as_variable, + native_array, + frontend, + fn_tree, + on_device, +): + dtype, x = dtype_and_x + x = np.asarray(x[0], dtype=dtype[0]) + # make symmetric positive-definite beforehand + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=dtype, + as_variable_flags=as_variable, + with_out=with_out, + num_positional_args=num_positional_args, + native_array_flags=native_array, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + atol=1e-03, + rtol=1e-05, + input=x, + full_matrices=full_matrices + ) + ret = [ivy.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] + + u, s, vh = ret + frontend_u, frontend_s, frontend_vh = frontend_ret + + assert_all_close( + ret_np=u @ np.diag(s) @ vh, + ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh, + rtol=1e-2, + atol=1e-2, + ground_truth_backend=frontend, + ) + From c43eba574bd8c3db3c57b1da79488e4cbd8be254 Mon Sep 17 00:00:00 2001 From: abhimanyu Date: Thu, 8 Dec 2022 14:28:13 +0530 Subject: [PATCH 2/2] ran linter on files --- ivy/functional/frontends/torch/linalg.py | 4 ++-- .../test_ivy/test_frontends/test_torch/test_linalg.py | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index f0eeb0fe9ca4f..a0bf50909cc61 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -55,5 +55,5 @@ def matrix_rank(input, *, atol=None, rtol=None, hermitian=False, out=None): @to_ivy_arrays_and_back -def svd(input,/,*,full_matrices=True): - return ivy.svd(input,compute_uv=True,full_matrices=full_matrices) +def svd(input, /, *, full_matrices=True): + return ivy.svd(input, compute_uv=True, full_matrices=full_matrices) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 149a5db9d4030..96c11c31507af 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -289,9 +289,9 @@ def test_matrix_rank( available_dtypes=helpers.get_dtypes("float"), min_value=0, max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])) + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), ), - full_matrices=st.booleans() + full_matrices=st.booleans(), ) def test_torch_svd( *, @@ -307,7 +307,7 @@ def test_torch_svd( ): dtype, x = dtype_and_x x = np.asarray(x[0], dtype=dtype[0]) - # make symmetric positive-definite beforehand + # make symmetric positive definite beforehand x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 ret, frontend_ret = helpers.test_frontend_function( input_dtypes=dtype, @@ -322,7 +322,7 @@ def test_torch_svd( atol=1e-03, rtol=1e-05, input=x, - full_matrices=full_matrices + full_matrices=full_matrices, ) ret = [ivy.to_numpy(x) for x in ret] frontend_ret = [np.asarray(x) for x in frontend_ret] @@ -337,4 +337,3 @@ def test_torch_svd( atol=1e-2, ground_truth_backend=frontend, ) -