Skip to content

Commit

Permalink
Properly handle contraction of guide plates in TraceEnum_ELBO (#1537)
Browse files Browse the repository at this point in the history
* fix guide plate contraction

* remove comment line

* Use Poisson dist
  • Loading branch information
ordabayevy authored Feb 9, 2023
1 parent 643412d commit 562e1be
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 43 deletions.
76 changes: 33 additions & 43 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict, defaultdict
from functools import partial, reduce
from functools import partial
from operator import itemgetter
import warnings

Expand Down Expand Up @@ -871,15 +871,8 @@ def __init__(self, num_particles=1, max_plate_nesting=float("inf")):

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
def single_particle_elbo(rng_key):
from opt_einsum import shared_intermediates

import funsor
from funsor.cnf import _eager_contract_tensors
from numpyro.contrib.funsor import to_data, to_funsor

logsumexp_backend = "funsor.einsum.numpy_log"
with shared_intermediates() as cache: # create a cache
pass
from numpyro.contrib.funsor import to_data

model_seed, guide_seed = random.split(rng_key)

Expand Down Expand Up @@ -935,7 +928,6 @@ def single_particle_elbo(rng_key):
cost = model_trace[name]["log_prob"]
scale = model_trace[name]["scale"]
deps = model_deps[name]
dice_factors = [guide_trace[key]["log_measure"] for key in deps]
else:
# compute contracted cost term
group_factors = tuple(
Expand All @@ -952,13 +944,16 @@ def single_particle_elbo(rng_key):
*(frozenset(f.inputs) & group_plates for f in group_factors)
)
elim_plates = group_plates - outermost_plates
cost = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
group_factors,
plates=group_plates,
eliminate=group_sum_vars | elim_plates,
)
with funsor.interpretations.normalize:
cost = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
group_factors,
plates=group_plates,
eliminate=group_sum_vars | elim_plates,
)
# TODO: add memoization
cost = funsor.optimizer.apply_optimizer(cost)
# incorporate the effects of subsampling and handlers.scale through a common scale factor
scales_set = set()
for name in group_names | group_sum_vars:
Expand Down Expand Up @@ -992,43 +987,38 @@ def single_particle_elbo(rng_key):
f"model enumeration sites upstream of guide site '{key}' in plate('{plate}')."
"Try converting some model enumeration sites to guide enumeration sites."
)
# combine dice factors
dice_factors = [
guide_trace[key]["log_measure"].reduce(
funsor.ops.add,
frozenset(guide_trace[key]["log_measure"].inputs)
& elim_plates,
)
for key in deps
]
cost_terms.append((cost, scale, dice_factors))
cost_terms.append((cost, scale, deps))

for name, deps in guide_deps.items():
# -logq cost term
cost = -guide_trace[name]["log_prob"]
scale = guide_trace[name]["scale"]
dice_factors = [guide_trace[key]["log_measure"] for key in deps]
cost_terms.append((cost, scale, dice_factors))
cost_terms.append((cost, scale, deps))

# compute elbo
elbo = 0.0
for cost, scale, dice_factors in cost_terms:
if dice_factors:
reduced_vars = (
frozenset().union(*[f.input_vars for f in dice_factors])
- cost.input_vars
for cost, scale, deps in cost_terms:
if deps:
dice_factors = tuple(
guide_trace[key]["log_measure"] for key in deps
)
if reduced_vars:
# use opt_einsum to reduce vars not present in the cost term
with shared_intermediates(cache):
dice_factor = _eager_contract_tensors(
reduced_vars, dice_factors, backend=logsumexp_backend
)
else:
dice_factor = reduce(lambda a, b: a + b, dice_factors)
dice_factor_vars = frozenset().union(
*[f.inputs for f in dice_factors]
)
cost_vars = frozenset(cost.inputs)
with funsor.interpretations.normalize:
dice_factor = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
dice_factors,
plates=(dice_factor_vars | cost_vars) - model_vars,
eliminate=dice_factor_vars - cost_vars,
)
# TODO: add memoization
dice_factor = funsor.optimizer.apply_optimizer(dice_factor)
cost = cost * funsor.ops.exp(dice_factor)
if (scale is not None) and (not is_identically_one(scale)):
cost = cost * to_funsor(scale)
cost = cost * scale

elbo = elbo + cost.reduce(funsor.ops.add)

Expand Down
41 changes: 41 additions & 0 deletions test/contrib/test_enum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2472,3 +2472,44 @@ def actual_loss_fn(params_raw):

assert_equal(actual_loss, expected_loss, prec=1e-3)
assert_equal(actual_grads, expected_grads, prec=1e-5)


def test_guide_plate_contraction():
def model(params):
with pyro.plate("a_axis", size=2):
a = pyro.sample("a", dist.Poisson(jnp.array(3.0)))
pyro.sample("b", dist.Normal(jnp.sum(a), 1.0), obs=1)

def guide(params):
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.positive
)
with pyro.plate("a_axis", size=2):
pyro.sample("a", dist.Poisson(probs_a))

params = {
"probs_a": jnp.array([3.0, 2.5]),
}
transform = dist.biject_to(dist.constraints.positive)
params_raw = jax.tree_util.tree_map(transform.inv, params)

# TraceGraph_ELBO grads averaged over num_particles
elbo = infer.TraceGraph_ELBO(num_particles=50_000)

def graph_loss_fn(params_raw):
params = jax.tree_util.tree_map(transform, params_raw)
return elbo.loss(random.PRNGKey(0), {}, model, guide, params)

graph_loss, graph_grads = jax.value_and_grad(graph_loss_fn)(params_raw)

# TraceEnum_ELBO grads averaged over num_particles (no enumeration)
elbo = infer.TraceEnum_ELBO(num_particles=50_000)

def enum_loss_fn(params_raw):
params = jax.tree_util.tree_map(transform, params_raw)
return elbo.loss(random.PRNGKey(0), {}, model, guide, params)

enum_loss, enum_grads = jax.value_and_grad(enum_loss_fn)(params_raw)

assert_equal(enum_loss, graph_loss, prec=1e-3)
assert_equal(enum_grads, graph_grads, prec=1e-2)

0 comments on commit 562e1be

Please sign in to comment.