Skip to content

ENH: Get sampling working using Apple Silicon GPU via jax backend #7332

Open
@drbenvincent

Description

@drbenvincent

It would be great to utilise the GPU on Apple Silicon chips. The lowest resistance way of doing this is probably through the jax backend, see https://jax.readthedocs.io/en/latest/installation.html#apple-silicon-gpu-arm-based and the Apple docs Accelerated JAX training on Mac

I don't have the stats, but some sizeable portion of PyMC users run code on hardware with Apple Silicon, and this will increase over time as more people upgrade from Intel to Apple Silicon. Full utilisation of those chips (i.e. the GPU component) would likely unlock some speed-ups in sampling.

So far I have partial progress (ht to @twiecki). I have the following environment file, metal_test_env.yaml

name: metal_test_env
channels:
  - conda-forge
dependencies:
  - blackjax
  - ipykernel
  - jax==0.4.26
  - jupyter
  - numpy
  - pip
  - pymc
  - python<3.11
  - pip:
    - jax-metal
    - jaxlib==0.4.26
    - ml-dtypes==0.2.0

NOTE: It seems that pinning python<3.11 is a necessity at this point in time.

I build that with:

mamba env create -f metal_test_env.yaml
mamba activate metal_test_env

Then in an ipython session we can confirm that jax has detected the Apple Silicon GPU

import jax
jax.print_environment_info()

gives

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716460927.826518 4468516 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

I0000 00:00:1716460927.843089 4468516 service.cc:145] XLA service 0x1276ac990 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716460927.843114 4468516 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716460927.844725 4468516 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716460927.844746 4468516 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='BenjamicStudio7', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

The key line is: jax.devices (1 total, 1 local): [METAL(id=0)]

So the next step is to see if we can do sampling:

import numpy as np
import pymc as pm

x = np.random.normal(size=10)
with pm.Model() as model:
    mu = pm.Normal("mu", 0, 1)
    pm.Normal("x_obs", mu=mu, sigma=1, observed=x)
    idata = pm.sample(nuts_sampler="blackjax")

which as of now results in this traceback

Traceback
XlaRuntimeError                           Traceback (most recent call last)
Cell In[6], line 8
      6 mu = pm.Normal("mu", 0, 1)
      7 pm.Normal("x_obs", mu=mu, sigma=1, observed=x)
----> 8 idata = pm.sample(nuts_sampler="blackjax")

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:688, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    684     if not isinstance(step, NUTS):
    685         raise ValueError(
    686             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    687         )
--> 688     return _sample_external_nuts(
    689         sampler=nuts_sampler,
    690         draws=draws,
    691         tune=tune,
    692         chains=chains,
    693         target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    694         random_seed=random_seed,
    695         initvals=initvals,
    696         model=model,
    697         var_names=var_names,
    698         progressbar=progressbar,
    699         idata_kwargs=idata_kwargs,
    700         nuts_sampler_kwargs=nuts_sampler_kwargs,
    701         **kwargs,
    702     )
    704 if isinstance(step, list):
    705     step = CompoundStep(step)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:351, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
    348 elif sampler in ("numpyro", "blackjax"):
    349     import pymc.sampling.jax as pymc_jax
--> 351     idata = pymc_jax.sample_jax_nuts(
    352         draws=draws,
    353         tune=tune,
    354         chains=chains,
    355         target_accept=target_accept,
    356         random_seed=random_seed,
    357         initvals=initvals,
    358         model=model,
    359         var_names=var_names,
    360         progressbar=progressbar,
    361         nuts_sampler=sampler,
    362         idata_kwargs=idata_kwargs,
    363         **nuts_sampler_kwargs,
    364     )
    365     return idata
    367 else:

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:567, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    564     raise ValueError(f"{nuts_sampler=} not recognized")
    566 tic1 = datetime.now()
--> 567 raw_mcmc_samples, sample_stats, library = sampler_fn(
    568     model=model,
    569     target_accept=target_accept,
    570     tune=tune,
    571     draws=draws,
    572     chains=chains,
    573     chain_method=chain_method,
    574     progressbar=progressbar,
    575     random_seed=random_seed,
    576     initial_points=initial_points,
    577     nuts_kwargs=nuts_kwargs,
    578 )
    579 tic2 = datetime.now()
    581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:398, in _sample_blackjax_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
    395 if chains == 1:
    396     initial_points = [np.stack(init_state) for init_state in zip(initial_points)]
--> 398 logprob_fn = get_jaxified_logp(model)
    400 seed = jax.random.PRNGKey(random_seed)
    401 keys = jax.random.split(seed, chains)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:153, in get_jaxified_logp(model, negative_logp)
    151 if not negative_logp:
    152     model_logp = -model_logp
--> 153 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
    155 def logp_fn_wrap(x):
    156     return logp_fn(*x)[0]

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:146, in get_jaxified_graph(inputs, outputs)
    143 mode.JAX.optimizer.rewrite(fgraph)
    145 # We now jaxify the optimized fgraph
--> 146 return jax_funcify(fgraph)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
     44 @jax_funcify.register(FunctionGraph)
     45 def jax_funcify_FunctionGraph(
     46     fgraph,
   (...)
     49     **kwargs,
     50 ):
---> 51     return fgraph_to_python(
     52         fgraph,
     53         jax_funcify,
     54         type_conversion_fn=jax_typify,
     55         fgraph_name=fgraph_name,
     56         **kwargs,
     57     )

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pytensor/link/utils.py:742, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    737 input_storage = storage_map.setdefault(
    738     i, [None if not isinstance(i, Constant) else i.data]
    739 )
    740 if input_storage[0] is not None or isinstance(i, Constant):
    741     # Constants need to be assigned locally and referenced
--> 742     global_env[local_input_name] = type_conversion_fn(
    743         input_storage[0], variable=i, storage=input_storage, **kwargs
    744     )
    745     # TODO: We could attempt to use the storage arrays directly
    746     # E.g. `local_input_name = f"{local_input_name}[0]"`
    747 node_input_names.append(local_input_name)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:35, in jax_typify_ndarray(data, dtype, **kwargs)
     33 if len(data.shape) == 0:
     34     return data.item()
---> 35 return jnp.array(data, dtype=dtype)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2197, in array(object, dtype, copy, order, ndmin)
   2194 else:
   2195   raise TypeError(f"Unexpected input type for array: {type(object)}")
-> 2197 out_array: Array = lax_internal._convert_element_type(
   2198     out, dtype, weak_type=weak_type)
   2199 if ndmin > ndim(out_array):
   2200   out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/lax/lax.py:558, in _convert_element_type(operand, new_dtype, weak_type)
    556   return type_cast(Array, operand)
    557 else:
--> 558   return convert_element_type_p.bind(operand, new_dtype=new_dtype,
    559                                      weak_type=bool(weak_type))

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/core.py:422, in Primitive.bind(self, *args, **params)
    419 def bind(self, *args, **params):
    420   assert (not config.enable_checks.value or
    421           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 422   return self.bind_with_trace(find_top_trace(args), args, params)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/core.py:425, in Primitive.bind_with_trace(self, trace, args, params)
    424 def bind_with_trace(self, trace, args, params):
--> 425   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    426   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/core.py:913, in EvalTrace.process_primitive(self, primitive, tracers, params)
    912 def process_primitive(self, primitive, tracers, params):
--> 913   return primitive.impl(*tracers, **params)

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/dispatch.py:87, in apply_primitive(prim, *args, **params)
     85 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     86 try:
---> 87   outs = fun(*args)
     88 finally:
     89   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 14 frame]

File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/compiler.py:238, in backend_compile(backend, module, options, host_callbacks)
    233   return backend.compile(built_c, compile_options=options,
    234                          host_callbacks=host_callbacks)
    235 # Some backends don't have `host_callbacks` option yet
    236 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    237 # to take in `host_callbacks`
--> 238 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}], function_type = (tensor<10xf64>) -> tensor<10xf64>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<10xf64>):
  "func.return"(%arg0) : (tensor<10xf64>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}], function_type = (tensor<10xf64>) -> tensor<10xf64>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<10xf64>):
  "func.return"(%arg0) : (tensor<10xf64>) -> ()
}) : () -> ()

For all I know the problem is on the jax side, and may require issues to be filled in that repo. But I think it makes sense to have a pymc issue to raise this goal as a priority and perhaps to coordinate any additional issues on the pymc or jax side.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions