From 1176454364335ae254c02129089965684707907d Mon Sep 17 00:00:00 2001 From: "rakshith.nagaraj6@gmail.com" Date: Fri, 7 Jul 2023 20:47:31 +0100 Subject: [PATCH 1/4] added lu factor to torch frontend and tested --- ivy/functional/frontends/torch/linalg.py | 4 +++ .../test_frontends/test_torch/test_linalg.py | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 0219a658b77cd..6ba370c62898a 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -280,3 +280,7 @@ def solve_ex(A, B, *, left=True, check_errors=False, out=None): info = ivy.ones(A.shape[:-2], dtype=ivy.int32) return result, info + +@to_ivy_arrays_and_back +def lu(a, *, pivot=True, out=None): + return ivy.lu_factor(a, pivot=pivot, out=out) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index c0a666487a7f3..79ae41ba976ac 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -1355,3 +1355,29 @@ def test_torch_solve_ex( B=other, check_errors=check, ) + +@handle_frontend_test( + fn_tree="torch.linalg.lu_factor", input_dtype_and_input=_lu_factor_helper() +) +def test_torch_lu(*, input_dtype_and_input, on_device, fn_tree, frontend, test_flags): + dtype, input = input_dtype_and_input + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=dtype, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + A=input, + ) + ret = [ivy.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] + + LU, pivot = ret + frontend_LU, frontend_pivot = frontend_ret + + assert_all_close( + ret_np=[LU, pivot], + ret_from_gt_np=[frontend_LU, frontend_pivot], + ground_truth_backend=frontend, + ) From 909037b762e305c37badcaa2f7437469a9c6836d Mon Sep 17 00:00:00 2001 From: "rakshith.nagaraj6@gmail.com" Date: Sat, 22 Jul 2023 17:31:35 +0100 Subject: [PATCH 2/4] added missing helper function for lu_factor --- .../test_frontends/test_torch/test_linalg.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 79ae41ba976ac..e43090fcbb357 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -1356,6 +1356,36 @@ def test_torch_solve_ex( check_errors=check, ) + +@st.composite +def _lu_factor_helper(draw): + ip_dtype = draw(helpers.get_dtypes("float")) + + dim1 = draw(helpers.ints(min_value=2, max_value=3)) + dim2 = draw(helpers.ints(min_value=2, max_value=3)) + batch_dim = 0 + + if batch_dim == 0: + input_matrix = draw( + helpers.array_values( + dtype=ip_dtype[0], + shape=(dim1, dim2), + min_value=-1, + max_value=1, + ) + ) + else: + input_matrix = draw( + helpers.array_values( + dtype=ip_dtype[0], + shape=(batch_dim, dim1, dim2), + min_value=-1, + max_value=1, + ) + ) + + return input_dtype, input_matrix + @handle_frontend_test( fn_tree="torch.linalg.lu_factor", input_dtype_and_input=_lu_factor_helper() ) From c2ba40547affc6a0e3955afa5f77dbbaa49dc04b Mon Sep 17 00:00:00 2001 From: "rakshith.nagaraj6@gmail.com" Date: Tue, 8 Aug 2023 21:16:02 +0100 Subject: [PATCH 3/4] corrected name error in lu_factor helper function --- ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index e43090fcbb357..01d26dcbce626 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -1384,7 +1384,7 @@ def _lu_factor_helper(draw): ) ) - return input_dtype, input_matrix + return ip_dtype, input_matrix @handle_frontend_test( fn_tree="torch.linalg.lu_factor", input_dtype_and_input=_lu_factor_helper() From aa25cbc4028b4beb218552b509b120358a75f174 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Sat, 23 Dec 2023 19:16:49 +0000 Subject: [PATCH 4/4] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_frontends/test_torch/test_linalg.py | 113 +++++++++--------- 1 file changed, 57 insertions(+), 56 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 9a2ad4cf3edf6..203719a9dcc0b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -190,6 +190,36 @@ def _get_solve_matrices(draw): return input_dtype, first_matrix, second_matrix +@st.composite +def _lu_factor_helper(draw): + ip_dtype = draw(helpers.get_dtypes("float")) + + dim1 = draw(helpers.ints(min_value=2, max_value=3)) + dim2 = draw(helpers.ints(min_value=2, max_value=3)) + batch_dim = 0 + + if batch_dim == 0: + input_matrix = draw( + helpers.array_values( + dtype=ip_dtype[0], + shape=(dim1, dim2), + min_value=-1, + max_value=1, + ) + ) + else: + input_matrix = draw( + helpers.array_values( + dtype=ip_dtype[0], + shape=(batch_dim, dim1, dim2), + min_value=-1, + max_value=1, + ) + ) + + return ip_dtype, input_matrix + + # tensorinv @st.composite def _tensorinv_helper(draw): @@ -674,6 +704,33 @@ def test_torch_inv_ex( ) +@handle_frontend_test( + fn_tree="torch.linalg.lu_factor", input_dtype_and_input=_lu_factor_helper() +) +def test_torch_lu(*, input_dtype_and_input, on_device, fn_tree, frontend, test_flags): + dtype, input = input_dtype_and_input + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=dtype, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + A=input, + ) + ret = [ivy.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] + + LU, pivot = ret + frontend_LU, frontend_pivot = frontend_ret + + assert_all_close( + ret_np=[LU, pivot], + ret_from_gt_np=[frontend_LU, frontend_pivot], + ground_truth_backend=frontend, + ) + + # lu_factor @handle_frontend_test( fn_tree="torch.linalg.lu_factor", @@ -1346,59 +1403,3 @@ def test_torch_vector_norm( keepdim=kd, dtype=dtype[0], ) - - -@st.composite -def _lu_factor_helper(draw): - ip_dtype = draw(helpers.get_dtypes("float")) - - dim1 = draw(helpers.ints(min_value=2, max_value=3)) - dim2 = draw(helpers.ints(min_value=2, max_value=3)) - batch_dim = 0 - - if batch_dim == 0: - input_matrix = draw( - helpers.array_values( - dtype=ip_dtype[0], - shape=(dim1, dim2), - min_value=-1, - max_value=1, - ) - ) - else: - input_matrix = draw( - helpers.array_values( - dtype=ip_dtype[0], - shape=(batch_dim, dim1, dim2), - min_value=-1, - max_value=1, - ) - ) - - return ip_dtype, input_matrix - -@handle_frontend_test( - fn_tree="torch.linalg.lu_factor", input_dtype_and_input=_lu_factor_helper() -) -def test_torch_lu(*, input_dtype_and_input, on_device, fn_tree, frontend, test_flags): - dtype, input = input_dtype_and_input - ret, frontend_ret = helpers.test_frontend_function( - input_dtypes=dtype, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - A=input, - ) - ret = [ivy.to_numpy(x) for x in ret] - frontend_ret = [np.asarray(x) for x in frontend_ret] - - LU, pivot = ret - frontend_LU, frontend_pivot = frontend_ret - - assert_all_close( - ret_np=[LU, pivot], - ret_from_gt_np=[frontend_LU, frontend_pivot], - ground_truth_backend=frontend, - )