Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes random_flax_module with flax.linen.BatchNorm #1823

Merged
merged 8 commits into from
Jul 1, 2024

Conversation

juanitorduz
Copy link
Contributor

Fixes #1446

Comment on lines +310 to +313
# test svi
guide = AutoDelta(model)
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
svi.run(random.PRNGKey(100), 10)
Copy link
Contributor Author

@juanitorduz juanitorduz Jun 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without these SVI lines, the tests were passing in the master branch. Once added, we got the error. The suggested fix in the issue does solve it.

@juanitorduz
Copy link
Contributor Author

The first two commits are from an old branch 🤦 . We can squash and merge instead.

numpyro/handlers.py Outdated Show resolved Hide resolved
@juanitorduz
Copy link
Contributor Author

image

@juanitorduz
Copy link
Contributor Author

@fehiepsi should I leave _substitute_default_key in utils.py or is handlers.py a better place?

@juanitorduz juanitorduz requested a review from fehiepsi July 1, 2024 15:32
@fehiepsi
Copy link
Member

fehiepsi commented Jul 1, 2024

Leaving it in utils sounds reasonable to me. It is just a workaround for the edge case.

Thanks for fixing the issue!

@fehiepsi fehiepsi merged commit d40f0e9 into pyro-ppl:master Jul 1, 2024
4 checks passed
@juanitorduz juanitorduz deleted the nn_batch_issue_1446 branch July 1, 2024 18:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Example code for using random_flax_module with flax.linen.BatchNorm (mutable)
2 participants