Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SGNHT #515

Merged
merged 8 commits into from
May 25, 2023
Merged

Add SGNHT #515

merged 8 commits into from
May 25, 2023

Conversation

SamDuffield
Copy link
Contributor

@SamDuffield SamDuffield commented Apr 7, 2023

Add Stochastic gradient Nosé-Hoover thermostat sampler of Ding et al, discussed in #289.

I have also rescaled the alpha, beta parameters for sghmc (and the new sgnht) to align with the description in Ma et al, which IMO has better interpretability (beta is variance of stochastic gradient). But I understand that this could modify existing code so can be reverted if need be. I also refactored diffusions.sghmc to read more clearly as an Euler solver.

I also noticed that the sgld and sghmc kernels return the step functions themselves rather than an MCMCSamplingAlgorithm (and that this is in disagreement with the docs).

A few important guidelines and requirements before we can merge your PR:

  • If I add a new sampler, there is an issue discussing it already;
  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date;
  • If I add a new sampler* I added/updated related examples

Copy link
Member

@albcab albcab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! Sorry for the delay, busy doing some restructuring.

Everything looks good, docs are great, tests are good, and the general structure of the algorithm looks great.

I agree that the parametrization from Ma et al for SGHMC is best, let's keep the rescaling.

You'll only need to follow the new structure of the API introduced in #501 and rebase to the latest commit. All you really need to do is move what you put in kernels.py to sgmcmc/snht.py and use the naming discussed in #280 for your algorithms. Everything else should be rebased to the latest commit history without any conflicts. You probably don't need it, but here is a basic skeleton for the new structure of sampling algorithms.

Since sgld and sghmc don't need a state (only the position needs to be carried over to the next iteration) they also don't need an initializer, hence why they don't return a MCMCSamplingAlgorithm. For now, we'll just keep it like this. Even though returning the momentum could be useful for debugging (if someone complains we'll change it).

@SamDuffield
Copy link
Contributor Author

I can't believe I was just getting the hang of the BlackJAX structure and then you went and changed it! 😆 Just kidding, big fan of the changes.

Hopefully, I have adopted them correctly, let me know what you think.

As discussed I have included the rescaling of the parameters for SGHMC. I also updated the docs for SGLC and SGHMC to say they return a "step function" rather than an "MCMCSamplingAlgorithm", but not sure if there is a better terminology.

IMO it might be nice for SGLD and SGHMC to return MCMCSamplingAlgorithm (perhaps with a dummy init function) so that users can switch between samplers seamlessly.

Copy link
Member

@albcab albcab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just two easy naming changes and we are ready to merge!

IMO it might be nice for SGLD and SGHMC to return MCMCSamplingAlgorithm (perhaps with a dummy init function) so that users can switch between samplers seamlessly.

This is an excellent point, completely agree. If you can/want to do the changes, go ahead with another PR (leave the doc updates of this PR), else open an issue and we'll take care of it 👍

blackjax/sgmcmc/sgnht.py Outdated Show resolved Hide resolved
blackjax/sgmcmc/sgnht.py Outdated Show resolved Hide resolved
@SamDuffield
Copy link
Contributor Author

Good spot! Should be good to go now

Yep I'm happy to do a new PR unifying SGLD and SGHMC to also return an MCMCSamplingAlgorithm

@codecov
Copy link

codecov bot commented May 25, 2023

Codecov Report

Merging #515 (5a0c7ce) into main (7100bca) will increase coverage by 0.02%.
The diff coverage is 99.52%.

@@            Coverage Diff             @@
##             main     #515      +/-   ##
==========================================
+ Coverage   99.28%   99.30%   +0.02%     
==========================================
  Files          47       48       +1     
  Lines        1947     2021      +74     
==========================================
+ Hits         1933     2007      +74     
  Misses         14       14              
Impacted Files Coverage Δ
blackjax/vi/pathfinder.py 97.01% <92.30%> (-1.17%) ⬇️
blackjax/vi/meanfield_vi.py 96.77% <92.85%> (-1.19%) ⬇️
blackjax/__init__.py 100.00% <100.00%> (ø)
blackjax/adaptation/base.py 100.00% <100.00%> (ø)
blackjax/adaptation/meads_adaptation.py 100.00% <100.00%> (ø)
blackjax/adaptation/pathfinder_adaptation.py 100.00% <100.00%> (ø)
blackjax/adaptation/step_size.py 100.00% <100.00%> (ø)
blackjax/adaptation/window_adaptation.py 100.00% <100.00%> (ø)
blackjax/mcmc/elliptical_slice.py 95.77% <100.00%> (+0.77%) ⬆️
blackjax/mcmc/ghmc.py 100.00% <100.00%> (ø)
... and 16 more

@albcab albcab merged commit c6149e3 into blackjax-devs:main May 25, 2023
@SamDuffield SamDuffield deleted the sgnht branch May 25, 2023 20:22
junpenglao pushed a commit that referenced this pull request Mar 12, 2024
* add sgnht

* reformat

* Restructure kernels

* Reformat

* Clean

* Rename step to kernel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants