From 9c96a7554f66101968e437c8cd8867f80c3e6afa Mon Sep 17 00:00:00 2001 From: colin Date: Wed, 19 Apr 2017 13:05:26 +0000 Subject: [PATCH] metrop_accept returns whether the sample was accepted --- pymc3/step_methods/arraystep.py | 9 +++++---- pymc3/step_methods/hmc/hmc.py | 2 +- pymc3/step_methods/metropolis.py | 18 ++++++++---------- pymc3/step_methods/smc.py | 8 ++++---- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/pymc3/step_methods/arraystep.py b/pymc3/step_methods/arraystep.py index 7547c18e9e2..168e2bcb499 100644 --- a/pymc3/step_methods/arraystep.py +++ b/pymc3/step_methods/arraystep.py @@ -6,7 +6,7 @@ from numpy.random import uniform from enum import IntEnum, unique -__all__ = ['ArrayStep', 'ArrayStepShared', 'metrop_select', 'Competence'] +__all__ = ('ArrayStep', 'ArrayStepShared', 'metrop_select', 'Competence') @unique @@ -159,7 +159,8 @@ def metrop_select(mr, q, q0): """Perform rejection/acceptance step for Metropolis class samplers. Returns the new sample q if a uniform random number is less than the - metropolis acceptance rate (`mr`), and the old sample otherwise. + metropolis acceptance rate (`mr`), and the old sample otherwise, along + with a boolean indicating whether the sample was accepted. Parameters ---------- @@ -173,6 +174,6 @@ def metrop_select(mr, q, q0): """ # Compare acceptance ratio to uniform random number if np.isfinite(mr) and np.log(uniform()) < mr: - return q + return q, True else: - return q0 + return q0, False diff --git a/pymc3/step_methods/hmc/hmc.py b/pymc3/step_methods/hmc/hmc.py index f8159593ee6..05d7cb48f72 100644 --- a/pymc3/step_methods/hmc/hmc.py +++ b/pymc3/step_methods/hmc/hmc.py @@ -64,7 +64,7 @@ def astep(self, q0): initial_energy = self.compute_energy(q, p) q, p, current_energy = self.leapfrog(q, p, e, n_steps) energy_change = initial_energy - current_energy - return metrop_select(energy_change, q, q0) + return metrop_select(energy_change, q, q0)[0] @staticmethod def competence(var): diff --git a/pymc3/step_methods/metropolis.py b/pymc3/step_methods/metropolis.py index 9d833c149fe..7dc0f11f813 100644 --- a/pymc3/step_methods/metropolis.py +++ b/pymc3/step_methods/metropolis.py @@ -150,10 +150,8 @@ def astep(self, q0): q = floatX(q0 + delta) accept = self.delta_logp(q, q0) - q_new = metrop_select(accept, q, q0) - - if q_new is q: - self.accepted += 1 + q_new, accepted = metrop_select(accept, q, q0) + self.accepted += accepted self.steps_until_tune -= 1 @@ -264,7 +262,8 @@ def astep(self, q0, logp): q[switch_locs] = True - q[switch_locs] accept = logp(q) - logp(q0) - q_new = metrop_select(accept, q, q0) + q_new, accepted = metrop_select(accept, q, q0) + self.accepted += accepted stats = { 'tune': self.tune, @@ -325,8 +324,8 @@ def astep(self, q0, logp): for idx in order: curr_val, q[idx] = q[idx], True - q[idx] logp_prop = logp(q) - q[idx] = metrop_select(logp_prop - logp_curr, q[idx], curr_val) - if q[idx] != curr_val: + q[idx], accepted = metrop_select(logp_prop - logp_curr, q[idx], curr_val) + if accepted: logp_curr = logp_prop return q @@ -408,10 +407,9 @@ def astep_unif(self, q0, logp): for dim, k in dimcats: curr_val, q[dim] = q[dim], sample_except(k, q[dim]) logp_prop = logp(q) - q[dim] = metrop_select(logp_prop - logp_curr, q[dim], curr_val) - if q[dim] != curr_val: + q[dim], accepted = metrop_select(logp_prop - logp_curr, q[dim], curr_val) + if accepted: logp_curr = logp_prop - return q def astep_prop(self, q0, logp): diff --git a/pymc3/step_methods/smc.py b/pymc3/step_methods/smc.py index 64389833d14..98e5b60c7d1 100644 --- a/pymc3/step_methods/smc.py +++ b/pymc3/step_methods/smc.py @@ -242,10 +242,10 @@ def astep(self, q0): if np.isfinite(varlogp): logp = self.logp_forw(q) - q_new = metrop_select( + q_new, accepted = metrop_select( self.beta * (logp[self._llk_index] - l0[self._llk_index]), q, q0) - if q_new is q: + if accepted: self.accepted += 1 l_new = logp self.chain_previous_lpoint[self.chain_index] = l_new @@ -257,10 +257,10 @@ def astep(self, q0): else: logp = self.logp_forw(q) - q_new = metrop_select( + q_new, accepted = metrop_select( self.beta * (logp[self._llk_index] - l0[self._llk_index]), q, q0) - if q_new is q: + if accepted: self.accepted += 1 l_new = logp self.chain_previous_lpoint[self.chain_index] = l_new