Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move kernel functions to their algorithm-class folder. #501

Merged
merged 3 commits into from
May 12, 2023

Conversation

albcab
Copy link
Member

@albcab albcab commented Feb 27, 2023

Closes #492 before people start building new algorithms with the previous structure.

Besides loosing the aliases, the other major breaking change is that for adaptation algorithms the user can no longer pass a class as the algorithm to adapt with. Hence, the user now needs to pass a string with the algorithm_name for the adaptation kernel to invoke the kernel and init method. This applies to window_adaptation and pathfinder_adaptation.

@albcab albcab requested a review from rlouf February 27, 2023 17:33
@albcab albcab force-pushed the kernel_refactoring branch 2 times, most recently from c300d43 to 8041a3f Compare February 27, 2023 18:01
@codecov
Copy link

codecov bot commented Feb 27, 2023

Codecov Report

Merging #501 (3077fa6) into main (65c01a8) will increase coverage by 0.01%.
The diff coverage is 99.35%.

@@            Coverage Diff             @@
##             main     #501      +/-   ##
==========================================
+ Coverage   99.28%   99.29%   +0.01%     
==========================================
  Files          47       47              
  Lines        1948     1978      +30     
==========================================
+ Hits         1934     1964      +30     
  Misses         14       14              
Impacted Files Coverage Δ
blackjax/vi/pathfinder.py 97.01% <92.30%> (-1.17%) ⬇️
blackjax/vi/meanfield_vi.py 96.77% <92.85%> (-1.19%) ⬇️
blackjax/__init__.py 100.00% <100.00%> (ø)
blackjax/adaptation/base.py 100.00% <100.00%> (ø)
blackjax/adaptation/meads_adaptation.py 100.00% <100.00%> (ø)
blackjax/adaptation/pathfinder_adaptation.py 100.00% <100.00%> (ø)
blackjax/adaptation/window_adaptation.py 100.00% <100.00%> (ø)
blackjax/mcmc/elliptical_slice.py 95.77% <100.00%> (+0.77%) ⬆️
blackjax/mcmc/ghmc.py 100.00% <100.00%> (ø)
blackjax/mcmc/hmc.py 100.00% <100.00%> (ø)
... and 10 more

@albcab
Copy link
Member Author

albcab commented Feb 27, 2023

not sure what is up with the building of the docs by readthedocs

@junpenglao
Copy link
Member

Love the overall direction this is going, some high level comments:

  1. we now have things like blackjax.mcmc.hmc.hmc for high level API, instead of blackjax.hmc (alias of blackjax.kernel.hmc), should we think of some ways to make this a bit better on the eye? For example renaming the file names to something more explicit (hamiltonian_monte_carlo.py instead of hmc.py)
  2. window_adaptation is taking string as input, but it might be unnecessary restrictive (a user can wrap the random walk MH sampler with a step size and inverse mass matrix parameter and window_adaptation would work

@@ -449,6 +452,7 @@ def normal_logprob(self, x):
def test_univariate_normal(
self, algorithm, initial_position, parameters, num_sampling_steps, burnin
):
algorithm = eval(algorithm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand this change, seems not needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because blackjax.hmc (or any other kernel constructor) is now a function, it creates an object with a specific location in memory when evaluated. If this object is created outside the test function test_univariate_normal then calls to the chex variant function inference_loop have a different call sign for each of the devices/cores (because each call sign has a different memory location), throwing an error when collecting tests. The error looks like ERROR collecting gw1. Different tests were collected between gw0 and gw1.

@@ -68,8 +68,8 @@ def logprior_fn(x):
iterates = []
results = [] # type: List[TemperedSMCState]

hmc_kernel = blackjax.hmc.kernel()
hmc_init = blackjax.hmc.init
hmc_kernel = blackjax.mcmc.hmc.kernel()
Copy link
Member

@junpenglao junpenglao Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, in the top level __init__.py you imported the symbol, so you dont need to make this change

@rlouf
Copy link
Member

rlouf commented Feb 28, 2023

not sure what is up with the building of the docs by readthedocs

Fixed this temporarily by allowing the build to pass even if there are warnings (we'll need to figure these out later) in #502.

@rlouf
Copy link
Member

rlouf commented Feb 28, 2023

Hence, the user now needs to pass a string with the algorithm_name for the adaptation kernel to invoke the kernel and init method. This applies to window_adaptation and pathfinder_adaptation.

Ok, so that's why I actually added these classes :/ Not a big fan of passing names as strings.

@junpenglao
Copy link
Member

Ok, so that's why I actually added these classes :/ Not a big fan of passing names as strings.

Similar issue for when we are writing test: #501 (comment)

@albcab
Copy link
Member Author

albcab commented Feb 28, 2023

1. we now have things like blackjax.mcmc.hmc.hmc for high level API, instead of blackjax.hmc (alias of blackjax.kernel.hmc), should we think of some ways to make this a bit better on the eye? For example renaming the file names to something more explicit (hamiltonian_monte_carlo.py instead of hmc.py)

We still have blackjax.hmc. What we don't have is blackjax.hmc.kernel and blackjax.hmc.init, instead we have blackjax.mcmc.hmc.kernel and blackjax.mcmc.hmc.init.

2. window_adaptation is taking string as input, but it might be unnecessary restrictive (a user can wrap the random walk MH sampler with a step size and inverse mass matrix parameter and window_adaptation would work

Not a big fan of passing names as strings.

Agree with both. Our other option is to have the user pass both the init and kernel function directly (without aliases though).

@albcab
Copy link
Member Author

albcab commented Feb 28, 2023

are we turning our backs to Marx and staying on a class structure?

Python does seem to be, inherently, an object oriented language...

Similar issue for when we are writing test: #501 (comment)

@albcab
Copy link
Member Author

albcab commented Feb 28, 2023

Back to classes (but still have the function commits if we want to go back to that), also I've made naming on sampling algorithms consistent #280

@albcab albcab requested a review from junpenglao March 1, 2023 14:29
junpenglao
junpenglao previously approved these changes Mar 6, 2023
Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this.
Since we need to keep the class implementation for high level API, the devil's advocate is that keeping them in the same file (like what we have originally) makes the implementation pattern clear and easy to reference / compare.

@albcab
Copy link
Member Author

albcab commented Mar 6, 2023

Also to consider, the kernel.py file will grow to unsustainable levels and I think it's easier to give general guidelines if new algorithms are implemented in one file (assuming al basic components of the method are already implemented).

@albcab albcab changed the title Refactor to functions for high-level kernel API, move kernel functions to their algorithm-class folder. Move kernel functions to their algorithm-class folder. May 8, 2023
Copy link
Member Author

@albcab albcab May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@junpenglao this file is a good example of various algorithms in one script. Since both additive_step_random_walk and irmh are special cases of rmh it makes sense to keep them all together, wdyt?

junpenglao
junpenglao previously approved these changes May 12, 2023
Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Only one minor comment: any reason why you are putting the top level API for blackjax.adaptation before def base(...)?

@albcab
Copy link
Member Author

albcab commented May 12, 2023

Only one minor comment: any reason why you are putting the top level API for blackjax.adaptation before def base(...)?

Didn't notice this, probably best to keep the set up consistent...

@albcab albcab merged commit 661874d into blackjax-devs:main May 12, 2023
@albcab albcab deleted the kernel_refactoring branch May 12, 2023 12:55
@albcab albcab mentioned this pull request May 16, 2023
12 tasks
junpenglao pushed a commit that referenced this pull request Mar 12, 2024
* Move kernel function constructors to their respective algorithm-class folder

* base classes for adaptation algorithms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Deprecate the "class" approach to the high-level kernels
3 participants