Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samples are outside the support for DiscreteUniform distribution #1834

Open
Deathn0t opened this issue Jul 22, 2024 · 3 comments
Open

Samples are outside the support for DiscreteUniform distribution #1834

Deathn0t opened this issue Jul 22, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@Deathn0t
Copy link

Deathn0t commented Jul 22, 2024

Hello,

I noticed that samples have value outside the support for DiscreteUniform distribution. Here is a simple reproducible example:

import jax.random
import numpyro

import numpyro.distributions as dist

from numpyro.infer import HMC, MCMC, MixedHMC


def model():
    x = numpyro.sample("x", dist.DiscreteUniform(1, 2))


num_samples = 10
kernel = HMC(model, trajectory_length=1.2)
kernel = MixedHMC(kernel, num_discrete_updates=20)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples, progress_bar=False)
key = jax.random.PRNGKey(0)
mcmc.run(key)
samples = mcmc.get_samples()

print(samples)

Which outputs:

{'x': Array([1, 1, 0, 0, 0, 0, 0, 0, 0, 1], dtype=int32)}

I was expecting values of x to be in [1,2].

Am I using it wrongly or is it a real bug?

Thank you very much for your help.

@fehiepsi fehiepsi added the bug Something isn't working label Jul 22, 2024
@fehiepsi
Copy link
Member

fehiepsi commented Jul 22, 2024

Thanks @Deathn0t! It is a bug at this line

proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)

We should pass in the enumerate support there. Something like

support_size = enumerate_support.shape[0]
proposal_idx = random.randint(rng_proposal, (), minval=0, maxval=support_size)
proposal = enumerate_support[proposal_idx]

Do you want to try to fix the issue?

@Deathn0t
Copy link
Author

Hi @fehiepsi , thank you for the hint! I will try to look at it today and keep you updated if I am blocked. If I see things working I will open PR.

@Deathn0t
Copy link
Author

Deathn0t commented Jul 23, 2024

Hi @fehiepsi , I had a look at it today. It seems the Gibbs proposal is used instead of the RW:

def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):

For simplicity and minimal code changes I was thinking maybe to do the mapping to enumerate_support values here on z_discrete:

z = {**z_discrete, **hmc_state.z}

what do you think?

I tried the following and it seems to work:

z_discrete = jax.tree.map(
    lambda idx, support: support[idx],
    z_discrete,
    self._support_enumerates,
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants