Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Dirichlet distribution type restrictions #4000

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from scipy import stats, linalg

from theano.gof.op import get_test_value
from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
from theano.tensor.slinalg import Cholesky
import pymc3 as pm
Expand Down Expand Up @@ -487,22 +488,23 @@ class Dirichlet(Continuous):
def __init__(self, a, transform=transforms.stick_breaking,
*args, **kwargs):

if not isinstance(a, pm.model.TensorVariable):
if not isinstance(a, list) and not isinstance(a, np.ndarray):
raise TypeError(
'The vector of concentration parameters (a) must be a python list '
'or numpy array.')
a = np.array(a)
if (a <= 0).any():
raise ValueError("All concentration parameters (a) must be > 0.")
brandonwillard marked this conversation as resolved.
Show resolved Hide resolved

shape = np.atleast_1d(a.shape)[-1]
brandonwillard marked this conversation as resolved.
Show resolved Hide resolved
if kwargs.get('shape') is None:
warnings.warn(
(
"Shape not explicitly set. "
"Please, set the value using the `shape` keyword argument. "
"Using the test value to infer the shape."
),
DeprecationWarning
)
try:
kwargs['shape'] = get_test_value(tt.shape(a))
except AttributeError:
pass

kwargs.setdefault("shape", shape)
super().__init__(transform=transform, *args, **kwargs)

self.size_prefix = tuple(self.shape[:-1])
self.k = tt.as_tensor_variable(shape)
self.a = a = tt.as_tensor_variable(a)
self.mean = a / tt.sum(a)

Expand Down Expand Up @@ -569,14 +571,13 @@ def logp(self, value):
-------
TensorVariable
"""
k = self.k
a = self.a

# only defined for sum(value) == 1
return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1)
+ gammaln(tt.sum(a, axis=-1)),
tt.all(value >= 0), tt.all(value <= 1),
k > 1, tt.all(a > 0),
np.logical_not(a.broadcastable), tt.all(a > 0),
lucianopaz marked this conversation as resolved.
Show resolved Hide resolved
broadcast_conditions=False)

def _repr_latex_(self, name=None, dist=None):
Expand Down
4 changes: 2 additions & 2 deletions pymc3/tests/test_dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ def test_multinomial_bound():
n = x.sum()

with pm.Model() as modelA:
p_a = pm.Dirichlet('p', floatX(np.ones(2)))
p_a = pm.Dirichlet('p', floatX(np.ones(2)), shape=(2,))
MultinomialA('x', n, p_a, observed=x)

with pm.Model() as modelB:
p_b = pm.Dirichlet('p', floatX(np.ones(2)))
p_b = pm.Dirichlet('p', floatX(np.ones(2)), shape=(2,))
MultinomialB('x', n, p_b, observed=x)

assert np.isclose(modelA.logp({'p_stickbreaking__': [0]}),
Expand Down
19 changes: 8 additions & 11 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,17 +1328,14 @@ def test_dirichlet(self, n):
Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf
)

@pytest.mark.parametrize("n", [3, 4])
def test_dirichlet_init_fail(self, n):
with Model():
with pytest.raises(
ValueError, match=r"All concentration parameters \(a\) must be > 0."
):
_ = Dirichlet("x", a=np.zeros(n), shape=n)
with pytest.raises(
ValueError, match=r"All concentration parameters \(a\) must be > 0."
):
_ = Dirichlet("x", a=np.array([-1.0] * n), shape=n)
brandonwillard marked this conversation as resolved.
Show resolved Hide resolved
def test_dirichlet_shape(self):
a = tt.as_tensor_variable(np.r_[1, 2])
with pytest.warns(DeprecationWarning):
dir_rv = Dirichlet.dist(a)
assert dir_rv.shape == (2,)

with pytest.warns(DeprecationWarning), theano.change_flags(compute_test_value="ignore"):
dir_rv = Dirichlet.dist(tt.vector())

def test_dirichlet_2D(self):
self.pymc3_matches_scipy(
Expand Down
8 changes: 4 additions & 4 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,15 +912,15 @@ def test_mixture_random_shape():
nr.poisson(9, size=10)])
with pm.Model() as m:
comp0 = pm.Poisson.dist(mu=np.ones(2))
w0 = pm.Dirichlet('w0', a=np.ones(2))
w0 = pm.Dirichlet('w0', a=np.ones(2), shape=(2,))
like0 = pm.Mixture('like0',
w=w0,
comp_dists=comp0,
observed=y)

comp1 = pm.Poisson.dist(mu=np.ones((20, 2)),
shape=(20, 2))
w1 = pm.Dirichlet('w1', a=np.ones(2))
w1 = pm.Dirichlet('w1', a=np.ones(2), shape=(2,))
like1 = pm.Mixture('like1',
w=w1,
comp_dists=comp1,
Expand Down Expand Up @@ -967,15 +967,15 @@ def test_mixture_random_shape_fast():
nr.poisson(9, size=10)])
with pm.Model() as m:
comp0 = pm.Poisson.dist(mu=np.ones(2))
w0 = pm.Dirichlet('w0', a=np.ones(2))
w0 = pm.Dirichlet('w0', a=np.ones(2), shape=(2,))
like0 = pm.Mixture('like0',
w=w0,
comp_dists=comp0,
observed=y)

comp1 = pm.Poisson.dist(mu=np.ones((20, 2)),
shape=(20, 2))
w1 = pm.Dirichlet('w1', a=np.ones(2))
w1 = pm.Dirichlet('w1', a=np.ones(2), shape=(2,))
like1 = pm.Mixture('like1',
w=w1,
comp_dists=comp1,
Expand Down
24 changes: 12 additions & 12 deletions pymc3/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_dimensions(self):

def test_mixture_list_of_normals(self):
with Model() as model:
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size)
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
Mixture('x_obs', w,
Expand All @@ -98,7 +98,7 @@ def test_mixture_list_of_normals(self):

def test_normal_mixture(self):
with Model() as model:
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size)
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x)
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
with Model() as model0:
mus = Normal('mus', shape=comp_shape)
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
ws = Dirichlet('ws', np.ones(ncomp))
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
mixture0 = NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd,
comp_shape=comp_shape)
obs0 = NormalMixture('obs', w=ws, mu=mus, tau=taus, shape=nd,
Expand All @@ -145,7 +145,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
with Model() as model1:
mus = Normal('mus', shape=comp_shape)
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
ws = Dirichlet('ws', np.ones(ncomp))
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
comp_dist = [Normal.dist(mu=mus[..., i], tau=taus[..., i],
shape=nd)
for i in range(ncomp)]
Expand All @@ -163,7 +163,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
# comp_dists.
mus = Normal('mus', shape=comp_shape)
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
ws = Dirichlet('ws', np.ones(ncomp))
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
if len(nd) > 1:
if nd[-1] != ncomp:
with pytest.raises(ValueError):
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_normal_mixture_nd(self, nd, ncomp):

def test_poisson_mixture(self):
with Model() as model:
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)))
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)), shape=self.pois_w.shape)
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
Mixture('x_obs', w, Poisson.dist(mu), observed=self.pois_x)
step = Metropolis()
Expand All @@ -224,7 +224,7 @@ def test_poisson_mixture(self):

def test_mixture_list_of_poissons(self):
with Model() as model:
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)))
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)), shape=self.pois_w.shape)
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
Mixture('x_obs', w,
[Poisson.dist(mu[0]), Poisson.dist(mu[1])],
Expand All @@ -247,7 +247,7 @@ def test_mixture_of_mvn(self):
cov2 = np.diag([2.5, 3.5])
obs = np.asarray([[.5, .5], mu1, mu2])
with Model() as model:
w = Dirichlet('w', floatX(np.ones(2)), transform=None)
w = Dirichlet('w', floatX(np.ones(2)), transform=None, shape=(2,))
mvncomp1 = MvNormal.dist(mu=mu1, cov=cov1)
mvncomp2 = MvNormal.dist(mu=mu2, cov=cov2)
y = Mixture('x_obs', w, [mvncomp1, mvncomp2],
Expand Down Expand Up @@ -291,13 +291,13 @@ def test_mixture_of_mixture(self):
sigma=1,
shape=nbr)
# weight vector for the mixtures
g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None)
l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None)
g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None, shape=(nbr,))
l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None, shape=(nbr,))
# mixture components
g_mix = Mixture.dist(w=g_w, comp_dists=g_comp)
l_mix = Mixture.dist(w=l_w, comp_dists=l_comp)
# mixture of mixtures
mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None)
mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None, shape=(2,))
mix = Mixture('mix', w=mix_w,
comp_dists=[g_mix, l_mix],
observed=np.exp(self.norm_x))
Expand Down Expand Up @@ -378,7 +378,7 @@ def build_toy_dataset(N, K):
X, y = build_toy_dataset(N, K)

with pm.Model() as model:
pi = pm.Dirichlet('pi', np.ones(K))
pi = pm.Dirichlet('pi', np.ones(K), shape=(K,))

comp_dist = []
mu = []
Expand Down