Skip to content

Commit

Permalink
Merge pull request #39 from alan-turing-institute/mlda_develop
Browse files Browse the repository at this point in the history
Improve stats calculation
  • Loading branch information
gmingas authored Jul 16, 2020
2 parents 9278ddc + 9d25cd2 commit db022a5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
25 changes: 13 additions & 12 deletions pymc3/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,8 @@ class MLDA(ArrayStepShared):
stats_dtypes = [{
'accept': np.float64,
'accepted': np.bool,
'tune': np.bool
'tune': np.bool,
'base_scaling': object
}]

def __init__(self, coarse_models, vars=None, base_S=None, base_proposal_dist=None,
Expand Down Expand Up @@ -1102,13 +1103,6 @@ def __init__(self, coarse_models, vars=None, base_S=None, base_proposal_dist=Non
self.tune,
self.subsampling_rates[-1])

# Update stats data types dictionary given vars and base_blocked
if self.base_blocked or len(self.vars) == 1:
self.stats_dtypes[0]['base_scaling'] = np.float64
else:
for name in self.var_names:
self.stats_dtypes[0]['base_scaling_' + name] = np.float64

def astep(self, q0):
"""One MLDA step, given current sample q0"""
# Check if the tuning flag has been changed and if yes,
Expand Down Expand Up @@ -1137,11 +1131,15 @@ def astep(self, q0):
# do not calculate likelihood, just set accept to 0.0
if (q == q0).all():
accept = np.float(0.0)
skipped_logp = True
else:
accept = self.delta_logp(q, q0) + self.delta_logp_next(q0, q)
skipped_logp = False

# Accept/reject sample - next sample is stored in q_new
q_new, accepted = metrop_select(accept, q, q0)
if skipped_logp:
accepted = False

# Update acceptance counter
self.accepted += accepted
Expand All @@ -1155,12 +1153,15 @@ def astep(self, q0):
# Capture latest base chain scaling stats from next step method
self.base_scaling_stats = {}
if isinstance(self.next_step_method, CompoundStep):
scaling_list = []
for method in self.next_step_method.methods:
self.base_scaling_stats["base_scaling_" + method.vars[0].name] = method.scaling
elif isinstance(self.next_step_method, Metropolis):
self.base_scaling_stats["base_scaling"] = self.next_step_method.scaling
scaling_list.append(method.scaling)
self.base_scaling_stats = {"base_scaling": np.array(scaling_list)}
elif not isinstance(self.next_step_method, MLDA):
# next method is any block sampler
self.base_scaling_stats = {"base_scaling": np.array(self.next_step_method.scaling)}
else:
# next method is MLDA
# next method is MLDA - propagate dict from lower levels
self.base_scaling_stats = self.next_step_method.base_scaling_stats
stats = {**stats, **self.base_scaling_stats}

Expand Down
32 changes: 17 additions & 15 deletions pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,10 +1126,10 @@ def test_acceptance_rate_against_coarseness(self):
Normal("x", 5.0, 1.0)

with Model() as coarse_model_1:
Normal("x", 5.5, 1.5)
Normal("x", 6.0, 2.0)

with Model() as coarse_model_2:
Normal("x", 6.0, 2.0)
Normal("x", 20.0, 5.0)

possible_coarse_models = [coarse_model_0,
coarse_model_1,
Expand All @@ -1139,9 +1139,9 @@ def test_acceptance_rate_against_coarseness(self):
with Model():
Normal("x", 5.0, 1.0)
for coarse_model in possible_coarse_models:
step = MLDA(coarse_models=[coarse_model], subsampling_rates=1,
tune=False)
trace = sample(chains=1, draws=500, tune=0, step=step)
step = MLDA(coarse_models=[coarse_model], subsampling_rates=3,
tune=True)
trace = sample(chains=1, draws=500, tune=100, step=step)
acc.append(trace.get_sampler_stats('accepted').mean())
assert acc[0] > acc[1] > acc[2], "Acceptance rate is not " \
"strictly increasing when" \
Expand Down Expand Up @@ -1197,10 +1197,10 @@ def test_tuning_and_scaling_on(self):
assert trace.get_sampler_stats('tune', chains=0)[ts - 1]
assert not trace.get_sampler_stats('tune', chains=0)[ts]
assert not trace.get_sampler_stats('tune', chains=0)[-1]
assert trace.get_sampler_stats('base_scaling_x', chains=0)[0] == 100.
assert trace.get_sampler_stats('base_scaling_y_logodds__', chains=0)[0] == 100.
assert trace.get_sampler_stats('base_scaling_x', chains=0)[-1] < 100.
assert trace.get_sampler_stats('base_scaling_y_logodds__', chains=0)[-1] < 100.
assert trace.get_sampler_stats('base_scaling', chains=0)[0][0] == 100.
assert trace.get_sampler_stats('base_scaling', chains=0)[0][1] == 100.
assert trace.get_sampler_stats('base_scaling', chains=0)[-1][0] < 100.
assert trace.get_sampler_stats('base_scaling', chains=0)[-1][1] < 100.

def test_tuning_and_scaling_off(self):
"""Test that tuning is deactivated when sample()'s tune=0 and that
Expand Down Expand Up @@ -1239,17 +1239,19 @@ def test_tuning_and_scaling_off(self):

assert not trace_0.get_sampler_stats('tune', chains=0)[0]
assert not trace_0.get_sampler_stats('tune', chains=0)[-1]
assert trace_0.get_sampler_stats('base_scaling_x', chains=0)[0] == \
trace_0.get_sampler_stats('base_scaling_x', chains=0)[-1] == 100.
assert trace_0.get_sampler_stats('base_scaling', chains=0)[0][0] == \
trace_0.get_sampler_stats('base_scaling', chains=0)[-1][0] == \
trace_0.get_sampler_stats('base_scaling', chains=0)[0][1] == \
trace_0.get_sampler_stats('base_scaling', chains=0)[-1][1] == 100.

assert trace_1.get_sampler_stats('tune', chains=0)[0]
assert trace_1.get_sampler_stats('tune', chains=0)[ts_1 - 1]
assert not trace_1.get_sampler_stats('tune', chains=0)[ts_1]
assert not trace_1.get_sampler_stats('tune', chains=0)[-1]
assert trace_1.get_sampler_stats('base_scaling_x', chains=0)[0] == 100.
assert trace_1.get_sampler_stats('base_scaling_y_logodds__', chains=0)[0] == 100.
assert trace_1.get_sampler_stats('base_scaling_x', chains=0)[-1] < 100.
assert trace_1.get_sampler_stats('base_scaling_y_logodds__', chains=0)[-1] < 100.
assert trace_1.get_sampler_stats('base_scaling', chains=0)[0][0] == 100.
assert trace_1.get_sampler_stats('base_scaling', chains=0)[0][1] == 100.
assert trace_1.get_sampler_stats('base_scaling', chains=0)[-1][0] < 100.
assert trace_1.get_sampler_stats('base_scaling', chains=0)[-1][1] < 100.

def test_trace_length(self):
"""Check if trace length is as expected."""
Expand Down

0 comments on commit db022a5

Please sign in to comment.