Skip to content

Commit

Permalink
Fixed bug in delta_logp for MLDA that broke AEM and VR
Browse files Browse the repository at this point in the history
  • Loading branch information
mikkelbue authored and twiecki committed Oct 26, 2021
1 parent 38295b7 commit 9e7e8aa
Showing 1 changed file with 8 additions and 35 deletions.
43 changes: 8 additions & 35 deletions pymc/step_methods/mlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,14 @@ def __init__(self, *args, **kwargs):
self.Q_last = np.nan
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above

# extract some necessary variables
vars = kwargs.get("vars", None)
if vars is None:
vars = model.value_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = pm.inputvars(vars)
shared = pm.make_shared_replacements(initial_values, vars, model)

# call parent class __init__
super().__init__(*args, **kwargs)

# modify the delta function and point to model if VR is used
if self.mlda_variance_reduction:
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
self.model = model
self.delta_logp_factory = self.delta_logp
self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q)

def reset_tuning(self):
"""
Expand Down Expand Up @@ -136,22 +128,14 @@ def __init__(self, *args, **kwargs):
self.Q_last = np.nan
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above

# extract some necessary variables
vars = kwargs.get("vars", None)
if vars is None:
vars = model.value_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = pm.inputvars(vars)
shared = pm.make_shared_replacements(initial_values, vars, model)

# call parent class __init__
super().__init__(*args, **kwargs)

# modify the delta function and point to model if VR is used
if self.mlda_variance_reduction:
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
self.model = model
self.delta_logp_factory = self.delta_logp
self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q)

def reset_tuning(self):
"""Skips resetting of tuned sampler parameters
Expand Down Expand Up @@ -556,7 +540,7 @@ def __init__(
# Construct Aesara function for current-level model likelihood
# (for use in acceptance)
shared = pm.make_shared_replacements(initial_values, vars, model)
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
self.delta_logp = delta_logp(initial_values, model.logpt, vars, shared)

# Construct Aesara function for below-level model likelihood
# (for use in acceptance)
Expand Down Expand Up @@ -749,7 +733,9 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
accept = np.float64(0.0)
skipped_logp = True
else:
accept = self.delta_logp(q.data, q0.data) + self.delta_logp_below(q0.data, q.data)
# NB! The order and sign of the first term are swapped compared
# to the convention to make sure the proposal is evaluated last.
accept = -self.delta_logp(q0.data, q.data) + self.delta_logp_below(q0.data, q.data)
skipped_logp = False

# Accept/reject sample - next sample is stored in q_new
Expand Down Expand Up @@ -954,19 +940,6 @@ def update(self, x):
self.t += 1


def delta_logp_inverse(point, logp, vars, shared):
[logp0], inarray0 = pm.join_nonshared_inputs(point, [logp], vars, shared)

tensor_type = inarray0.type
inarray1 = tensor_type("inarray1")

logp1 = pm.CallableTensor(logp0)(inarray1)

f = compile_rv_inplace([inarray1, inarray0], -logp0 + logp1)
f.trust_input = True
return f


def extract_Q_estimate(trace, levels):
"""
Returns expectation and standard error of quantity of interest,
Expand Down

0 comments on commit 9e7e8aa

Please sign in to comment.