You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Here is the issue corresponding to the forum question 5066.
Example:
5 Customers are going to buy apples from 4 available category of apples.
The total number of purchase is negative binomial (with zero purchase being the most common one).
The probability of choosing apples from each category is drawn from Dirichlet distribution for each customer.
Here is the minimal code:
importjaximportjax.numpyasjnpimportmatplotlib.pyplotaspltimportnumpyroimportnumpyro.distributionsasdistfromnumpyro.inferimportPredictiveseed=1234# numpyro seedrng_key=jax.random.PRNGKey(seed)
rng_trace, rng_prior_pred, rng_posterior_pred=jax.random.split(rng_key, 3)
defgen_data(num_apple_category: int=4, n_customers: int=5):
mu=1alpha=2withnumpyro.plate("customers", n_customers):
# The latent probability for number of buys for each customer is Negative binomial.# The lower number of the purchase has more chance.num_purchase=numpyro.sample("num_purchase", dist.NegativeBinomial2(mu, alpha))
desired_probs=numpyro.sample(
"desired_probs", dist.Dirichlet(0.5*jnp.ones(num_apple_category))
)
numpyro.sample(
"apples", dist.Multinomial(total_count=num_purchase, probs=desired_probs)
)
num_samples=1# Draw from priors to make the fake dataprior_predictive=Predictive(gen_data, num_samples=num_samples)
prior_predictions=prior_predictive(rng_prior_pred)
Error:
AssertionError: The total count parameter `n` should not be a jax abstract array.
Versions:
jax=='0.4.1'
jaxlib=='0.4.1'
numpyro=='0.10.1'
The text was updated successfully, but these errors were encountered:
Here is the issue corresponding to the forum question 5066.
Example:
5 Customers are going to buy apples from 4 available category of apples.
Here is the minimal code:
Error:
AssertionError: The total count parameter `n` should not be a jax abstract array.
Versions:
The text was updated successfully, but these errors were encountered: