Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce blackjax sampling memory usage #7407

Merged
merged 6 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ dependencies:
- cloudpickle
- h5py>=2.7
# Jaxlib version must not be greater than jax version!
- blackjax==1.2.0 # Blackjax>=1.2.1 is incompatible with latest available version of jaxlib in conda-forge
- jaxlib==0.4.23 # Latest available version in conda-forge, update when new version is available
- jax==0.4.23
- blackjax>=1.2.2
- jax>=0.4.28
- jaxlib>=0.4.28
- libblas=*=*mkl
- mkl-service
- numpy>=1.15.0
Expand All @@ -25,9 +25,8 @@ dependencies:
- networkx
- rich>=13.7.1
- threadpoolctl>=3.1.0
# JAX is only compatible with Scipy 1.13.0 from >=0.4.26, but the respective version of
# JAXlib is still not on conda: https://github.com/conda-forge/jaxlib-feedstock/pull/243
- scipy>=1.4.1,<1.13.0
# JAX is only compatible with Scipy 1.13.0 from >=0.4.26
- scipy>=1.13.0
- typing-extensions>=3.7.4
# Extra dependencies for testing
- ipython>=7.16
Expand Down
3 changes: 3 additions & 0 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def _blackjax_inference_loop(
):
import blackjax

from blackjax.adaptation.base import get_filter_adapt_info_fn

algorithm_name = adaptation_kwargs.pop("algorithm", "nuts")
if algorithm_name == "nuts":
algorithm = blackjax.nuts
Expand All @@ -255,6 +257,7 @@ def _blackjax_inference_loop(
algorithm=algorithm,
logdensity_fn=logprob_fn,
target_acceptance_rate=target_accept,
adaptation_info_fn=get_filter_adapt_info_fn(),
**adaptation_kwargs,
)
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
Expand Down
2 changes: 1 addition & 1 deletion tests/sampling/test_mcmc_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
warns = {
(warn.category, warn.message.args[0])
for warn in recwarn
if warn.category not in (FutureWarning, DeprecationWarning)
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
junpenglao marked this conversation as resolved.
Show resolved Hide resolved
}
expected = set()
if nuts_sampler == "nutpie":
Expand Down
Loading