Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor performance of BlackJax SgLD for Bayesian CNN on MNIST dataset. #516

Open
HitarthGandhi opened this issue Apr 12, 2023 · 1 comment

Comments

@HitarthGandhi
Copy link

We are trying to use BlackJAX SgLD for BCNN on the MNIST dataset to solve blackjax-devs/sampling-book#14. We used the SgMCMCJax library for SgLD, which gives an accuracy of 94.4%. BlackJax is unable to reproduce the results in the same setting. The detailed code and our findings can be found in this notebook.

You can see in the below plot that BlackJax SGLD does not give the desired results.

@albcab
Copy link
Member

albcab commented Apr 13, 2023

Try installing the latest developments with pip install blackjax-nightly. Notice that you'll need to change your code a bit for it to work with the latest SgLD kernel. Specifically, you won't be initializing a state and instead passing the previous position directly to the step function, you'll also be passing the learning_rate/step_size to the step function instead of the kernel initializer:

sgld = blackjax.sgld(grad_fn)
new_position = sgld.step(rng_key, position, minibatch, step_size)

Hope this helps, let us know if it works. We're working on a new release soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants