Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unclear evaluation semantics of gradients of random functions #3702

Closed
JonyEpsilon opened this issue Jul 9, 2020 · 5 comments
Closed

Unclear evaluation semantics of gradients of random functions #3702

JonyEpsilon opened this issue Jul 9, 2020 · 5 comments
Labels
question Questions for the JAX team

Comments

@JonyEpsilon
Copy link
Contributor

This isn't a bug, more a request to open a discussion around clarifying how random functions interact with jax.grad.

Consider the following example:

def r(x, key):
  return x * jax.random.uniform(key)

rg = jax.value_and_grad(r)
key = jax.random.PRNGKey(42)
rg(10.0, key)

> (DeviceArray(4.2672753, dtype=float32), DeviceArray(0.42672753, dtype=float32))

The question is how a user might understand the gradient values that are being returned. Chatting with some of my colleagues about their mental model for this yielded three answers:

  1. some found this behaviour surprising, expecting perhaps that the gradient should not be well defined.
  2. some thought that this was expected as uniform is a deterministic function of key, independent of x therefore it's reasonable that if uniform(key) = c and r = c * x then dr/dx = c.
  3. some thought that this was expected because first the random number was instantiated, and then the partial derivative taken with the random number held fixed.

I think these are all reasonable guesses to what the evaluation semantics would be. And the actual evaluation semantics (the second on this list) is also reasonable, and fits well with JAX's "just transforming the code" ethos.

Consider, though, this example:

def r2(x, key):
  return jax.random.uniform(key, maxval=x)

rg2 = jax.value_and_grad(r2)
key = jax.random.PRNGKey(42)
rg2(10.0, key)

> (DeviceArray(4.2672753, dtype=float32), DeviceArray(0.42672753, dtype=float32))

Quite a few people I talked with found this unexpected ... more than found the first example unexpected. People with the option 3 mental model would find this surprising because, once the random number has been generated it is no longer dependent on x, therefore they would expect the derivative of rg2 to be zero w.r.t. x. Another way that this implicit option 3 view manifested was one person who noted that it looks like "leaking an implementation detail" as it suggests that the random number is calculated as x * U(0,1).

If I had to guess from the (very anecdotal) results of chatting about these two examples with people, I'd say a small minority of people think the evaluation semantics would be option 1, with a roughly even split between those that (perhaps implicitly) are thinking it's option 2 and option 3.

I'm not really trying to take a position on which evaluation semantics are best here. More wanting to note that they're not obvious, and can lead to results that many users might find unexpected. It's also worth noting that the current semantics don't necessarily yield the gradients that would be useful for training an ML model, so user misunderstandings on this matter could potentially lead to research errors.

I think the right thing to do would be to make sure that there's first a clear consensus on what the semantics should be. Then I think it would be wise to document this quite prominently.

@shoyer
Copy link
Collaborator

shoyer commented Jul 9, 2020

For understanding how JAX defines gradients for random number generation (RNG) functions, it's helpful to realize that pseudo-random number generators are actually deterministic functions in JAX, based on key. (Interpretation (2) in your list.)

So for the first example, this is exactly the same as if jax.random.uniform is replaced by any other function of a single integer value, e.g.,

def r(x, key):
  return x * any_function(key)

I'll let someone else comment on differentiation with respect to float parameters of random number generation functions. I agree that the r2 example is a little confusing, but I think there is a case that this is the only meaningful way to define the gradient -- the alternative would be raising an error.

@mattjj
Copy link
Collaborator

mattjj commented Jul 9, 2020

One short way to describe JAX's behavior is that the semantics match formal probability theory (which is functional!), and the functional PRNG makes that possible.

A "random variable" is really a deterministic function that takes a random element (and potentially some parameters) and produces a real value (for simplicity), so that we can think of it has having type \Omega \to R. A parameterized random variable might have type R \times \Omega \to R, so that its value for a particular parameter value \theta and a particular random element \omega might be denoted as X(\theta, \omega). If the partial derivative of X with respect to its first input exists for each \omega, then we can think of ∂_0 X : R \times \Omega -> R as a new well-defined random variable (where ∂_0 means partial derivative with respect to the first argument, i.e. the argument at index 0).

In Python, we can model random elements as PRNG keys, and random variables as Python functions that take keys as arguments. In particular, we can model X as a Python function r = lambda theta, key: .... Then grad(r) models the random variable ∂_0 X.

So partial derivatives are well-defined (closest to your interpretation 2) so long as the random variables in question are well-defined as functions. But just specifying the distribution of a random variable is not enough to unambiguously pin down its definition. I think that's the ambiguity that arises in the second example r2.

The function \theta \mapsto ∂_0 X(\theta, \omega) represents the answer to "for this particular fixed seed, how does the value of X change as I change the parameter?" This derivative need not be defined, and whether it is indeed depends on more details about X than just its distribution. Let's say our base probability space has a uniform measure over \Omega. As you've observed, we could take X(\theta, \omega) = \theta * U(\omega) where U has a uniform law on [0, 1). But if we instead define X(\theta, \omega) = \theta * U(f(\theta, \omega)), where f has type \R times \Omega -> \Omega and is an indexed hash-like bijection that scrambles all the elements of the sample space by hashing its two inputs together, then this new X still has the right distribution, but won't be smooth enough for ∂_0 X(\theta, \omega) to be defined.

So I'd say yes, grad(r2) is indeed "leaking" more information about r2, and indeed about jax.random.uniform's implementation, than just their distributions alone would tell you. Derivatives do tend to do that: they give us more information about functions!

WDYT?

@mattjj mattjj added the question Questions for the JAX team label Jul 9, 2020
@JonyEpsilon
Copy link
Contributor Author

Sorry for the slow reply @mattjj , and thanks for the clear explanation.

On expressions involving random samples from distributions not parameterised by variables that will be differentiated w.r.t : I agree that this seems like a reasonable thing to do. Since it's not entirely expected by everyone, though, I wonder whether a useful action item would be to add something to the docs on this point? Maybe a short addition to the "sharp bits" would be appropriate, showing perhaps the r1 example and explaining the result?

The second case discussed, where a derivative is taken w.r.t. a distribution parameter is interesting. Your explanation of why it feels like it's "leaking" makes sense. I guess another way of putting it is that the gradients are revealing something about the generative process behind the random numbers. The thing I can't get straight in my head (frustratingly, as I've tried to prove the relevant point a few times, and failed - I suspect due to lack of imagination) is what impact this has on users of the random numbers. The point that's confusing me is this:

Imagine we have two parameterised generative processes, and the outputs of these processes agree in distribution. What can we say about gradients of these processes w.r.t. their parameters? My gut tells me that the expectation of the gradient should be the same for all reasonable processes that have outputs that are equal in distribution, otherwise I don't see how expectations of the variable itself could be made to do the right thing. My gut also tells me, though, that the other moments of the distribution of gradients could differ. Think of a location parameter for the distribution. I'm imagining one generative process where changing the location parameter just moves all of the sampled points along equally by the required amount. And I'm imagining a second generative process which moves the points around much more vigorously, but in a way that is contrived such that the distribution is the same as before, just moved by the right amount. The distribution of gradients from the latter process would be much broader than from the former. And clearly, this could be problematic for an ML experiment where the variance in the gradients matters because of finite sampling to approximate expectations.

What's bugging me is that I can't decide whether the second generative process can exist or not, while still being differentiable w.r.t. its parameters. In your answer above you give an example of a non-differentiable process that is equal in distribution. The question is whether there's space "in between" for something differentiable, different, but still equal in distribution. I tried a few times to construct illustrative maps explicitly on simple distributions, but failed!

I guess the answer to this question informs what the right thing to do is. If there's only really one distribution of gradients that can come out in the case that the process is differentiable, then jax.random is probably doing the right thing already. But if there are many reasonable parameterisations that would give different distributions of gradients, then I'd say it's definitely doing the wrong thing and it should probably force the user to express the generative process directly by erroring when taking gradients w.r.t. distribution parameters.

Sorry it's a bit rambling. Hopefully the gist of it makes sense!

@mattjj
Copy link
Collaborator

mattjj commented Jul 22, 2020

Thanks for the discussion. Rambling is welcome; I do plenty myself!

My gut tells me that the expectation of the gradient should be the same for all reasonable processes that have outputs that are equal in distribution, otherwise I don't see how expectations of the variable itself could be made to do the right thing.

I don't think that's true though (unless you offer some definition of "reasonable"!). As a more concrete example, consider the random variable Y(ω, µ) = X(ω) + µ where X(ω) ~ N(0, 1). Then consider Z(ω, µ) = f(Y(ω, µ)) where f(x) = -x if x is rational else x. Unless I'm mistaken, Z is nowhere differentiable with respect to its second argument, since f is nowhere differentiable (or even continuous), yet Y and Z have the same distribution. So we don't have that the expectation of the gradient is the same for Y and Z, since the latter has no gradient at all.

The point is just that saying two RVs have the same distribution is not pinning them down nearly enough to say whether their derivatives agree (or exist at all).

If there's only really one distribution of gradients that can come out in the case that the process is differentiable, then jax.random is probably doing the right thing already.

That condition might be sufficient but I don't think it's necessary. If there are many possible conventions, we can just pick one that seems reasonable. What would be the alternative; just to return nans?

I think jax.random is following a good convention already. If a problem with its behavior arises then we should consider revising it, but unless/until that happens I don't think there's anything to revise.

WDYT?

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 28, 2022

It seems like this question has been resolved.

@jakevdp jakevdp closed this as completed Jun 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

4 participants