Skip to content

Commit

Permalink
Remove Dirichlet distribution type restrictions
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 5, 2020
1 parent 8770259 commit 9a2da91
Showing 1 changed file with 1 addition and 16 deletions.
17 changes: 1 addition & 16 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,23 +486,9 @@ class Dirichlet(Continuous):

def __init__(self, a, transform=transforms.stick_breaking,
*args, **kwargs):

if not isinstance(a, pm.model.TensorVariable):
if not isinstance(a, list) and not isinstance(a, np.ndarray):
raise TypeError(
'The vector of concentration parameters (a) must be a python list '
'or numpy array.')
a = np.array(a)
if (a <= 0).any():
raise ValueError("All concentration parameters (a) must be > 0.")

shape = np.atleast_1d(a.shape)[-1]

kwargs.setdefault("shape", shape)
super().__init__(transform=transform, *args, **kwargs)

self.size_prefix = tuple(self.shape[:-1])
self.k = tt.as_tensor_variable(shape)
self.a = a = tt.as_tensor_variable(a)
self.mean = a / tt.sum(a)

Expand Down Expand Up @@ -569,14 +555,13 @@ def logp(self, value):
-------
TensorVariable
"""
k = self.k
a = self.a

# only defined for sum(value) == 1
return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1)
+ gammaln(tt.sum(a, axis=-1)),
tt.all(value >= 0), tt.all(value <= 1),
k > 1, tt.all(a > 0),
np.logical_not(a.broadcastable), tt.all(a > 0),
broadcast_conditions=False)

def _repr_latex_(self, name=None, dist=None):
Expand Down

0 comments on commit 9a2da91

Please sign in to comment.