Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge remote to fork #1

Merged
merged 32 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bc89395
Update README.md (#638)
AdrienCorenflos Feb 21, 2024
8a9dfbf
Indexing the notebook showing how to reproduce the GIF. (#640)
AdrienCorenflos Feb 28, 2024
2ccdfb0
Bump python version (#645)
junpenglao Mar 7, 2024
3dc3809
SMC: allow each mutation kernel to have different parameters. (#649)
ciguaran Mar 25, 2024
2e25624
Migrate from deprecated `host_callback` to `io_callback` (#651)
junpenglao Mar 28, 2024
f77297f
Fix MALA transition energy (#653)
ksnxr Mar 31, 2024
7cf4f9d
Change variable names (#654)
ksnxr Apr 1, 2024
a5f7482
Replace iterative RNG split and carry with `jax.random.fold_in` (#656)
junpenglao Apr 8, 2024
1bc6f93
Removal of Algorithm classes. (#657)
ciguaran Apr 22, 2024
3f92393
Fix deprecated call to jnp.clip (#664)
GaetanLepage May 8, 2024
40efb6c
Update jax version requirements (#666)
junpenglao May 8, 2024
af79fa4
Make tests pass on `aarch64-linux` (#671)
albcab May 13, 2024
cd91e41
Enable fitlering of AdaptationInfo (#674)
andrewdipper May 16, 2024
e0a7f9e
Update `run_inference_algorithm` to split `initial_position` and `ini…
reubenharry May 20, 2024
5831740
Preconditioned mclmc (#673)
reubenharry May 25, 2024
20666de
New integrator, and add some metadata to integrators.py (#681)
reubenharry May 27, 2024
360ac3b
Minor formatting (#685)
junpenglao May 27, 2024
3fbdac6
MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT (#687)
reubenharry Jun 3, 2024
a3e235f
FIX KWARG BUG (#686)
reubenharry Jun 3, 2024
83bc3a0
Change isokinetic_integrator generation API (#689)
junpenglao Jun 3, 2024
a4408d3
Apply function on pytree directly. (#692)
junpenglao Jun 5, 2024
dd9ba03
Fix sampling test. (#693)
junpenglao Jun 5, 2024
3353209
Enable shared mcmc parameters with tempered smc (#694)
andrewdipper Jun 15, 2024
eca35ab
convert to bit twiddling (#696)
andrewdipper Jun 20, 2024
5764a2b
Remove nightly release (#699)
junpenglao Jun 24, 2024
f8db9aa
Fix doc mistakes (#701)
gil2rok Jun 24, 2024
441412a
Update index.md (#711)
johannahaffner Jul 31, 2024
27dfc9e
Enable progress bar under pmap (#712)
andrewdipper Aug 7, 2024
148c028
remove labels (#716)
andrewdipper Aug 9, 2024
7135fd7
Simplify `run_inference_algorithm` (#714)
reubenharry Aug 12, 2024
834f55d
Harmonize Quickstart example (#717)
gil2rok Aug 13, 2024
4a11236
Update README.md (#719)
gil2rok Aug 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 0 additions & 48 deletions .github/workflows/nightly.yml

This file was deleted.

4 changes: 2 additions & 2 deletions .github/workflows/publish_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ jobs:
with:
persist-credentials: false

- name: Set up Python 3.9
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11

- name: Build the documentation with Sphinx
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11
- name: Build sdist and wheel
run: |
python -m pip install -U pip
Expand Down Expand Up @@ -51,7 +51,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11
- name: Give PyPI some time to update the index
run: sleep 240
- name: Attempt install from PyPI
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11
- uses: pre-commit/[email protected]

test:
Expand All @@ -24,7 +24,7 @@ jobs:
- style
strategy:
matrix:
python-version: [ '3.9', '3.11']
python-version: ['3.11', '3.12']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
26 changes: 11 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ or via conda-forge:
conda install -c conda-forge blackjax
```

Nightly builds (bleeding edge) of Blackjax can also be installed using `pip`:

```bash
pip install blackjax-nightly
```

BlackJAX is written in pure Python but depends on XLA via JAX. By default, the
version of JAX that will be installed along with BlackJAX will make your code
run on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow
Expand Down Expand Up @@ -81,9 +75,10 @@ state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.key(0)
for _ in range(100):
rng_key, nuts_key = jax.random.split(rng_key)
state, _ = nuts.step(nuts_key, state)
step = jax.jit(nuts.step)
for i in range(100):
nuts_key = jax.random.fold_in(rng_key, i)
state, _ = step(nuts_key, state)
```

See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.
Expand Down Expand Up @@ -138,12 +133,13 @@ Please follow our [short guide](https://github.com/blackjax-devs/blackjax/blob/m
To cite this repository:

```
@software{blackjax2020github,
author = {Cabezas, Alberto, Lao, Junpeng, and Louf, R\'emi},
title = {{B}lackjax: A sampling library for {JAX}},
url = {http://github.com/blackjax-devs/blackjax},
version = {<insert current release tag>},
year = {2023},
@misc{cabezas2024blackjax,
title={BlackJAX: Composable {B}ayesian inference in {JAX}},
author={Alberto Cabezas and Adrien Corenflos and Junpeng Lao and Rémi Louf},
year={2024},
eprint={2402.10797},
archivePrefix={arXiv},
primaryClass={cs.MS}
}
```
In the above bibtex entry, names are in alphabetical order, the version number should be the last tag on the `main` branch.
Expand Down
188 changes: 140 additions & 48 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,163 @@
import dataclasses
from typing import Callable

from blackjax._version import __version__

from .adaptation.chees_adaptation import chees_adaptation
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
from .adaptation.meads_adaptation import meads_adaptation
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
from .adaptation.window_adaptation import window_adaptation
from .base import SamplingAlgorithm, VIAlgorithm
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .mcmc.barker import barker_proposal
from .mcmc.dynamic_hmc import dynamic_hmc
from .mcmc.elliptical_slice import elliptical_slice
from .mcmc.ghmc import ghmc
from .mcmc.hmc import hmc
from .mcmc.mala import mala
from .mcmc.marginal_latent_gaussian import mgrad_gaussian
from .mcmc.mclmc import mclmc
from .mcmc.nuts import nuts
from .mcmc.periodic_orbital import orbital_hmc
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh
from .mcmc.rmhmc import rmhmc
from .mcmc import barker
from .mcmc import dynamic_hmc as _dynamic_hmc
from .mcmc import elliptical_slice as _elliptical_slice
from .mcmc import ghmc as _ghmc
from .mcmc import hmc as _hmc
from .mcmc import mala as _mala
from .mcmc import marginal_latent_gaussian
from .mcmc import mclmc as _mclmc
from .mcmc import nuts as _nuts
from .mcmc import periodic_orbital, random_walk
from .mcmc import rmhmc as _rmhmc
from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk
from .mcmc.random_walk import (
irmh_as_top_level_api,
normal_random_walk,
rmh_as_top_level_api,
)
from .optimizers import dual_averaging, lbfgs
from .sgmcmc.csgld import csgld
from .sgmcmc.sghmc import sghmc
from .sgmcmc.sgld import sgld
from .sgmcmc.sgnht import sgnht
from .smc.adaptive_tempered import adaptive_tempered_smc
from .smc.inner_kernel_tuning import inner_kernel_tuning
from .smc.tempered import tempered_smc
from .vi.meanfield_vi import meanfield_vi
from .vi.pathfinder import pathfinder
from .vi.schrodinger_follmer import schrodinger_follmer
from .vi.svgd import svgd
from .sgmcmc import csgld as _csgld
from .sgmcmc import sghmc as _sghmc
from .sgmcmc import sgld as _sgld
from .sgmcmc import sgnht as _sgnht
from .smc import adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import tempered
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
from .vi import schrodinger_follmer as _schrodinger_follmer
from .vi import svgd as _svgd
from .vi.pathfinder import PathFinderAlgorithm

"""
The above three classes exist as a backwards compatible way of exposing both the high level, differentiable
factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower
level to be mostly functional programming in nature and reducing boilerplate code.
"""


@dataclasses.dataclass
class GenerateSamplingAPI:
differentiable: Callable
init: Callable
build_kernel: Callable

def __call__(self, *args, **kwargs) -> SamplingAlgorithm:
return self.differentiable(*args, **kwargs)

def register_factory(self, name, callable):
setattr(self, name, callable)


@dataclasses.dataclass
class GenerateVariationalAPI:
differentiable: Callable
init: Callable
step: Callable
sample: Callable

def __call__(self, *args, **kwargs) -> VIAlgorithm:
return self.differentiable(*args, **kwargs)


@dataclasses.dataclass
class GeneratePathfinderAPI:
differentiable: Callable
approximate: Callable
sample: Callable

def __call__(self, *args, **kwargs) -> PathFinderAlgorithm:
return self.differentiable(*args, **kwargs)


def generate_top_level_api_from(module):
return GenerateSamplingAPI(
module.as_top_level_api, module.init, module.build_kernel
)


# MCMC
hmc = generate_top_level_api_from(_hmc)
nuts = generate_top_level_api_from(_nuts)
rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh)
irmh = GenerateSamplingAPI(
irmh_as_top_level_api, random_walk.init, random_walk.build_irmh
)
dynamic_hmc = generate_top_level_api_from(_dynamic_hmc)
rmhmc = generate_top_level_api_from(_rmhmc)
mala = generate_top_level_api_from(_mala)
mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian)
orbital_hmc = generate_top_level_api_from(periodic_orbital)

additive_step_random_walk = GenerateSamplingAPI(
_additive_step_random_walk, random_walk.init, random_walk.build_additive_step
)

additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk)

mclmc = generate_top_level_api_from(_mclmc)
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
ghmc = generate_top_level_api_from(_ghmc)
barker_proposal = generate_top_level_api_from(barker)

hmc_family = [hmc, nuts]

# SMC
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
tempered_smc = generate_top_level_api_from(tempered)
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)

smc_family = [tempered_smc, adaptive_tempered_smc]
"Step_fn returning state has a .particles attribute"

# stochastic gradient mcmc
sgld = generate_top_level_api_from(_sgld)
sghmc = generate_top_level_api_from(_sghmc)
sgnht = generate_top_level_api_from(_sgnht)
csgld = generate_top_level_api_from(_csgld)
svgd = generate_top_level_api_from(_svgd)

# variational inference
meanfield_vi = GenerateVariationalAPI(
_meanfield_vi.as_top_level_api,
_meanfield_vi.init,
_meanfield_vi.step,
_meanfield_vi.sample,
)
schrodinger_follmer = GenerateVariationalAPI(
_schrodinger_follmer.as_top_level_api,
_schrodinger_follmer.init,
_schrodinger_follmer.step,
_schrodinger_follmer.sample,
)

pathfinder = GeneratePathfinderAPI(
_pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample
)


__all__ = [
"__version__",
"dual_averaging", # optimizers
"lbfgs",
"hmc", # mcmc
"dynamic_hmc",
"rmhmc",
"mala",
"mgrad_gaussian",
"nuts",
"orbital_hmc",
"additive_step_random_walk",
"rmh",
"irmh",
"mclmc",
"elliptical_slice",
"ghmc",
"barker_proposal",
"sgld", # stochastic gradient mcmc
"sghmc",
"sgnht",
"csgld",
"window_adaptation", # mcmc adaptation
"meads_adaptation",
"chees_adaptation",
"pathfinder_adaptation",
"mclmc_find_L_and_step_size", # mclmc adaptation
"adaptive_tempered_smc", # smc
"tempered_smc",
"inner_kernel_tuning",
"meanfield_vi", # variational inference
"pathfinder",
"schrodinger_follmer",
"svgd",
"ess", # diagnostics
"rhat",
]
Loading