From d40f0e9c88de7d407cb9f3b90dfcd085a53ba12b Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Mon, 1 Jul 2024 18:19:09 +0200 Subject: [PATCH] Fixes `random_flax_module` with `flax.linen.BatchNorm` (#1823) * filter oout tests waiting for next tfp release * fix issue 1446 * add feddback (not working) * feedbackl 2 * default handler * rm prng_key from substitute * remove class from function --- numpyro/handlers.py | 2 +- numpyro/infer/util.py | 12 +++++++++++- test/contrib/test_module.py | 13 ++++++++++++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 6e13aeb70..11fdde45c 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -817,7 +817,7 @@ def process_message(self, msg): return if self.data is not None: - value = self.data.get(msg["name"]) + value = self.data.get(msg.get("name")) else: value = self.substitute_fn(msg) diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 3775893a7..56c4d2c4d 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -21,6 +21,7 @@ from numpyro.distributions.util import is_identically_one, sum_rightmost from numpyro.handlers import condition, replay, seed, substitute, trace from numpyro.infer.initialization import init_to_uniform, init_to_value +from numpyro.primitives import Messenger from numpyro.util import ( _validate_model, find_stack_level, @@ -46,6 +47,12 @@ ParamInfo = namedtuple("ParamInfo", ["z", "potential_energy", "z_grad"]) +class _substitute_default_key(Messenger): + def process_message(self, msg): + if msg["type"] == "prng_key" and msg["value"] is None: + msg["value"] = random.PRNGKey(0) + + def log_density(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given @@ -660,9 +667,12 @@ def initialize_model( data={ k: site["value"] for k, site in model_trace.items() - if site["type"] in ["param"] + if site["type"] in ["param", "mutable"] }, ) + + model = _substitute_default_key(model) + constrained_values = { k: v["value"] for k, v in model_trace.items() diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index d3f27f17e..a1342507f 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -21,7 +21,8 @@ random_haiku_module, ) import numpyro.distributions as dist -from numpyro.infer import MCMC, NUTS +from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO +from numpyro.infer.autoguide import AutoDelta pytestmark = pytest.mark.filterwarnings( "ignore:jax.tree_.+ is deprecated:FutureWarning" @@ -256,6 +257,11 @@ def model(): else: assert set(tr.keys()) == {"nn$params", "x", "y"} + # test svi + guide = AutoDelta(model) + svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) + svi.run(random.PRNGKey(100), 10) + @pytest.mark.parametrize("dropout", [True, False]) @pytest.mark.parametrize("batchnorm", [True, False]) @@ -300,3 +306,8 @@ def model(): assert tr["nn$state"]["type"] == "mutable" else: assert set(tr.keys()) == {"nn$params", "x", "y"} + + # test svi + guide = AutoDelta(model) + svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) + svi.run(random.PRNGKey(100), 10)