Skip to content

Commit

Permalink
SVD
Browse files Browse the repository at this point in the history
Add svdvals implementation

SVD for when not tall and skinny
  • Loading branch information
tomwhite committed Nov 8, 2024
1 parent a2b1053 commit af965e5
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 20 deletions.
4 changes: 2 additions & 2 deletions api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ A few of the [linear algebra extension](https://data-apis.org/array-api/2022.12/
| | `qr` | :white_check_mark: | | |
| | `slogdet` | :x: | | |
| | `solve` | :x: | | |
| | `svd` | :x: | | |
| | `svdvals` | :x: | | |
| | `svd` | :white_check_mark: | | |
| | `svdvals` | :white_check_mark: | | |
| | `tensordot` | :white_check_mark: | | |
| | `trace` | :x: | | |
| | `vecdot` | :white_check_mark: | | |
Expand Down
97 changes: 79 additions & 18 deletions cubed/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class QRResult(NamedTuple):
R: Array


class SVDResult(NamedTuple):
U: Array
S: Array
Vh: Array


def qr(x, /, *, mode="reduced") -> QRResult:
if x.ndim != 2:
raise ValueError("qr requires x to have 2 dimensions.")
Expand All @@ -43,10 +49,11 @@ def qr(x, /, *, mode="reduced") -> QRResult:
"Consider rechunking so there is only a single column chunk."
)

return tsqr(x)
Q, R, _, _, _ = tsqr(x)
return QRResult(Q, R)


def tsqr(x) -> QRResult:
def tsqr(x, compute_svd=False, final_u=True):
"""Direct Tall-and-Skinny QR algorithm
From:
Expand All @@ -57,18 +64,21 @@ def tsqr(x) -> QRResult:
https://arxiv.org/abs/1301.1071
"""

# follows Algorithm 2 from Benson et al
# follows Algorithm 2 from Benson et al, modified for SVD
Q1, R1 = _qr_first_step(x)

if _r1_is_too_big(R1):
R1 = _rechunk_r1(R1)
Q2, R2 = tsqr(R1)
Q2, R2, U, S, Vh = tsqr(R1, compute_svd=compute_svd, final_u=False)
else:
Q2, R2 = _qr_second_step(R1)
Q2, R2, U, S, Vh = _qr_second_step(R1, compute_svd=compute_svd)

Q, R = _qr_third_step(Q1, Q2), R2

return QRResult(Q, R)
if compute_svd and final_u:
U = Q @ U # fourth step (SVD only)

return Q, R, U, S, Vh


def _qr_first_step(A):
Expand Down Expand Up @@ -108,7 +118,7 @@ def _rechunk_r1(R1, split_every=4):
return merge_chunks(R1, chunks=chunks)


def _qr_second_step(R1):
def _qr_second_step(R1, compute_svd=False):
R1_single = _merge_into_single_chunk(R1)

Q2_shape = R1.shape
Expand All @@ -117,17 +127,38 @@ def _qr_second_step(R1):
n = R1.shape[1]
R2_shape = (n, n)
R2_chunks = R2_shape # single chunk
# qr implementation creates internal array buffers
extra_projected_mem = R1_single.chunkmem * 4
Q2, R2 = map_blocks_multiple_outputs(
nxp.linalg.qr,
R1_single,
shapes=[Q2_shape, R2_shape],
dtypes=[R1.dtype, R1.dtype],
chunkss=[Q2_chunks, R2_chunks],
extra_projected_mem=extra_projected_mem,
)
return QRResult(Q2, R2)

if not compute_svd:
# qr implementation creates internal array buffers
extra_projected_mem = R1_single.chunkmem * 4
Q2, R2 = map_blocks_multiple_outputs(
nxp.linalg.qr,
R1_single,
shapes=[Q2_shape, R2_shape],
dtypes=[R1.dtype, R1.dtype],
chunkss=[Q2_chunks, R2_chunks],
extra_projected_mem=extra_projected_mem,
)
return Q2, R2, None, None, None
else:
U_shape = (n, n)
U_chunks = U_shape
S_shape = (n,)
S_chunks = S_shape
Vh_shape = (n, n)
Vh_chunks = Vh_shape

# qr implementation creates internal array buffers
extra_projected_mem = R1_single.chunkmem * 4
Q2, R2, U, S, Vh = map_blocks_multiple_outputs(
_qr2,
R1_single,
shapes=[Q2_shape, R2_shape, U_shape, S_shape, Vh_shape],
dtypes=[R1.dtype, R1.dtype, R1.dtype, R1.dtype, R1.dtype],
chunkss=[Q2_chunks, R2_chunks, U_chunks, S_chunks, Vh_chunks],
extra_projected_mem=extra_projected_mem,
)
return Q2, R2, U, S, Vh


def _merge_into_single_chunk(x, split_every=4):
Expand All @@ -138,6 +169,12 @@ def _merge_into_single_chunk(x, split_every=4):
return x


def _qr2(a):
Q, R = nxp.linalg.qr(a)
U, S, Vh = nxp.linalg.svd(R)
return Q, R, U, S, Vh


def _qr_third_step(Q1, Q2):
m, n = Q1.chunksize
k, _ = Q1.numblocks
Expand Down Expand Up @@ -174,6 +211,30 @@ def _q_matmul(a1, a2, q2_chunks=None, block_id=None):
return q1 @ q2


def svd(x, /, *, full_matrices=True) -> SVDResult:
if full_matrices:
raise ValueError("Cubed arrays only support using full_matrices=False")

nb = x.numblocks
# TODO: what about nb[0] == nb[1] == 1
if nb[0] > nb[1]:
_, _, U, S, Vh = tsqr(x, compute_svd=True)
truncate = x.shape[0] < x.shape[1]
else:
_, _, Vht, S, Ut = tsqr(x.T, compute_svd=True)
U, S, Vh = Ut.T, S, Vht.T
truncate = x.shape[0] > x.shape[1]
if truncate: # from dask
k = min(x.shape)
U, Vh = U[:, :k], Vh[:k, :]
return SVDResult(U, S, Vh)


def svdvals(x, /):
_, S, _ = svd(x, full_matrices=False)
return S


def map_blocks_multiple_outputs(
func,
*args,
Expand Down
44 changes: 44 additions & 0 deletions cubed/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,47 @@ def test_qr_chunking():
match=r"qr only supports tall-and-skinny \(single column chunk\) arrays.",
):
xp.linalg.qr(A)


def test_svd():
A = np.reshape(np.arange(32, dtype=np.float64), (16, 2))

U, S, Vh = xp.linalg.svd(xp.asarray(A, chunks=(4, 2)), full_matrices=False)
U, S, Vh = cubed.compute(U, S, Vh)

assert_allclose(U * S @ Vh, A, atol=1e-08)
assert_allclose(U.T @ U, np.eye(2, 2), atol=1e-08) # U must be orthonormal
assert_allclose(Vh @ Vh.T, np.eye(2, 2), atol=1e-08) # Vh must be orthonormal


def test_svd_recursion():
A = np.reshape(np.arange(128, dtype=np.float64), (64, 2))

# find a memory setting where recursion happens
found = False
for factor in range(4, 16):
spec = cubed.Spec(allowed_mem=128 * factor, reserved_mem=0)

try:
U, S, Vh = xp.linalg.svd(
xp.asarray(A, chunks=(8, 2), spec=spec), full_matrices=False
)

found = True
plan_unopt = arrays_to_plan(U, S, Vh)._finalize()
assert plan_unopt.num_primitive_ops() > 4 # more than without recursion

U, S, Vh = cubed.compute(U, S, Vh)

assert_allclose(U * S @ Vh, A, atol=1e-08)
assert_allclose(U.T @ U, np.eye(2, 2), atol=1e-08) # U must be orthonormal
assert_allclose(
Vh @ Vh.T, np.eye(2, 2), atol=1e-08
) # Vh must be orthonormal

break

except ValueError:
pass # not enough memory

assert found

0 comments on commit af965e5

Please sign in to comment.