Skip to content

Commit

Permalink
fix issue 1446
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Jun 30, 2024
1 parent f38ef77 commit cbed1f5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def initialize_model(
data={
k: site["value"]
for k, site in model_trace.items()
if site["type"] in ["param"]
if site["type"] in ["param", "mutable"]
},
)
constrained_values = {
Expand Down
17 changes: 14 additions & 3 deletions test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
random_haiku_module,
)
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta

pytestmark = pytest.mark.filterwarnings(
"ignore:jax.tree_.+ is deprecated:FutureWarning"
Expand Down Expand Up @@ -242,7 +243,7 @@ def model():
nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3))
x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))
if dropout:
y = nn(numpyro.prng_key(), x)
y = nn(random.PRNGKey(0), x)
else:
y = nn(x)
numpyro.deterministic("y", y)
Expand All @@ -256,6 +257,11 @@ def model():
else:
assert set(tr.keys()) == {"nn$params", "x", "y"}

# test svi
guide = AutoDelta(model)
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
svi.run(random.PRNGKey(100), 10)


@pytest.mark.parametrize("dropout", [True, False])
@pytest.mark.parametrize("batchnorm", [True, False])
Expand Down Expand Up @@ -287,7 +293,7 @@ def model():
)
x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))
if dropout:
y = net(x, rngs={"dropout": numpyro.prng_key()})
y = net(x, rngs={"dropout": random.PRNGKey(0)})
else:
y = net(x)
numpyro.deterministic("y", y)
Expand All @@ -300,3 +306,8 @@ def model():
assert tr["nn$state"]["type"] == "mutable"
else:
assert set(tr.keys()) == {"nn$params", "x", "y"}

# test svi
guide = AutoDelta(model)
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
svi.run(random.PRNGKey(100), 10)

0 comments on commit cbed1f5

Please sign in to comment.