You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to implement a specific reversible-jump MCMC algorithm, By definition, an RJ algorithm traverses solutions of different dimensionalities, which is obviously tricky with Jax, but I'm hoping I'm simply missing a clever trick to make this work :-)
Here is a MWE of what I am trying to do:
import jax.numpy as jnp
import jax.random as jrnd
key = jrnd.PRNGKey(0)
key, wishart_key, choice_key = jrnd.split(key, 3)
p = 4
dof = p + 1
scale = jnp.eye(p)
# Create a pxp Wishart-distributed random matrix
u = jrnd.multivariate_normal(wishart_key, mean=jnp.zeros((p, )), cov=scale, shape=(dof, ))
Sigma = jnp.einsum('dp,dq->pq', u, u)
# Create a pxp symmetric, binary matrix
G = jnp.array([[1, 1, 0, 0], [1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 1, 1]])
def take_dynamic_submatrices(i, G):
# select all variables that i connects to in G
ix = jnp.where(G[i] == 1)[0]
W_sub = W[ix[:, None], ix]
return W_sub
#
# Pick a node
i = 1
W_sub = take_dynamic_submatrices(i, G)
print('W_sub:\n', W_sub)
# How can we make this work within a vmap or lax.scan?
W_sub_jit = jax.jit(take_dynamic_submatrices)(i, G)
print('W_sub_jit:\n', W_sub_jit)
Throughout the algorithm, Sigma/W and G would change values, but remain of shape $p\times p$. However depending on random steps, G would change, and that would mean ix and hence W_sub would/could be different in size across iterations of the algorithm. I do not really need to store elements of different sizes, as I would assign W_sub back to a subpart of W, keeping what I track over iterations of the same shape. I would need to have them as variables on the fly though.
I have tried a couple of directions: static_argnums, and simply masking the matrices that I update. The first one does not work, because G is not hashable (as it is an array), and the second approach failed as well because later I have to do some linear algebra operations on W_sub, and that does not work if it is masked.
It's definitely possible that what I am trying to do is simply impossible with Jax due to the requirement of array shapes being known, but since the objects I really care about do keep the same shape, do you perhaps have a suggestion on how to proceed?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I'm trying to implement a specific reversible-jump MCMC algorithm, By definition, an RJ algorithm traverses solutions of different dimensionalities, which is obviously tricky with Jax, but I'm hoping I'm simply missing a clever trick to make this work :-)
Here is a MWE of what I am trying to do:
Throughout the algorithm, Sigma/W and G would change values, but remain of shape$p\times p$ . However depending on random steps, G would change, and that would mean
ix
and henceW_sub
would/could be different in size across iterations of the algorithm. I do not really need to store elements of different sizes, as I would assignW_sub
back to a subpart ofW
, keeping what I track over iterations of the same shape. I would need to have them as variables on the fly though.I have tried a couple of directions:
static_argnums
, and simply masking the matrices that I update. The first one does not work, becauseG
is not hashable (as it is an array), and the second approach failed as well because later I have to do some linear algebra operations onW_sub
, and that does not work if it is masked.It's definitely possible that what I am trying to do is simply impossible with Jax due to the requirement of array shapes being known, but since the objects I really care about do keep the same shape, do you perhaps have a suggestion on how to proceed?
Beta Was this translation helpful? Give feedback.
All reactions