Skip to content

Commit

Permalink
Allow to use NeuTra on models with plates (#1826)
Browse files Browse the repository at this point in the history
* allow to use NeuTra with plate

* Fix typo in reparam.py
  • Loading branch information
fehiepsi authored Jul 2, 2024
1 parent 2ed9f92 commit 89ca117
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
8 changes: 7 additions & 1 deletion numpyro/infer/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class NeuTraReparam(Reparam):
# Step 2. Use trained guide in NeuTra MCMC
neutra = NeuTraReparam(guide)
model = netra.reparam(model)
model = neutra.reparam(model)
nuts = NUTS(model)
# ...now use the model in HMC or NUTS...
Expand Down Expand Up @@ -281,9 +281,15 @@ def __call__(self, name, fn, obs):
compute_density = numpyro.get_mask() is not False
if not self._x_unconstrained: # On first sample site.
# Sample a shared latent.
model_plates = {
msg["name"]
for msg in self.guide.prototype_trace.values()
if msg["type"] == "plate"
}
z_unconstrained = numpyro.sample(
"{}_shared_latent".format(self.guide.prefix),
self.guide.get_base_dist().mask(False),
infer={"block_plates": model_plates},
)

# Differentiably transform.
Expand Down
6 changes: 6 additions & 0 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,12 @@ def process_message(self, msg):
)
return

if (
"block_plates" in msg.get("infer", {})
and self.name in msg["infer"]["block_plates"]
):
return

cond_indep_stack = msg["cond_indep_stack"]
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
Expand Down
18 changes: 17 additions & 1 deletion test/infer/test_reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from numpyro.distributions.transforms import AffineTransform, ExpTransform
import numpyro.handlers as handlers
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoIAFNormal
from numpyro.infer.autoguide import AutoDiagonalNormal, AutoIAFNormal
from numpyro.infer.reparam import (
CircularReparam,
ExplicitReparam,
Expand Down Expand Up @@ -228,6 +228,22 @@ def test_neutra_reparam_unobserved_model():
reparam_model(data=None)


def test_neutra_reparam_with_plate():
def model():
with numpyro.plate("N", 3, dim=-1):
x = numpyro.sample("x", dist.Normal(0, 1))
assert x.shape == (3,)

guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, Adam(1e-3), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
params = svi.get_params(svi_state)
neutra = NeuTraReparam(guide, params)
reparam_model = neutra.reparam(model)
with handlers.seed(rng_seed=0):
reparam_model()


@pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str)
@pytest.mark.parametrize("centered", [0.0, 0.6, 1.0, None])
@pytest.mark.parametrize("dist_type", ["Normal", "StudentT"])
Expand Down

0 comments on commit 89ca117

Please sign in to comment.