diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 4ea1b2dc1..abf8f70b1 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -638,10 +638,11 @@ def Multinomial( """Multinomial distribution. :param total_count: number of trials. If this is a JAX array, - `total_count_max` is required to specify. + it is required to specify `total_count_max`. :param probs: event probabilities :param logits: event log probabilities - :param int total_count_max: the maximum number of trials + :param int total_count_max: the maximum number of trials, + i.e. `max(total_count)` """ if probs is not None: return MultinomialProbs(