Skip to content

Commit

Permalink
added solve_triangular to frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
alia2109 committed Aug 15, 2023
1 parent 19f9818 commit 21c6ecc
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,51 @@ def solve(A, B, *, left=True, out=None):
return ivy.solve(A, B, out=out)


@to_ivy_arrays_and_back
@with_supported_dtypes(
{"2.0.1 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
)
def solve_triangular(A, B, *, upper, left=True, unitriangular=False, out=None):
# TOD0: Implement left and out
X = ivy.zeros_like(B)

#check if A is triangular
if ivy.all(ivy.triu(A)) != 0 or ivy.all(ivy.tril(A)) != 0:
raise RuntimeError("Matrix is not triangular")

#check if matrix is inversible
for row_idx, _ in enumerate(A):
if A[row_idx][row_idx] == 0:
raise RuntimeError("Matrix is not invertible")

#need to make leading variables 1
if unitriangular == False:
if upper == True:
for row_idx, row in enumerate(A):
div = A[row_idx][row_idx]
row[:] = [x / div for x in row]
B[row_idx][:] = [y / div for y in B[row_idx]]
else:
for row_idx, row in enumerate(A):
div = A[row_idx][0]
row[:] = [x / div for x in row]
B[row_idx][:] = [y / div for y in B[row_idx]]

#equation for lower
if upper == False:
for i in range(len(A)):
X[i] = (B[i] - ivy.dot(A[i][:i], X[:i])) / A[i][i]

return X

#equation for upper
if upper == True:
for i in range(len(A) - 1, -1, -1):
X[i] = (B[i] - ivy.dot(A[i][i:], X[i:])) / A[i][i]

return X


@to_ivy_arrays_and_back
@with_supported_dtypes(
{"2.0.1 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
Expand Down

0 comments on commit 21c6ecc

Please sign in to comment.