Replies: 2 comments 6 replies
-
Problem in one sentence: how to compute the mean of the some of the the gradients of an intermediate activation tensor over multiple training steps. Can we use custom gradient function for that? |
Beta Was this translation helpful? Give feedback.
-
One way to do it would be to define a 'perturbation function' (related to this idea). As a simpler example, let's say you want access to gradients with respect to the intermediate import jax.numpy as jnp
def f(x):
y = jnp.sin(x)
return y ** 2
print(grad(f)(3.)) # -0.2794155 We can write a perturbation function as def f_perturbed(x, delta_y):
y = jnp.sin(x)
y = y + delta_y
return y ** 2 Notice that print(grad(f_perturbed, (0, 1))(3., 0.) # (DeviceArray(-0.2794155, dtype=float32), DeviceArray(0.28224, dtype=float32)) Here's an example implementation of K-FAC on fully-connected networks using this trick (in Autograd, but the differentiation API is the same as JAX's). Something annoying about this is that you need to know the shapes of all the intermediates, so that you can construct appropriately-shaped zeros arrays for the perturbation arguments. The zeros could have some runtime cost too, but I wouldn't worry about that; under a One could automate the construction of both the perturbation function itself and the appropriately-shaped perturbation values by writing a custom jaxpr interpreter. I don't think a custom gradient function would help on its own because it doesn't give you a way to plumb out additional outputs. Another way to do it if you're willing to go beyond "core JAX" would be to use a library that adds "state management" features on top of JAX. I'm sure Oryx can do this for you (I believe with its Harvest API), and it's possible libraries like Flax and Haiku can also do this for you, but I'm less familiar with how it would work with those. Maybe folks on those three respective libraries could help answer questions about their APIs. (Oryx's implementation basically has custom jaxpr interpreters for this kind of thing.) With these other libraries, you may be able to simply save the values in a custom_vjp rule. They also might give you a more convenient way to implement the perturbation function approach, though if you can just save values inside a custom_vjp rule that approach seems the simplest. The difference between these and the |
Beta Was this translation helpful? Give feedback.
-
I am using JAX to implement a simple neural network (NN) and I want to access and save the gradients from the backward pass for further analysis after the NN ran. I can access and look at the gradients temporarily with the python debugger (as long as I am not using jit). But I want to save all gradients over the whole training process and analyze them after the training is done. I have come up with a rather hacky solution for this using id_tap and a global variable (see the code below). But I was wondering whether there is a better solution which does not violate the functional principles of JAX.
Many thanks!
Beta Was this translation helpful? Give feedback.
All reactions