Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect gradient of function with segment_prod #9296

Closed
jakuboc opened this issue Jan 24, 2022 · 6 comments
Closed

Incorrect gradient of function with segment_prod #9296

jakuboc opened this issue Jan 24, 2022 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@jakuboc
Copy link

jakuboc commented Jan 24, 2022

Hello,

I am using automatic differentiation in JAX (0.2.27, jaxlib 0.1.75, python 3.9) to obtain gradient of a simple conditional logit model during maximum likelihood estimation of the model parameters. If agents face several trials of choices, the probability of observing a sequence of choices is a product of conditional logit formulas. If I implement this in the likelihood formula using jax.ops segment_prod, the log likelihood produced by the function is correct, but the computed gradient is incorrect. I can obtain the correct gradient if the product is implemented using segment_sum and log/exp transformation. For the purpose of demonstrating the example, I attach a dummy dataset where agents face several trials of choices from different alternatives.

Choice.csv

The following code reproduces the issue:

import jax.numpy as jnp
import jax.ops as jo
from numpy import genfromtxt
from jax import grad

# Column "id" in the Choice.csv file denotes individual identifier, "choice_id" denotes trials, 
# "choice" is a binary variable indicating whether an alternative was chosen.
# The remaining columns are explanatory variables.
data = genfromtxt('Choice.csv', delimiter=';', skip_header=1)

id = data[:,0]
id = id.astype(jnp.int64)
choice_id = data[:,1]
choice_id = choice_id.astype(jnp.int64)
choice = data[:,2]
x1 = data[:,3]
x2 = data[:,4]
x3 = data[:,5]

def clogit_ll(b):
	numer = jnp.exp(b[0]*x1+b[1]*x2+b[2]*x3)
	ch = numer[(choice==1)]
	denom = jo.segment_sum(numer,choice_id)
	denom = jnp.delete(denom, 0)
	l1 = ch/denom
	ids = id[(choice==1)]
	ll1 = jo.segment_prod(l1,ids)
	ll1 = jnp.delete(ll1, 0)	
	lli = jnp.log(ll1)
	return(jnp.sum(lli))

coef = jnp.array([ -.3400659,  -.9289953,  -.6646674])
grad_clogit = jax.value_and_grad(clogit_ll)
grad_clogit(coef)

Output:

(DeviceArray(-1477.655, dtype=float32), DeviceArray([ 377694.53, -213394.  ,  367053.06], dtype=float32))

Which is not correct if we check:

def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

first_finite_differences(clogit_ll, coef)

Output:

DeviceArray([-146.72852,  126.95312, -173.21776], dtype=float32)

Now, if we replace segment_prod in the function with an exponent of segment_sum logs instead:

def clogit_ll(b):
	numer = jnp.exp(b[0]*x1+b[1]*x2+b[2]*x3)
	ch = numer[(choice==1)]
	denom = jo.segment_sum(numer,choice_id)
	denom = jnp.delete(denom, 0)
	l1 = ch/denom
	ids = id[(choice==1)]
	ll1 = jnp.exp(jo.segment_sum(jnp.log(l1),ids))
	ll1 = jnp.delete(ll1, 0)	
	lli = jnp.log(ll1)
	return(jnp.sum(lli))

grad_clogit = jax.value_and_grad(clogit_ll)
grad_clogit(coef)

The gradient is now correct:

(DeviceArray(-1477.655, dtype=float32), DeviceArray([-146.70026 ,  126.960014, -173.30954 ], dtype=float32))
@jakuboc jakuboc added the bug Something isn't working label Jan 24, 2022
@hawkinsp hawkinsp self-assigned this Jan 24, 2022
@hawkinsp
Copy link
Collaborator

Yes, this looks like a legitimate bug in scatter_mul's gradient. I'm looking into it.

copybara-service bot pushed a commit that referenced this issue Jan 24, 2022
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient.

#9296

PiperOrigin-RevId: 423896106
@hawkinsp
Copy link
Collaborator

hawkinsp commented Jan 24, 2022

scatter_mul's gradient is only correct if there aren't colliding indices. For now, I'm going to make the non-unique case an error. That means your example in this issue will report an error, which is much better than a wrong output.

Does that suffice for your purposes? As you note, if you know your values are positive, you can do the computation in log space.

copybara-service bot pushed a commit that referenced this issue Jan 24, 2022
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient.

#9296

PiperOrigin-RevId: 423896106
copybara-service bot pushed a commit that referenced this issue Jan 24, 2022
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient.

#9296

PiperOrigin-RevId: 423896106
@jakuboc
Copy link
Author

jakuboc commented Jan 24, 2022

By non-unique indices you mean something like:

segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
data = jnp.array([1, 1, 1, 1, 2, 2])

where the products for indices 0 and 1 are non-unique?

For the purpose of conditional logit, the log transform should work, as the individual contribution to the likelihood can't be negative, since it is a probability. A group product doesn't appear that often in likelihood functions of other estimators, at least for the moment I can't recall anything. Though, I guess it would be nice to have the option of non-unique indices with scatter_mul gradients anyway :)

@hawkinsp
Copy link
Collaborator

Well, you'd need segment_ids to be all different, and to pass unique_indices=True to segment_prod. I acknowledge that makes the gradient much less useful...

  • Using log space is probably the best approach.
  • If it is useful, it would be possible for us to add a gradient that mirror's TensorFlow's gradient for its version of segment_prod:
    https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/ops/math_grad.py;l=534
    However I don't like that approach very much: using division feels like it has a good chance of being numerically unstable if, for example, one of the elements are small. I also looked at PyTorch for inspiration but as far as I can tell PyTorch doesn't have an equivalent differentiable operator.
  • Yet another approach would be to individually scatter each element into an array of ones, and then multiply the arrays. This would waste a lot of memory but has the virtue of having a derivative.
  • If we had a bunch of free time, it would be possible to define an efficient forward-mode derivative using a variadic version of the scatter operator, but unfortunately XLA doesn't implement such an operator. You could then turn the forward-mode derivative into the reverse-mode derivative using a trick similar to the one TF uses.

So I'm tempted to just say "use log space" if you can.

copybara-service bot pushed a commit that referenced this issue Jan 24, 2022
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient.

#9296

PiperOrigin-RevId: 423896106
copybara-service bot pushed a commit that referenced this issue Jan 24, 2022
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient.

#9296

PiperOrigin-RevId: 423917753
@hawkinsp
Copy link
Collaborator

Closing because the "wrong output" bug is fixed. If someone needs the segment_prod gradient and isn't happy with the proposed workaround, feel free to reopen.

@oliverdutton
Copy link
Contributor

oliverdutton commented Mar 29, 2022

This problem came up for me, I am completely happy with the log->add->exp pathway and agree the straight mul route could be horrifically numerically unstable. Thank you for the solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants