Skip to content

Commit

Permalink
replace the unimplemented tensor.mH used to the implemented adjoint, …
Browse files Browse the repository at this point in the history
…fixed the wrong shape and dtype of return. Now there are somehow numerial difference between return of groundtruth torch.svd and ivy.svd
  • Loading branch information
Jin Wang committed Jun 30, 2024
1 parent 46d180a commit dce10a6
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions ivy/functional/frontends/torch/blas_and_lapack_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,17 @@ def slogdet(A, *, out=None):
@to_ivy_arrays_and_back
def svd(input, some=True, compute_uv=True, *, out=None):
# TODO: add handling for driver
ret = ivy.svd(input, full_matrices=not some, compute_uv=compute_uv)
retu = ivy.svd(input, full_matrices=not some, compute_uv=compute_uv)
results = namedtuple("svd", ['U', 'S', 'V'])
if compute_uv:
ret = results(ret.U, ret.S, ret.Vh.mH)
ret = results(retu[0], retu[1], ivy.adjoint(retu[2]))
else:
shape = input.shape
m = shape[-2]
n = shape[-1]
ret = results(ivy.zeros((m,m), device=input.device), ret.S, ivy.zeros((n,n), device=input.device))
shape = list(input.shape)
shape1 = shape
shape2 = shape
shape1[-2] = shape[-1]
shape2[-1] = shape[-2]
ret = results(ivy.zeros(shape1, device=input.device, dtype=input.dtype), retu[0], ivy.zeros(shape2, device=input.device, dtype=input.dtype))
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret
Expand Down

0 comments on commit dce10a6

Please sign in to comment.