Skip to content

Commit

Permalink
Fix Mixture distribution mode computation and logp dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 4, 2020
1 parent 1bf867e commit 8770259
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
12 changes: 8 additions & 4 deletions pymc3/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,13 @@ def __init__(self, w, comp_dists, *args, **kwargs):
dtype = kwargs.pop('dtype', default_dtype)

try:
comp_modes = self._comp_modes()
comp_mode_logps = self.logp(comp_modes)
self.mode = comp_modes[tt.argmax(w * comp_mode_logps, axis=-1)]
if isinstance(comp_dists, Distribution):
comp_mode_logps = comp_dists.logp(comp_dists.mode)
else:
comp_mode_logps = tt.stack([cd.logp(cd.mode) for cd in comp_dists])

mode_idx = tt.argmax(tt.log(w) + comp_mode_logps, axis=-1)
self.mode = self._comp_modes()[mode_idx]

if 'mode' not in defaults:
defaults.append('mode')
Expand Down Expand Up @@ -427,7 +431,7 @@ def logp(self, value):
"""
w = self.w

return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1),
return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1, keepdims=False),
w >= 0, w <= 1, tt.allclose(w.sum(axis=-1), 1),
broadcast_conditions=False)

Expand Down
5 changes: 3 additions & 2 deletions pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,11 @@ def tround(*args, **kwargs):
return tt.round(*args, **kwargs)


def logsumexp(x, axis=None):
def logsumexp(x, axis=None, keepdims=True):
# Adapted from https://github.com/Theano/Theano/issues/1563
x_max = tt.max(x, axis=axis, keepdims=True)
return tt.log(tt.sum(tt.exp(x - x_max), axis=axis, keepdims=True)) + x_max
res = tt.log(tt.sum(tt.exp(x - x_max), axis=axis, keepdims=True)) + x_max
return res if keepdims else res.squeeze()


def logaddexp(a, b):
Expand Down
15 changes: 13 additions & 2 deletions pymc3/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ def setup_class(cls):
cls.pois_mu = np.array([5., 20.])
cls.pois_x = generate_poisson_mixture_data(cls.pois_w, cls.pois_mu, size=1000)

def test_dimensions(self):
a1 = Normal.dist(mu=0, sigma=1)
a2 = Normal.dist(mu=10, sigma=1)
mix = Mixture.dist(w=np.r_[0.5, 0.5], comp_dists=[a1, a2])

assert mix.mode.ndim == 0
assert mix.logp(0.0).ndim == 0

value = np.r_[0.0, 1.0, 2.0]
assert mix.logp(value).ndim == 1

def test_mixture_list_of_normals(self):
with Model() as model:
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
Expand Down Expand Up @@ -252,7 +263,7 @@ def test_mixture_of_mvn(self):
# check logp of mixture
testpoint = model.test_point
mixlogp_st = logsumexp(np.log(testpoint['w']) + complogp_st,
axis=-1, keepdims=True)
axis=-1, keepdims=False)
assert_allclose(y.logp_elemwise(testpoint),
mixlogp_st)

Expand Down Expand Up @@ -321,7 +332,7 @@ def mixmixlogp(value, point):
complogp_mix = np.concatenate((mixlogp1, mixlogp2), axis=1)
mixmixlogpg = logsumexp(np.log(point['mix_w']).astype(floatX) +
complogp_mix,
axis=-1, keepdims=True)
axis=-1, keepdims=False)
return priorlogp, mixmixlogpg

value = np.exp(self.norm_x)[:, None]
Expand Down

0 comments on commit 8770259

Please sign in to comment.