diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index d64321aa7..677b13a81 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -274,14 +274,32 @@ def __call__(self, *args, **kwargs): if self.prototype_trace is None: # run model to inspect the model structure self._setup_prototype(*args, **kwargs) + check_deterministic_sites = True + else: + check_deterministic_sites = False # create all plates self._create_plates(*args, **kwargs) - # run slave guides + # run sub-guides result = {} for part in self._guides: result.update(part(*args, **kwargs)) + + # Check deterministic sites after calling sub-guides because they are not + # initialized prior to the first call. We do not check guides that do not have + # a prototype_trace attribute, e.g., custom guides. + if check_deterministic_sites: + for i, part in enumerate(self._guides): + prototype_trace = getattr(part, "prototype_trace", None) + if prototype_trace: + for key, value in prototype_trace.items(): + if value["type"] == "deterministic": + raise ValueError( + f"deterministic site '{key}' in sub-guide at position " + f"{i} should not be exposed" + ) + return result def __getitem__(self, key): diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 5557032a9..9394e17a1 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -157,7 +157,7 @@ def model(data=None, labels=None): rng_key_init = random.PRNGKey(1) if auto_class == AutoGuideList: guide = AutoGuideList(model) - guide.append(AutoNormal(handlers.block(model, hide=[]))) + guide.append(AutoNormal(handlers.block(model, hide=["logits"]))) else: guide = auto_class(model, init_loc_fn=init_strategy) svi = SVI(model, guide, adam, Elbo()) @@ -1157,6 +1157,18 @@ def model(x, y=None): part.quantiles(params=params, quantiles=[0.2, 0.5, 0.8]) +def test_autoguidelist_deterministic(): + def model(): + x = numpyro.sample("x", dist.Normal()) + numpyro.deterministic("x2", x**2) + + guide = AutoGuideList(model) + guide.append(AutoDiagonalNormal(model)) + seeded_guide = handlers.seed(guide, 8) + with pytest.raises(ValueError, match="should not be exposed"): + seeded_guide() + + @pytest.mark.parametrize( "auto_class", [ @@ -1183,6 +1195,7 @@ def model(): numpyro.deterministic("x2", x**2) guide = AutoGuideList(model) + blocked_model = handlers.block(handlers.seed(model, 7), hide=["x2"]) # AutoGuideList does not support AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS if auto_class == AutoDAIS: @@ -1190,24 +1203,24 @@ def model(): ValueError, match="AutoDAIS, AutoSemiDAIS, and AutoSurrogateLikelihoodDAIS are not supported.", ): - guide.append(auto_class(model)) + guide.append(auto_class(blocked_model)) return if auto_class == AutoSemiDAIS: with pytest.raises( ValueError, match="AutoDAIS, AutoSemiDAIS, and AutoSurrogateLikelihoodDAIS are not supported.", ): - guide.append(auto_class(model, local_model=None, global_guide=None)) + guide.append(auto_class(blocked_model, local_model=None, global_guide=None)) return if auto_class == AutoSurrogateLikelihoodDAIS: with pytest.raises( ValueError, match="AutoDAIS, AutoSemiDAIS, and AutoSurrogateLikelihoodDAIS are not supported.", ): - guide.append(auto_class(model, surrogate_model=None)) + guide.append(auto_class(blocked_model, surrogate_model=None)) return - guide.append(auto_class(model)) + guide.append(auto_class(blocked_model)) svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO()) if auto_class in (AutoIAFNormal, AutoBNAFNormal) and max(shape, default=0) <= 1: with pytest.raises( @@ -1228,7 +1241,9 @@ def model(): sample_shape=sample_shape, ) assert guide_samples["x"].shape == sample_shape + shape - assert guide_samples["x2"].shape == sample_shape + shape + # Substitute and trace to get the deterministic sites. + trace = handlers.trace(handlers.substitute(model, guide_samples)).get_trace() + assert trace["x2"]["value"].shape == sample_shape + shape @pytest.mark.parametrize("use_global_dais_params", [True, False])