diff --git a/cubed/array_api/linalg.py b/cubed/array_api/linalg.py index 19842b6f..8a907a23 100644 --- a/cubed/array_api/linalg.py +++ b/cubed/array_api/linalg.py @@ -1,10 +1,10 @@ from typing import NamedTuple from cubed.array_api.array_object import Array - -# These functions are in both the main and linalg namespaces from cubed.array_api.data_type_functions import result_type from cubed.array_api.dtypes import _floating_dtypes + +# These functions are in both the main and linalg namespaces from cubed.array_api.linear_algebra_functions import ( # noqa: F401 matmul, matrix_transpose, @@ -12,7 +12,7 @@ vecdot, ) from cubed.backend_array_api import namespace as nxp -from cubed.core.ops import blockwise, general_blockwise, merge_chunks +from cubed.core.ops import blockwise, general_blockwise, merge_chunks, squeeze from cubed.utils import array_memory, get_item @@ -53,7 +53,7 @@ def qr(x, /, *, mode="reduced") -> QRResult: return QRResult(Q, R) -def tsqr(x, compute_svd=False, final_u=True): +def tsqr(x, compute_svd=False, finalize_svd=True): """Direct Tall-and-Skinny QR algorithm From: @@ -69,14 +69,15 @@ def tsqr(x, compute_svd=False, final_u=True): if _r1_is_too_big(R1): R1 = _rechunk_r1(R1) - Q2, R2, U, S, Vh = tsqr(R1, compute_svd=compute_svd, final_u=False) + Q2, R2, U, S, Vh = tsqr(R1, compute_svd=compute_svd, finalize_svd=False) else: Q2, R2, U, S, Vh = _qr_second_step(R1, compute_svd=compute_svd) Q, R = _qr_third_step(Q1, Q2), R2 - if compute_svd and final_u: + if compute_svd and finalize_svd: U = Q @ U # fourth step (SVD only) + S = squeeze(S, axis=1) # remove extra dim return Q, R, U, S, Vh @@ -143,7 +144,7 @@ def _qr_second_step(R1, compute_svd=False): else: U_shape = (n, n) U_chunks = U_shape - S_shape = (n,) + S_shape = (n, 1) # extra dim since multiple outputs must have same numblocks S_chunks = S_shape Vh_shape = (n, n) Vh_chunks = Vh_shape @@ -172,6 +173,7 @@ def _merge_into_single_chunk(x, split_every=4): def _qr2(a): Q, R = nxp.linalg.qr(a) U, S, Vh = nxp.linalg.svd(R) + S = S[:, nxp.newaxis] # add extra dim return Q, R, U, S, Vh @@ -216,7 +218,7 @@ def svd(x, /, *, full_matrices=True) -> SVDResult: raise ValueError("Cubed arrays only support using full_matrices=False") nb = x.numblocks - # TODO: what about nb[0] == nb[1] == 1 + # TODO: optimize case nb[0] == nb[1] == 1 if nb[0] > nb[1]: _, _, U, S, Vh = tsqr(x, compute_svd=True) truncate = x.shape[0] < x.shape[1]