Skip to content

Commit

Permalink
When calculating SVD S array temporarily add an extra dim to comply
Browse files Browse the repository at this point in the history
with restriction on multiple outputs having to have same numblocks
  • Loading branch information
tomwhite committed Nov 8, 2024
1 parent af965e5 commit 1ca532d
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions cubed/array_api/linalg.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import NamedTuple

from cubed.array_api.array_object import Array

# These functions are in both the main and linalg namespaces
from cubed.array_api.data_type_functions import result_type
from cubed.array_api.dtypes import _floating_dtypes

# These functions are in both the main and linalg namespaces
from cubed.array_api.linear_algebra_functions import ( # noqa: F401
matmul,
matrix_transpose,
tensordot,
vecdot,
)
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import blockwise, general_blockwise, merge_chunks
from cubed.core.ops import blockwise, general_blockwise, merge_chunks, squeeze
from cubed.utils import array_memory, get_item


Expand Down Expand Up @@ -53,7 +53,7 @@ def qr(x, /, *, mode="reduced") -> QRResult:
return QRResult(Q, R)


def tsqr(x, compute_svd=False, final_u=True):
def tsqr(x, compute_svd=False, finalize_svd=True):
"""Direct Tall-and-Skinny QR algorithm
From:
Expand All @@ -69,14 +69,15 @@ def tsqr(x, compute_svd=False, final_u=True):

if _r1_is_too_big(R1):
R1 = _rechunk_r1(R1)
Q2, R2, U, S, Vh = tsqr(R1, compute_svd=compute_svd, final_u=False)
Q2, R2, U, S, Vh = tsqr(R1, compute_svd=compute_svd, finalize_svd=False)
else:
Q2, R2, U, S, Vh = _qr_second_step(R1, compute_svd=compute_svd)

Q, R = _qr_third_step(Q1, Q2), R2

if compute_svd and final_u:
if compute_svd and finalize_svd:
U = Q @ U # fourth step (SVD only)
S = squeeze(S, axis=1) # remove extra dim

return Q, R, U, S, Vh

Expand Down Expand Up @@ -143,7 +144,7 @@ def _qr_second_step(R1, compute_svd=False):
else:
U_shape = (n, n)
U_chunks = U_shape
S_shape = (n,)
S_shape = (n, 1) # extra dim since multiple outputs must have same numblocks
S_chunks = S_shape
Vh_shape = (n, n)
Vh_chunks = Vh_shape
Expand Down Expand Up @@ -172,6 +173,7 @@ def _merge_into_single_chunk(x, split_every=4):
def _qr2(a):
Q, R = nxp.linalg.qr(a)
U, S, Vh = nxp.linalg.svd(R)
S = S[:, nxp.newaxis] # add extra dim
return Q, R, U, S, Vh


Expand Down Expand Up @@ -216,7 +218,7 @@ def svd(x, /, *, full_matrices=True) -> SVDResult:
raise ValueError("Cubed arrays only support using full_matrices=False")

nb = x.numblocks
# TODO: what about nb[0] == nb[1] == 1
# TODO: optimize case nb[0] == nb[1] == 1
if nb[0] > nb[1]:
_, _, U, S, Vh = tsqr(x, compute_svd=True)
truncate = x.shape[0] < x.shape[1]
Expand Down

0 comments on commit 1ca532d

Please sign in to comment.