Skip to content

Commit

Permalink
Merge pull request ivy-llc#9924 from ra9hur/master
Browse files Browse the repository at this point in the history
Torch FE tensorsolve feature added
  • Loading branch information
karalleyna authored Jan 27, 2023
2 parents 4f05436 + 40827f6 commit cfee4d7
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
6 changes: 6 additions & 0 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,9 @@ def eig(input, *, out=None):
@with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, "torch")
def solve(input, other, *, out=None):
return ivy.solve(input, other, out=out)


@to_ivy_arrays_and_back
@with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, "torch")
def tensorsolve(A, B, dims=None, *, out=None):
return ivy.tensorsolve(A, B, axes=dims, out=out)
70 changes: 70 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,73 @@ def test_torch_tensorinv(
input=x,
ind=ind,
)


# tensorsolve
@st.composite
def _get_solve_matrices(draw):
# batch_shape, random_size, shared

# float16 causes a crash when filtering out matrices
# for which `np.linalg.cond` is large.
input_dtype_strategy = st.shared(
st.sampled_from(draw(helpers.get_dtypes("float"))).filter(
lambda x: "float16" not in x
),
key="shared_dtype",
)
input_dtype = draw(input_dtype_strategy)

dim = draw(helpers.ints(min_value=2, max_value=5))

first_matrix = draw(
helpers.array_values(
dtype=input_dtype,
shape=(dim, dim, dim, dim),
min_value=1.2,
max_value=5,
).filter(
lambda x: np.linalg.cond(x.reshape((dim**2, dim**2)))
< 1 / sys.float_info.epsilon
)
)

second_matrix = draw(
helpers.array_values(
dtype=input_dtype,
shape=(dim, dim),
min_value=1.2,
max_value=3,
).filter(
lambda x: np.linalg.cond(x.reshape((dim, dim))) < 1 / sys.float_info.epsilon
)
)

return input_dtype, first_matrix, second_matrix


@handle_frontend_test(
fn_tree="torch.linalg.tensorsolve",
a_and_b=_get_solve_matrices(),
)
def test_torch_tensorsolve(
*,
a_and_b,
on_device,
fn_tree,
frontend,
test_flags,
):
input_dtype, A, B = a_and_b
test_flags.num_positional_args = 2
helpers.test_frontend_function(
input_dtypes=[input_dtype],
test_flags=test_flags,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
atol=1e-3,
rtol=1e-3,
A=A,
B=B,
)

0 comments on commit cfee4d7

Please sign in to comment.