Skip to content

Commit

Permalink
Add float64 support to PHMC.
Browse files Browse the repository at this point in the history
Previously, had the exception: "TypeError: Tensors in list passed to 'inputs' of 'AddN' Op have types [float64, float64, float32] that don't all match."

PiperOrigin-RevId: 347871205
  • Loading branch information
brianwa84 authored and jburnim committed Dec 21, 2020
1 parent d3bf5a0 commit 5f0dbec
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def _prepare_args(target_log_prob_fn,
def _batched_isotropic_normal_like(state_part):
event_ndims = ps.rank(state_part) - batch_rank
return independent.Independent(
normal.Normal(ps.zeros_like(state_part, tf.float32), 1.),
normal.Normal(ps.zeros_like(state_part), 1.),
reinterpreted_batch_ndims=event_ndims)

momentum_distribution = jds.JointDistributionSequential(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,36 @@ def test_correctness_with_200d_mvn_tril(self, precondition_scheme):
dict(testcase_name='_explicit', use_default=False))
class PreconditionedHMCTest(test_util.TestCase):

def test_f64(self, use_default):
if use_default:
momentum_distribution = None
else:
momentum_distribution = tfp.experimental.as_composite(
tfd.Normal(0., tf.constant(.5, dtype=tf.float64)))
kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
lambda x: -x**2, step_size=.5, num_leapfrog_steps=2,
momentum_distribution=momentum_distribution)
kernel = tfp.mcmc.SimpleStepSizeAdaptation(kernel, num_adaptation_steps=3)
self.evaluate(tfp.mcmc.sample_chain(
1, kernel=kernel, current_state=tf.ones([], tf.float64),
num_burnin_steps=5, trace_fn=None))

# TODO(b/175787154): Enable this test
def DISABLED_test_f64_multichain(self, use_default):
if use_default:
momentum_distribution = None
else:
momentum_distribution = tfp.experimental.as_composite(
tfd.Normal(0., tf.constant(.5, dtype=tf.float64)))
kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
lambda x: -x**2, step_size=.5, num_leapfrog_steps=2,
momentum_distribution=momentum_distribution)
kernel = tfp.mcmc.SimpleStepSizeAdaptation(kernel, num_adaptation_steps=3)
nchains = 7
self.evaluate(tfp.mcmc.sample_chain(
1, kernel=kernel, current_state=tf.ones([nchains], tf.float64),
num_burnin_steps=5, trace_fn=None))

def test_diag(self, use_default):
"""Test that a diagonal multivariate normal can be effectively sampled from.
Expand Down

0 comments on commit 5f0dbec

Please sign in to comment.