Skip to content

Commit

Permalink
Raise ValueError if deterministic site is exposed in sub-guide. (#1757
Browse files Browse the repository at this point in the history
)

* Raise `ValueError` if deterministic site is exposed in sub-guide.

* Add test to verify deterministic site check in `AutoGuideList`.
  • Loading branch information
tillahoffmann authored Mar 14, 2024
1 parent 7769a32 commit 79d9e6b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
20 changes: 19 additions & 1 deletion numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,32 @@ def __call__(self, *args, **kwargs):
if self.prototype_trace is None:
# run model to inspect the model structure
self._setup_prototype(*args, **kwargs)
check_deterministic_sites = True
else:
check_deterministic_sites = False

# create all plates
self._create_plates(*args, **kwargs)

# run slave guides
# run sub-guides
result = {}
for part in self._guides:
result.update(part(*args, **kwargs))

# Check deterministic sites after calling sub-guides because they are not
# initialized prior to the first call. We do not check guides that do not have
# a prototype_trace attribute, e.g., custom guides.
if check_deterministic_sites:
for i, part in enumerate(self._guides):
prototype_trace = getattr(part, "prototype_trace", None)
if prototype_trace:
for key, value in prototype_trace.items():
if value["type"] == "deterministic":
raise ValueError(
f"deterministic site '{key}' in sub-guide at position "
f"{i} should not be exposed"
)

return result

def __getitem__(self, key):
Expand Down
27 changes: 21 additions & 6 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def model(data=None, labels=None):
rng_key_init = random.PRNGKey(1)
if auto_class == AutoGuideList:
guide = AutoGuideList(model)
guide.append(AutoNormal(handlers.block(model, hide=[])))
guide.append(AutoNormal(handlers.block(model, hide=["logits"])))
else:
guide = auto_class(model, init_loc_fn=init_strategy)
svi = SVI(model, guide, adam, Elbo())
Expand Down Expand Up @@ -1157,6 +1157,18 @@ def model(x, y=None):
part.quantiles(params=params, quantiles=[0.2, 0.5, 0.8])


def test_autoguidelist_deterministic():
def model():
x = numpyro.sample("x", dist.Normal())
numpyro.deterministic("x2", x**2)

guide = AutoGuideList(model)
guide.append(AutoDiagonalNormal(model))
seeded_guide = handlers.seed(guide, 8)
with pytest.raises(ValueError, match="should not be exposed"):
seeded_guide()


@pytest.mark.parametrize(
"auto_class",
[
Expand All @@ -1183,31 +1195,32 @@ def model():
numpyro.deterministic("x2", x**2)

guide = AutoGuideList(model)
blocked_model = handlers.block(handlers.seed(model, 7), hide=["x2"])

# AutoGuideList does not support AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS
if auto_class == AutoDAIS:
with pytest.raises(
ValueError,
match="AutoDAIS, AutoSemiDAIS, and AutoSurrogateLikelihoodDAIS are not supported.",
):
guide.append(auto_class(model))
guide.append(auto_class(blocked_model))
return
if auto_class == AutoSemiDAIS:
with pytest.raises(
ValueError,
match="AutoDAIS, AutoSemiDAIS, and AutoSurrogateLikelihoodDAIS are not supported.",
):
guide.append(auto_class(model, local_model=None, global_guide=None))
guide.append(auto_class(blocked_model, local_model=None, global_guide=None))
return
if auto_class == AutoSurrogateLikelihoodDAIS:
with pytest.raises(
ValueError,
match="AutoDAIS, AutoSemiDAIS, and AutoSurrogateLikelihoodDAIS are not supported.",
):
guide.append(auto_class(model, surrogate_model=None))
guide.append(auto_class(blocked_model, surrogate_model=None))
return

guide.append(auto_class(model))
guide.append(auto_class(blocked_model))
svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
if auto_class in (AutoIAFNormal, AutoBNAFNormal) and max(shape, default=0) <= 1:
with pytest.raises(
Expand All @@ -1228,7 +1241,9 @@ def model():
sample_shape=sample_shape,
)
assert guide_samples["x"].shape == sample_shape + shape
assert guide_samples["x2"].shape == sample_shape + shape
# Substitute and trace to get the deterministic sites.
trace = handlers.trace(handlers.substitute(model, guide_samples)).get_trace()
assert trace["x2"]["value"].shape == sample_shape + shape


@pytest.mark.parametrize("use_global_dais_params", [True, False])
Expand Down

0 comments on commit 79d9e6b

Please sign in to comment.