Skip to content

How to perform a strided dynamic update slice? #22246

Answered by dfm
EdwardRaff asked this question in Q&A
Discussion options

You must be logged in to vote

I'm not sure I totally understand the question or issues that you're encountering, but something like your first example works just fine under jit:

@jax.jit
def fun(S, Q):
  for z in range(9):
    i, j = np.mod(z,3), np.mod(z//3,3)  # <- Using 'np' not 'jnp' here
    Q = Q.at[:, i::3,j::3].add(S[z,:,:,:])
  return Q

C_out = 2
H, W = 11, 11
Q = jnp.zeros((C_out, H*3,W*3))
S = jnp.arange((9* C_out*H*W)).reshape(9, C_out, H, W)
Q = fun(S, Q)

Depending on your actual use case, you might also be able to use scan with the unroll parameter set >1, or (even better!) perhaps you can rewrite your problem as a convolution and use JAX's convolution implementation!

Replies: 1 comment 6 replies

Comment options

dfm
Jul 3, 2024
Collaborator

You must be logged in to vote
6 replies
@EdwardRaff
Comment options

@dfm
Comment options

dfm Jul 3, 2024
Collaborator

@EdwardRaff
Comment options

@dfm
Comment options

dfm Jul 3, 2024
Collaborator

@EdwardRaff
Comment options

Answer selected by EdwardRaff
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants