-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorrect gradient of function with segment_prod #9296
Comments
Yes, this looks like a legitimate bug in |
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient. #9296 PiperOrigin-RevId: 423896106
Does that suffice for your purposes? As you note, if you know your values are positive, you can do the computation in log space. |
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient. #9296 PiperOrigin-RevId: 423896106
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient. #9296 PiperOrigin-RevId: 423896106
By non-unique indices you mean something like: segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
data = jnp.array([1, 1, 1, 1, 2, 2]) where the products for indices 0 and 1 are non-unique? For the purpose of conditional logit, the log transform should work, as the individual contribution to the likelihood can't be negative, since it is a probability. A group product doesn't appear that often in likelihood functions of other estimators, at least for the moment I can't recall anything. Though, I guess it would be nice to have the option of non-unique indices with |
Well, you'd need
So I'm tempted to just say "use log space" if you can. |
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient. #9296 PiperOrigin-RevId: 423896106
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient. #9296 PiperOrigin-RevId: 423917753
Closing because the "wrong output" bug is fixed. If someone needs the |
This problem came up for me, I am completely happy with the log->add->exp pathway and agree the straight mul route could be horrifically numerically unstable. Thank you for the solution. |
Hello,
I am using automatic differentiation in JAX (0.2.27, jaxlib 0.1.75, python 3.9) to obtain gradient of a simple conditional logit model during maximum likelihood estimation of the model parameters. If agents face several trials of choices, the probability of observing a sequence of choices is a product of conditional logit formulas. If I implement this in the likelihood formula using jax.ops segment_prod, the log likelihood produced by the function is correct, but the computed gradient is incorrect. I can obtain the correct gradient if the product is implemented using segment_sum and log/exp transformation. For the purpose of demonstrating the example, I attach a dummy dataset where agents face several trials of choices from different alternatives.
Choice.csv
The following code reproduces the issue:
Output:
Which is not correct if we check:
Output:
Now, if we replace segment_prod in the function with an exponent of segment_sum logs instead:
The gradient is now correct:
The text was updated successfully, but these errors were encountered: