Skip to content

Commit c8b22df

Browse files
authored
Reduce blackjax sampling memory usage (#7407)
* Reduce blackjax sampling memory usage ... by not outputing the warmup diagnositics * Update jax env * fix pre-commit * skip also RuntimeWarning * ping jax versions
1 parent 641a60b commit c8b22df

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

conda-envs/environment-jax.yml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ dependencies:
1111
- cloudpickle
1212
- h5py>=2.7
1313
# Jaxlib version must not be greater than jax version!
14-
- blackjax==1.2.0 # Blackjax>=1.2.1 is incompatible with latest available version of jaxlib in conda-forge
15-
- jaxlib==0.4.23 # Latest available version in conda-forge, update when new version is available
16-
- jax==0.4.23
14+
- blackjax>=1.2.2
15+
- jax>=0.4.28
16+
- jaxlib>=0.4.28
1717
- libblas=*=*mkl
1818
- mkl-service
1919
- numpy>=1.15.0
@@ -25,9 +25,8 @@ dependencies:
2525
- networkx
2626
- rich>=13.7.1
2727
- threadpoolctl>=3.1.0
28-
# JAX is only compatible with Scipy 1.13.0 from >=0.4.26, but the respective version of
29-
# JAXlib is still not on conda: https://github.com/conda-forge/jaxlib-feedstock/pull/243
30-
- scipy>=1.4.1,<1.13.0
28+
# JAX is only compatible with Scipy 1.13.0 from >=0.4.26
29+
- scipy>=1.13.0
3130
- typing-extensions>=3.7.4
3231
# Extra dependencies for testing
3332
- ipython>=7.16

pymc/sampling/jax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ def _blackjax_inference_loop(
243243
):
244244
import blackjax
245245

246+
from blackjax.adaptation.base import get_filter_adapt_info_fn
247+
246248
algorithm_name = adaptation_kwargs.pop("algorithm", "nuts")
247249
if algorithm_name == "nuts":
248250
algorithm = blackjax.nuts
@@ -255,6 +257,7 @@ def _blackjax_inference_loop(
255257
algorithm=algorithm,
256258
logdensity_fn=logprob_fn,
257259
target_acceptance_rate=target_accept,
260+
adaptation_info_fn=get_filter_adapt_info_fn(),
258261
**adaptation_kwargs,
259262
)
260263
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)

tests/sampling/test_mcmc_external.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
5151
warns = {
5252
(warn.category, warn.message.args[0])
5353
for warn in recwarn
54-
if warn.category not in (FutureWarning, DeprecationWarning)
54+
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
5555
}
5656
expected = set()
5757
if nuts_sampler == "nutpie":

0 commit comments

Comments
 (0)