Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sample from distribution without storing #1790

Merged
merged 6 commits into from
May 2, 2024
Merged

Conversation

amifalk
Copy link
Contributor

@amifalk amifalk commented May 1, 2024

As discussed in #1695. Users can exclude a sample site from collection by adding ~{sample_field}.{site_name} to extra_fields in MCMC.run or MCMC.warmup.

Because model initialization doesn't happen until _single_chain_mcmc is called and the logic to set up collect_fields happens before that, I opted not to make default_fields mutable/settable with infer.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice solution, thanks @amifalk!

test/infer/test_mcmc.py Show resolved Hide resolved
numpyro/infer/mcmc.py Outdated Show resolved Hide resolved
numpyro/infer/mcmc.py Outdated Show resolved Hide resolved
numpyro/infer/mcmc.py Outdated Show resolved Hide resolved
if field_name.startswith(f"~{self._sample_field}."):
remove_sites.append(field_name[len(self._sample_field) + 2 :])
else:
collect_fields.append(field_name)
Copy link
Member

@fehiepsi fehiepsi May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it seems that making collect_fields a set is slightly better (to avoid collecting duplicating fields). Maybe setting collect_fields=tuple(set(collect_fields)) below? Or using dict like what you did before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set is not insertion order preserving and the current solution relies on self._sample_field being the first item in collect_fields so I'll revert to the dictionary approach.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful work! Thanks, @amifalk!

@fehiepsi fehiepsi merged commit 0fd8c2e into pyro-ppl:master May 2, 2024
4 checks passed
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
* exclude sample sites with "~"

* handle repeat remove_sites

* test exclude sites

* fix test case and len(1) collect_fields edge case

* add dict check, switch to list, add documentation

* back to dict solution
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants