Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removal of Algorithm classes. #657

Merged
merged 13 commits into from
Apr 22, 2024

Conversation

ciguaran
Copy link
Contributor

@ciguaran ciguaran commented Apr 18, 2024

This PR removes the algorithm classes. These are static classes we don't use directly, but via calling new and getting a SamplingAlgorithm result. This PR:

  • replaces these classes with instances of a single class so now each module contains: an init, build_kernel and a as_sampling_algorithm functions. The idea is that the latter fixes dependencies/parameters, in particular when these are not differentiable. The init and build_kernel still exist and are exposed since we need that lower level API, specially when composing algorithms (for example, when tuning SMC inner kernel we need to be able to change parameter on every call).

By doing this we still can call algorithms directly, like blackjax.hmc(). What we do loose is the (light) type annotations we are doing, for example in window_adaptation. I have been thinking about this type of annotations, and I think we should remove them.

The reason for doing it is the following: python fosters duck and structural typing, in contrast to nominal typing like you can find in say Java. In the case of window adaptation, we want the type to mean "hmc family" as algorithms that have an inverse_mass_matrix and a step size. But the way is implemented right now, it actually means whatever the class hmc or the class nuts does! So the classes kind of exist just to be able to name them (aka to use nominal typing). Since most of our codebase is functional, from a typing perspective most samplers are Callables that take in matrixes, doubles, pytrees and return something of the same flavours. There's no way to say: this is the type of a callable that takes in an inverse_mass_matrix and a step size and uses it in some consistent way, because that is not duck typing nor structural! aka we are trying to statically type using tools that are not pythonic. I'd suggest we replace this kind of "algorithm level" type annotations with docs and tests. Check for example the smc_compatibility_test

@@ -243,7 +242,7 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]:


def window_adaptation(
algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts],
algorithm,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this type not SamplingAlgorithm? I guess it is somewhat more specific than that, but it seems like it needs to at least implement the SamplingAlgorithm protocol...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but adding that type annotation would tell the user that any subtype of SamplingAlgorithm can be used as parameter to that function, which would not be true.

@ciguaran ciguaran marked this pull request as ready for review April 19, 2024 18:01
@ciguaran ciguaran changed the title A draft PR on the removal of Algorithm classes. Removal of Algorithm classes. Apr 19, 2024
Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some suggestion around naming.
@albcab thoughts?

blackjax/__init__.py Outdated Show resolved Hide resolved
blackjax/__init__.py Outdated Show resolved Hide resolved
blackjax/__init__.py Outdated Show resolved Hide resolved
blackjax/__init__.py Outdated Show resolved Hide resolved
blackjax/__init__.py Outdated Show resolved Hide resolved
blackjax/mcmc/random_walk.py Outdated Show resolved Hide resolved
blackjax/sgmcmc/csgld.py Outdated Show resolved Hide resolved
blackjax/vi/meanfield_vi.py Outdated Show resolved Hide resolved
blackjax/vi/pathfinder.py Outdated Show resolved Hide resolved
tests/smc/test_inner_kernel_tuning.py Outdated Show resolved Hide resolved
@albcab
Copy link
Member

albcab commented Apr 22, 2024

I have some suggestion around naming. @albcab thoughts?

Suggesting algorithm rather than API, this way we could use as_algorithm for all.

junpenglao
junpenglao previously approved these changes Apr 22, 2024
junpenglao
junpenglao previously approved these changes Apr 22, 2024
Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! Thank you for working on this!

blackjax/__init__.py Outdated Show resolved Hide resolved
@junpenglao junpenglao merged commit 1bc6f93 into blackjax-devs:main Apr 22, 2024
5 checks passed
AdrienCorenflos added a commit to AdrienCorenflos/blackjax that referenced this pull request Aug 14, 2024
* Update README.md (blackjax-devs#638)

* Update README.md

Update citation.

* Update README.md

* Indexing the notebook showing how to reproduce the GIF. (blackjax-devs#640)

Co-authored-by: Junpeng Lao <[email protected]>

* Bump python version (blackjax-devs#645)

* Bump python version

* update bool inverse

* SMC: allow each mutation kernel to have different parameters. (blackjax-devs#649)

* vmaping over parameters in base

* switch from mcmc_factory to just passing in parameters

* pre-commit and typing

* CRU and docs improvement

* pre-commit

* code review updates

* pre-commit

* rename test

* Migrate from deprecated `host_callback` to `io_callback` (blackjax-devs#651)

* Migrate from deprecated `host_callback` to `io_callback`

Co-Authored-By:
George Necula <[email protected]>

* Format file

* Fix bug

* Fix MALA transition energy (blackjax-devs#653)

* Fix MALA transition energy

* Use a different logic.

* Change variable names (blackjax-devs#654)

* Replace iterative RNG split and carry with `jax.random.fold_in` (blackjax-devs#656)

* Replace iterative RNG split and carry with `jax.random.fold_in`

* revert unintended change

* file formatting

* change `jax.tree_map` to `jax.tree.map`

* revert unintended file

* fiddle with rng_key

* seed again

* Removal of Algorithm classes. (blackjax-devs#657)

* more

* removing export

* removal of classes, tests passing

* linter

* fix on test

* linter

* removing parametrization on test

* code review updates

* exporting as_top_level_api in dynamic_hmc

* linter

* code review update: replace imports

* Fix deprecated call to jnp.clip (blackjax-devs#664)

* Update jax version requirements (blackjax-devs#666)

Fix blackjax-devs#665

* Make tests pass on `aarch64-linux` (blackjax-devs#671)

* Enable fitlering of AdaptationInfo (blackjax-devs#674)

* enable AdaptationInfo filtering

* revert progress_bar

* fix pre-commit

* fix empty sets

* enable adapt info filtering for all adaptation algorithms

* fix precommit /progressbar=True

* change filter tuple to use tree_map

* Update `run_inference_algorithm` to split `initial_position` and `initial_state` (blackjax-devs#672)

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* ADD INITIAL_POSITION

* FIX TEST

* RENAME O

* FIX DOCSTRING

* PUT EXPECTATION AFTER TRANSFORM

* Preconditioned mclmc (blackjax-devs#673)

* TESTS

* TESTS

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* ADD PRECONDITIONING TO MCLMC

* ADD PRECONDITIONING TO TUNING FOR MCLMC

* UPDATE GITIGNORE

* UPDATE GITIGNORE

* UPDATE TESTS

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* STREAMING AVERAGE

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* GITIGNORE

* PRECOMMIT CLEAN UP

* ADD INITIAL_POSITION

* FIX TEST

* ADD TEST

* REMOVE BENCHMARKS

* BUG FIX

* CHANGE PRECISION

* CHANGE PRECISION

* RENAME O

* UPDATE STREAMING AVG

* UPDATE PR

* RENAME STD_MAT

* New integrator, and add some metadata to integrators.py (blackjax-devs#681)

* TESTS

* TESTS

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* ADD PRECONDITIONING TO MCLMC

* ADD PRECONDITIONING TO TUNING FOR MCLMC

* UPDATE GITIGNORE

* UPDATE GITIGNORE

* UPDATE TESTS

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* STREAMING AVERAGE

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* GITIGNORE

* PRECOMMIT CLEAN UP

* FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS

* TEMPORARILY ADD BENCHMARKS

* ADD INITIAL_POSITION

* FIX TEST

* CLEAN UP

* REMOVE BENCHMARKS

* ADD TEST

* REMOVE BENCHMARKS

* BUG FIX

* CHANGE PRECISION

* CHANGE PRECISION

* ADD OMELYAN TEST

* RENAME O

* UPDATE STREAMING AVG

* UPDATE PR

* RENAME STD_MAT

* MERGE MAIN

* REMOVE COEFFICIENT EXPORTS

* Minor formatting (blackjax-devs#685)

* Minor formatting

* formatting

* fix test

* formatting

* MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT (blackjax-devs#687)

* FIX KWARG BUG (blackjax-devs#686)

* FIX KWARG BUG

* FIX KWARG BUG

* Change isokinetic_integrator generation API (blackjax-devs#689)

* Apply function on pytree directly. (blackjax-devs#692)

* Apply function on pytree directly.

Avoiding unnecssary unpacking

* Fix kwarg

* Fix sampling test. (blackjax-devs#693)

* Enable shared mcmc parameters with tempered smc (blackjax-devs#694)

* add parameter filtering

* fix parameter split + docstring

* change extend_paramss

* convert to bit twiddling (blackjax-devs#696)

* Remove nightly release (blackjax-devs#699)

* Fix doc mistakes (blackjax-devs#701)

* Fix equation formatting

* Clarify JAX gradient error

* Fix punctuation + capitalization

* Fix grammar

Should not begin sentence with "i.e." in English.

* Fix math formatting error

* Fix typo

Change parallel _ensample_ chain adaptation to parallel _ensemble_ chain adaptation.

* Add SVGD citation to appear in doc

Currently the SVGD paper is only cited in the `kernel` function, which is defined _within_ the `build_kernel` function. Because of this nested function format, the SVGD paper is _not_ cited in the documentation.

To fix this, I added a citation to the SVGD paper in the `as_top_level_api` docstring.

* Fix grammar + clarify doc

* Fix typo

---------

Co-authored-by: Junpeng Lao <[email protected]>

* Update index.md (blackjax-devs#711)

The jitted step remained unused, leading to the example running with an uncompiled nuts.step. 

Changing this reduces the execution time by a factor of 30 on my system and showcases blackjax' speed.

* Enable progress bar under pmap (blackjax-devs#712)

* enable pmap progbar

* fix bar creation

* add locking

* fix formatting

* switch to using chain state

* remove labels (blackjax-devs#716)

* Simplify `run_inference_algorithm` (blackjax-devs#714)

* fix minor type errors

* storing only expectation values

* fixed memory efficient sampling

* clean up

* renaming vars

* precommit fixes

* fixing tests

* fixing tests

* fixing tests

* fixing tests

* fixing tests

* merge main

* burn in and fix tests

* burn in and fix tests

* minor fixes

* minor fixes

* minor fixes

---------

Co-authored-by: [email protected] <[email protected]>

* Harmonize Quickstart example (blackjax-devs#717)

* Update README.md (blackjax-devs#719)

---------

Co-authored-by: Junpeng Lao <[email protected]>
Co-authored-by: Carlos Iguaran <[email protected]>
Co-authored-by: ksnxr <[email protected]>
Co-authored-by: Gaétan Lepage <[email protected]>
Co-authored-by: Alberto Cabezas <[email protected]>
Co-authored-by: andrewdipper <[email protected]>
Co-authored-by: Reuben <[email protected]>
Co-authored-by: Gilad Turok <[email protected]>
Co-authored-by: johannahaffner <[email protected]>
Co-authored-by: [email protected] <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants