diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 6f9abd7afabbb..bc94e95f6661d 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -175,3 +175,9 @@ def eig(input, *, out=None): @with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, "torch") def solve(input, other, *, out=None): return ivy.solve(input, other, out=out) + + +@to_ivy_arrays_and_back +@with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, "torch") +def tensorsolve(A, B, dims=None, *, out=None): + return ivy.tensorsolve(A, B, axes=dims, out=out) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 9efec16964118..e6f3c1569d2d7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -777,3 +777,73 @@ def test_torch_tensorinv( input=x, ind=ind, ) + + +# tensorsolve +@st.composite +def _get_solve_matrices(draw): + # batch_shape, random_size, shared + + # float16 causes a crash when filtering out matrices + # for which `np.linalg.cond` is large. + input_dtype_strategy = st.shared( + st.sampled_from(draw(helpers.get_dtypes("float"))).filter( + lambda x: "float16" not in x + ), + key="shared_dtype", + ) + input_dtype = draw(input_dtype_strategy) + + dim = draw(helpers.ints(min_value=2, max_value=5)) + + first_matrix = draw( + helpers.array_values( + dtype=input_dtype, + shape=(dim, dim, dim, dim), + min_value=1.2, + max_value=5, + ).filter( + lambda x: np.linalg.cond(x.reshape((dim**2, dim**2))) + < 1 / sys.float_info.epsilon + ) + ) + + second_matrix = draw( + helpers.array_values( + dtype=input_dtype, + shape=(dim, dim), + min_value=1.2, + max_value=3, + ).filter( + lambda x: np.linalg.cond(x.reshape((dim, dim))) < 1 / sys.float_info.epsilon + ) + ) + + return input_dtype, first_matrix, second_matrix + + +@handle_frontend_test( + fn_tree="torch.linalg.tensorsolve", + a_and_b=_get_solve_matrices(), +) +def test_torch_tensorsolve( + *, + a_and_b, + on_device, + fn_tree, + frontend, + test_flags, +): + input_dtype, A, B = a_and_b + test_flags.num_positional_args = 2 + helpers.test_frontend_function( + input_dtypes=[input_dtype], + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + atol=1e-3, + rtol=1e-3, + A=A, + B=B, + )