Vmap and lax.scan with different sequence array #22293
-
Hi everyone , I have this following code written in pure jax, Thanks all from Jax import random
from jax import lax
import jax
import jax.numpy as jnp
import pdb
def fwd_dynamics(x_u, xs):
x0,uk = x_u
Delta_T = 0.001
lwb = 1.2
psi0=x0[2][0]
v0= x0[3][0]
vdot0 = uk[0][0]
delta0 = uk[1][0]
thetadot0 = uk[2][0]
xdot= jnp.asarray([[v0*jnp.cos(psi0) ],
[v0*jnp.sin(psi0)] ,
[v0*jnp.tan(delta0)/(lwb)],
[vdot0],
[thetadot0]])
x_next = x0 + xdot*Delta_T
return (x_next,uk), x_next # ("carryover", "accumulated")
def state_predictor( xk,uk ,sim_timestep):
(x_next,_), _ = lax.scan(fwd_dynamics, (xk,uk) ,jnp.arange(sim_timestep) )
return x_next
low = 0 # Adjust minimum value as needed
high = 100 # Adjust maximum value as needed
key = jax.random.PRNGKey(44)
sim_time = jax.random.randint(key, shape=(10, 1), minval=low, maxval=high)
xk = jax.random.uniform(key, shape=(10,5, 1))
uk = jax.random.uniform(key, shape=(10,2, 1))
state_predictor_vmap = jax.jit(jax.vmap(state_predictor,in_axes= 0 ,out_axes=0 ))
x_next = state_predictor_vmap( xk,uk ,sim_time)
print(x_next.shape) ``` |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Looks like this was answered here: https://stackoverflow.com/questions/78713478/jax-vmap-with-lax-scan-having-different-sequence-length-in-batch-dimension |
Beta Was this translation helpful? Give feedback.
Looks like this was answered here: https://stackoverflow.com/questions/78713478/jax-vmap-with-lax-scan-having-different-sequence-length-in-batch-dimension