Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes random_flax_module with flax.linen.BatchNorm #1823

Merged
merged 8 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def process_message(self, msg):
return

if self.data is not None:
value = self.data.get(msg["name"])
value = self.data.get(msg.get("name"))
else:
value = self.substitute_fn(msg)

Expand Down
12 changes: 11 additions & 1 deletion numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from numpyro.distributions.util import is_identically_one, sum_rightmost
from numpyro.handlers import condition, replay, seed, substitute, trace
from numpyro.infer.initialization import init_to_uniform, init_to_value
from numpyro.primitives import Messenger
from numpyro.util import (
_validate_model,
find_stack_level,
Expand All @@ -46,6 +47,12 @@
ParamInfo = namedtuple("ParamInfo", ["z", "potential_energy", "z_grad"])


class _substitute_default_key(Messenger):
def process_message(self, msg):
if msg["type"] == "prng_key" and msg["value"] is None:
msg["value"] = random.PRNGKey(0)


def log_density(model, model_args, model_kwargs, params):
"""
(EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
Expand Down Expand Up @@ -660,9 +667,12 @@ def initialize_model(
data={
k: site["value"]
for k, site in model_trace.items()
if site["type"] in ["param"]
if site["type"] in ["param", "mutable"]
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
},
)

model = _substitute_default_key(model)

juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
constrained_values = {
k: v["value"]
for k, v in model_trace.items()
Expand Down
13 changes: 12 additions & 1 deletion test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
random_haiku_module,
)
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta

pytestmark = pytest.mark.filterwarnings(
"ignore:jax.tree_.+ is deprecated:FutureWarning"
Expand Down Expand Up @@ -256,6 +257,11 @@ def model():
else:
assert set(tr.keys()) == {"nn$params", "x", "y"}

# test svi
guide = AutoDelta(model)
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
svi.run(random.PRNGKey(100), 10)


@pytest.mark.parametrize("dropout", [True, False])
@pytest.mark.parametrize("batchnorm", [True, False])
Expand Down Expand Up @@ -300,3 +306,8 @@ def model():
assert tr["nn$state"]["type"] == "mutable"
else:
assert set(tr.keys()) == {"nn$params", "x", "y"}

# test svi
guide = AutoDelta(model)
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
svi.run(random.PRNGKey(100), 10)
Comment on lines +310 to +313
Copy link
Contributor Author

@juanitorduz juanitorduz Jun 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without these SVI lines, the tests were passing in the master branch. Once added, we got the error. The suggested fix in the issue does solve it.

Loading