Skip to content

Commit

Permalink
metrop_accept returns whether the sample was accepted
Browse files Browse the repository at this point in the history
  • Loading branch information
ColCarroll authored and twiecki committed Apr 21, 2017
1 parent 6c9e848 commit b2be720
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
7 changes: 4 additions & 3 deletions pymc3/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def metrop_select(mr, q, q0):
"""Perform rejection/acceptance step for Metropolis class samplers.
Returns the new sample q if a uniform random number is less than the
metropolis acceptance rate (`mr`), and the old sample otherwise.
metropolis acceptance rate (`mr`), and the old sample otherwise, along
with a boolean indicating whether the sample was accepted.
Parameters
----------
Expand All @@ -173,6 +174,6 @@ def metrop_select(mr, q, q0):
"""
# Compare acceptance ratio to uniform random number
if np.isfinite(mr) and np.log(uniform()) < mr:
return q
return q, True
else:
return q0
return q0, False
2 changes: 1 addition & 1 deletion pymc3/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def astep(self, q0):
initial_energy = self.compute_energy(q, p)
q, p, current_energy = self.leapfrog(q, p, e, n_steps)
energy_change = initial_energy - current_energy
return metrop_select(energy_change, q, q0)
return metrop_select(energy_change, q, q0)[0]

@staticmethod
def competence(var):
Expand Down
18 changes: 8 additions & 10 deletions pymc3/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,8 @@ def astep(self, q0):
q = floatX(q0 + delta)

accept = self.delta_logp(q, q0)
q_new = metrop_select(accept, q, q0)

if q_new is q:
self.accepted += 1
q_new, accepted = metrop_select(accept, q, q0)
self.accepted += accepted

self.steps_until_tune -= 1

Expand Down Expand Up @@ -264,7 +262,8 @@ def astep(self, q0, logp):
q[switch_locs] = True - q[switch_locs]

accept = logp(q) - logp(q0)
q_new = metrop_select(accept, q, q0)
q_new, accepted = metrop_select(accept, q, q0)
self.accepted += accepted

stats = {
'tune': self.tune,
Expand Down Expand Up @@ -325,8 +324,8 @@ def astep(self, q0, logp):
for idx in order:
curr_val, q[idx] = q[idx], True - q[idx]
logp_prop = logp(q)
q[idx] = metrop_select(logp_prop - logp_curr, q[idx], curr_val)
if q[idx] != curr_val:
q[idx], accepted = metrop_select(logp_prop - logp_curr, q[idx], curr_val)
if accepted:
logp_curr = logp_prop

return q
Expand Down Expand Up @@ -408,10 +407,9 @@ def astep_unif(self, q0, logp):
for dim, k in dimcats:
curr_val, q[dim] = q[dim], sample_except(k, q[dim])
logp_prop = logp(q)
q[dim] = metrop_select(logp_prop - logp_curr, q[dim], curr_val)
if q[dim] != curr_val:
q[dim], accepted = metrop_select(logp_prop - logp_curr, q[dim], curr_val)
if accepted:
logp_curr = logp_prop

return q

def astep_prop(self, q0, logp):
Expand Down
10 changes: 5 additions & 5 deletions pymc3/step_methods/smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .arraystep import metrop_select
from ..backends import smc_text as atext

__all__ = ('SMC', 'ATMIP_sample')
__all__ = ['SMC', 'ATMIP_sample']

EXPERIMENTAL_WARNING = "Warning: SMC is an experimental step method, and not yet"\
" recommended for use in PyMC3!"
Expand Down Expand Up @@ -242,10 +242,10 @@ def astep(self, q0):

if np.isfinite(varlogp):
logp = self.logp_forw(q)
q_new = metrop_select(
q_new, accepted = metrop_select(
self.beta * (logp[self._llk_index] - l0[self._llk_index]), q, q0)

if q_new is q:
if accepted:
self.accepted += 1
l_new = logp
self.chain_previous_lpoint[self.chain_index] = l_new
Expand All @@ -257,10 +257,10 @@ def astep(self, q0):

else:
logp = self.logp_forw(q)
q_new = metrop_select(
q_new, accepted = metrop_select(
self.beta * (logp[self._llk_index] - l0[self._llk_index]), q, q0)

if q_new is q:
if accepted:
self.accepted += 1
l_new = logp
self.chain_previous_lpoint[self.chain_index] = l_new
Expand Down

0 comments on commit b2be720

Please sign in to comment.