Open
Description
Description
JAX does not play well with fork, which is the default we're using for linux OS and arm-based MacOS
import pymc as pm
N_OBSERVATIONS = 50
with pm.Model() as model:
mu = pm.Normal("mu")
sigma = pm.HalfNormal("sigma", sigma=0.5)
y = pm.Normal("y", mu=mu, sigma=sigma, shape=N_OBSERVATIONS)
prior_trace = pm.sample_prior_predictive(random_seed=100)
data = prior_trace.prior.y.isel(chain=0, draw=0)
with pm.observe(model, {y: data}):
pm.sample(compile_kwargs=dict(mode="JAX"), mp_ctx="forkserver") # fine
pm.sample(compile_kwargs=dict(mode="JAX")) # hangs forever
Wherever we're defaulting to fork, we should switch to forkserver/spawn instead (whichever is supported)
Relevant code:
pymc/pymc/sampling/parallel.py
Lines 437 to 450 in 268e13b
To find the backend that is being used something like this can be used:
from pytensor.compile.mode import get_mode
from pytensor.link.jax import JAXLinker
...
# Somewhere inside/downstream of pm.sample
mode = compile_kwargs.get("mode", None)
using_jax = isinstance(get_mode(mode).linker, JAXLinker)