efficient computation of block diagonal matrix vector multiplication when matrix is in compressed format #14288
Unanswered
parshakova
asked this question in
Q&A
Replies: 1 comment 1 reply
-
The variable slice sizes make this somewhat more difficult, but often you can use @jax.jit
def segment_sum(B, x, indices):
i = jnp.zeros(B.shape[0], int).at[jnp.array(indices)].set(1).cumsum()
return jax.ops.segment_sum(B * x, i, num_segments=1 + len(indices)).reshape(-1, 1)
ans1 = block_diag_expand(B, x, indices)
ans2 = matvec_slicing(B, x, indices)
ans3 = segment_sum(B, x, indices)
np.testing.assert_allclose(ans1, ans2)
np.testing.assert_allclose(ans1, ans3, rtol=1E-4) # larger tolerance due to float accumulation errors
%timeit block_diag_expand(B, x, indices).block_until_ready()
# 4.7 ms ± 947 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matvec_slicing(B, x, indices).block_until_ready()
# 2.77 ms ± 114 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit segment_sum(B, x, indices).block_until_ready()
# 640 µs ± 78.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Given matrix$B \in \mathrm{R}^{n \times r}$ , a vector $x \in \mathrm{R}^{n}$ and a set of indices for slicing matrix $i_1, \ldots,i_k$ .
We want to compute the following matvec efficiently using jax
$(B_{:i_1}, \ldots, B_{i_{k-1}:i_k})^T x$ ,$B$ is being sliced along the first dimension and slices are not of the same size.
block_diag
where
I tried these two functions but it is pretty slow
the process above takes$n r$ flops, therefore those functions are super slow, e.g., comparing to $n r$ flops
Does anyone have any idea on how to improve this?
I did not figure out how to use
vmap
in this case when the slices are of different sizes, because we cannot just unfold the vector and matrix and do some kind of batch matrix vector multiplication.Beta Was this translation helpful? Give feedback.
All reactions