Skip to content

Do not use "fork" as default mp_ctx when compiling JAX functions in the PyMC sampler #7668

Open
@ricardoV94

Description

@ricardoV94

Description

JAX does not play well with fork, which is the default we're using for linux OS and arm-based MacOS

import pymc as pm
N_OBSERVATIONS = 50

with pm.Model() as model:
    mu = pm.Normal("mu")
    sigma = pm.HalfNormal("sigma", sigma=0.5)
    y = pm.Normal("y", mu=mu, sigma=sigma, shape=N_OBSERVATIONS)
    prior_trace = pm.sample_prior_predictive(random_seed=100)

data = prior_trace.prior.y.isel(chain=0, draw=0)
with pm.observe(model, {y: data}):
    pm.sample(compile_kwargs=dict(mode="JAX"), mp_ctx="forkserver")  # fine
    pm.sample(compile_kwargs=dict(mode="JAX"))  # hangs forever

Wherever we're defaulting to fork, we should switch to forkserver/spawn instead (whichever is supported)

Relevant code:

if mp_ctx is None or isinstance(mp_ctx, str):
# Closes issue https://github.com/pymc-devs/pymc/issues/3849
# Related issue https://github.com/pymc-devs/pymc/issues/5339
if mp_ctx is None and platform.system() == "Darwin":
if platform.processor() == "arm":
mp_ctx = "fork"
logger.debug(
"mp_ctx is set to 'fork' for MacOS with ARM architecture. "
+ "This might cause unexpected behavior with JAX, which is inherently multithreaded."
)
else:
mp_ctx = "forkserver"
mp_ctx = multiprocessing.get_context(mp_ctx)

To find the backend that is being used something like this can be used:

from pytensor.compile.mode import get_mode
from pytensor.link.jax import JAXLinker
...
  # Somewhere inside/downstream of pm.sample
  mode = compile_kwargs.get("mode", None)
  using_jax = isinstance(get_mode(mode).linker, JAXLinker)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions