Skip to content

Best practices for debugging vjp blowups #22343

Answered by dfm
lengstrom asked this question in Q&A
Discussion options

You must be logged in to vote

I'm not sure that I have a great answer to the broad question here, but there are some options. One place to start is to try printing the jaxpr of your reverse pass:

print(jax.make_jaxpr(jax.vjp(fun, *args)[1])(*ct))

Which you can also combine with jax.named_scope to get a little bit more metadata. For example:

Example
import jax
import jax.numpy as jnp

@jax.jit
def fun(x):
  return f1(x) * f2(x)

@jax.jit
@jax.named_scope("f1")
def f1(x):
  return jnp.sin(x)

@jax.jit
@jax.named_scope("f2")
def f2(x):
  return jnp.exp(0.5 * x)

jax.make_jaxpr(jax.vjp(fun, jnp.ones(5))[1])(jnp.ones(5))

Prints:

{ lambda a:f32[5] b:f32[5] c:f32[5]; d:f32[5]. let
    e:f32[5] = pjit[
      name=fun
      j…

Replies: 1 comment

Comment options

dfm
Jul 9, 2024
Collaborator

You must be logged in to vote
0 replies
Answer selected by lengstrom
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants