diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index 3b7196682d..e9b9a87e99 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -486,23 +486,9 @@ class Dirichlet(Continuous): def __init__(self, a, transform=transforms.stick_breaking, *args, **kwargs): - - if not isinstance(a, pm.model.TensorVariable): - if not isinstance(a, list) and not isinstance(a, np.ndarray): - raise TypeError( - 'The vector of concentration parameters (a) must be a python list ' - 'or numpy array.') - a = np.array(a) - if (a <= 0).any(): - raise ValueError("All concentration parameters (a) must be > 0.") - - shape = np.atleast_1d(a.shape)[-1] - - kwargs.setdefault("shape", shape) super().__init__(transform=transform, *args, **kwargs) self.size_prefix = tuple(self.shape[:-1]) - self.k = tt.as_tensor_variable(shape) self.a = a = tt.as_tensor_variable(a) self.mean = a / tt.sum(a) @@ -569,14 +555,13 @@ def logp(self, value): ------- TensorVariable """ - k = self.k a = self.a # only defined for sum(value) == 1 return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(tt.sum(a, axis=-1)), tt.all(value >= 0), tt.all(value <= 1), - k > 1, tt.all(a > 0), + np.logical_not(a.broadcastable), tt.all(a > 0), broadcast_conditions=False) def _repr_latex_(self, name=None, dist=None):