Skip to content

Commit

Permalink
sample from distribution without storing (#1790)
Browse files Browse the repository at this point in the history
* exclude sample sites with "~"

* handle repeat remove_sites

* test exclude sites

* fix test case and len(1) collect_fields edge case

* add dict check, switch to list, add documentation

* back to dict solution
  • Loading branch information
amifalk authored May 2, 2024
1 parent d6ceae1 commit 0fd8c2e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 14 deletions.
50 changes: 36 additions & 14 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,23 @@ def _sample_fn_nojit_args(state, sampler, args, kwargs):
return (sampler.sample(state[0], args, kwargs),)


def _collect_fn(collect_fields):
@cached_by(_collect_fn, collect_fields)
def _collect_fn(collect_fields, remove_sites):
@cached_by(_collect_fn, collect_fields, remove_sites)
def collect(x):
if collect_fields:
return attrgetter(*collect_fields)(x[0])
fields = attrgetter(*collect_fields)(x[0])

if remove_sites != ():
fields = [fields] if len(collect_fields) == 1 else list(fields)
assert isinstance(fields[0], dict)

sample_sites = fields[0].copy()
for site in remove_sites:
sample_sites.pop(site)
fields[0] = sample_sites
fields = fields[0] if len(collect_fields) == 1 else fields

return fields
else:
return x[0]

Expand Down Expand Up @@ -419,7 +431,7 @@ def _get_cached_init_state(self, rng_key, args, kwargs):
except TypeError:
return None

def _single_chain_mcmc(self, init, args, kwargs, collect_fields):
def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites):
rng_key, init_state, init_params = init
# Check if _sample_fn is None, then we need to initialize the sampler.
if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
Expand Down Expand Up @@ -452,7 +464,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields):
upper_idx,
sample_fn,
init_val,
transform=_collect_fn(collect_fields),
transform=_collect_fn(collect_fields, remove_sites),
progbar=self.progress_bar,
return_last_val=True,
thinning=self.thinning,
Expand Down Expand Up @@ -556,7 +568,8 @@ def warmup(
These are typically the arguments needed by the `model`.
:param extra_fields: Extra fields (aside from :meth:`~numpyro.infer.mcmc.MCMCKernel.default_fields`)
from the state object (e.g. :data:`numpyro.infer.hmc.HMCState` for HMC) to collect during
the MCMC run.
the MCMC run. Exclude sample sites from collection with "~`sampler.sample_field`.`sample_site`".
e.g. "~z.a" will prevent site "a" from being collected if you're using the NUTS sampler.
:type extra_fields: tuple or list
:param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults
to `False`.
Expand Down Expand Up @@ -591,7 +604,9 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
:param extra_fields: Extra fields (aside from `"z"`, `"diverging"`) from the
state object (e.g. :data:`numpyro.infer.hmc.HMCState` for HMC) to be collected
during the MCMC run. Note that subfields can be accessed using dots, e.g.
`"adapt_state.step_size"` can be used to collect step sizes at each step.
`"adapt_state.step_size"` can be used to collect step sizes at each step. Exclude sample sites from
collection with "~`sampler.sample_field`.`sample_site`". e.g. "~z.a" will prevent site "a" from
being collected if you're using the NUTS sampler.
:type extra_fields: tuple or list of str
:param init_params: Initial parameters to begin sampling. The type must be consistent
with the input type to `potential_fn` provided to the kernel. If the kernel is
Expand Down Expand Up @@ -626,18 +641,25 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
" as `num_chains`."
)
assert isinstance(extra_fields, (tuple, list))
collect_fields = tuple(
set(
(self._sample_field,)
+ tuple(self._default_fields)
+ tuple(extra_fields)
)
)

collect_fields = {}
remove_sites = {}
for field_name in (
(self._sample_field,) + tuple(self._default_fields) + tuple(extra_fields)
):
if field_name.startswith(f"~{self._sample_field}."):
remove_sites[(field_name[len(self._sample_field) + 2 :])] = None
else:
collect_fields[field_name] = None
collect_fields = tuple(collect_fields.keys())
remove_sites = tuple(remove_sites.keys())

partial_map_fn = partial(
self._single_chain_mcmc,
args=args,
kwargs=kwargs,
collect_fields=collect_fields,
remove_sites=remove_sites,
)
map_args = (rng_key, init_state, init_params)
if self.num_chains == 1:
Expand Down
14 changes: 14 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,3 +1186,17 @@ def model(data):
mcmc.run(rng_key, data, extra_fields=("num_steps",))
num_steps_list = np.array(mcmc.get_extra_fields()["num_steps"])
assert all(step == num_steps for step in num_steps_list)


@pytest.mark.parametrize("kernel_cls", [NUTS, BarkerMH])
@pytest.mark.parametrize("remove_sites", [("~z.a", "~z.b"), ("~z.a", "~z.a")])
def test_remove_sites(kernel_cls, remove_sites):
def model():
numpyro.sample("a", dist.Normal())
numpyro.sample("b", dist.Normal())

mcmc = MCMC(kernel_cls(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0), extra_fields=remove_sites)
samps = mcmc.get_samples()

assert all([site[3:] not in samps for site in remove_sites])

0 comments on commit 0fd8c2e

Please sign in to comment.