How to keep hashable arguments static through a scan #16667
Unanswered
HenriLamarre
asked this question in
Q&A
Replies: 2 comments 3 replies
-
Thanks for the question! Any of the arguments that pass through the def loop_function(itermax):
iter = 0
nonstatic = 1
static = 2
def scanner_closure(a, x):
a, x = scanner((a[0], a[1], static), x)
return (a[0], a[1]), x
a, x = lax.scan(scanner_closure, (iter, nonstatic), xs=None, length=itermax) Now |
Beta Was this translation helpful? Give feedback.
3 replies
-
This is a very common issue. As a follow-up, you may like |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi, disclaimer: I am new to JAX but it has been awesome so far! :D
I have a function ('jitfunc') that is jit-inherited and works properly. This function is being called in a while loop. The while loop itself is very slow whereas the function is fast. I would like to convert my while loop to either lax.while_loop, or lax.fori_loop but most likely lax.scan as it is differentiable and I the next step for this project is AI stuff.
So I have a maximal number of iterations itermax for my loop but the reason I use a while loop instead of a for loop is because sometimes, the jit-function can speed up the calculations and increase the iterand by more than 1.
In that scenario, I was planning to use scan, and have the function do nothing when it has reached itermax naturally (I think this is doable but I have no idea if it is optimal)
However, my issue comes from the fact that when I try running lax.scan, I get the error:
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 6) of type for function in_loop is non-hashable.
From looking around online, it looks like using lax.scan makes my hashable static variables non-hashable? (Not sure about this one)
Is there a fix for what I am trying to do?
My code is not super readable so here is a minimal example that replicates my error:
Thanks! :)
Beta Was this translation helpful? Give feedback.
All reactions