Replies: 1 comment
-
Have you figured it out? Is writing forward+backward w/ custom_vjp or custom_jvp required or not? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Reading this section in the docs I would expect that calling
jax.grad
on a Pallas call would work – but potentially with a performance hit. However, using the following snippet:Results in an
AssertionError
:What is happening here? Is this some issue with using
interpret=True
? Is the expected usage to define different pallas kernels for fwd and bwd usingjax.custom_jvp
?Beta Was this translation helpful? Give feedback.
All reactions