Description
It would be great to utilise the GPU on Apple Silicon chips. The lowest resistance way of doing this is probably through the jax backend, see https://jax.readthedocs.io/en/latest/installation.html#apple-silicon-gpu-arm-based and the Apple docs Accelerated JAX training on Mac
I don't have the stats, but some sizeable portion of PyMC users run code on hardware with Apple Silicon, and this will increase over time as more people upgrade from Intel to Apple Silicon. Full utilisation of those chips (i.e. the GPU component) would likely unlock some speed-ups in sampling.
So far I have partial progress (ht to @twiecki). I have the following environment file, metal_test_env.yaml
name: metal_test_env
channels:
- conda-forge
dependencies:
- blackjax
- ipykernel
- jax==0.4.26
- jupyter
- numpy
- pip
- pymc
- python<3.11
- pip:
- jax-metal
- jaxlib==0.4.26
- ml-dtypes==0.2.0
NOTE: It seems that pinning python<3.11
is a necessity at this point in time.
I build that with:
mamba env create -f metal_test_env.yaml
mamba activate metal_test_env
Then in an ipython session we can confirm that jax has detected the Apple Silicon GPU
import jax
jax.print_environment_info()
gives
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1716460927.826518 4468516 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
I0000 00:00:1716460927.843089 4468516 service.cc:145] XLA service 0x1276ac990 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716460927.843114 4468516 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716460927.844725 4468516 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716460927.844746 4468516 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='BenjamicStudio7', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')
The key line is: jax.devices (1 total, 1 local): [METAL(id=0)]
So the next step is to see if we can do sampling:
import numpy as np
import pymc as pm
x = np.random.normal(size=10)
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
pm.Normal("x_obs", mu=mu, sigma=1, observed=x)
idata = pm.sample(nuts_sampler="blackjax")
which as of now results in this traceback
Traceback
XlaRuntimeError Traceback (most recent call last)
Cell In[6], line 8
6 mu = pm.Normal("mu", 0, 1)
7 pm.Normal("x_obs", mu=mu, sigma=1, observed=x)
----> 8 idata = pm.sample(nuts_sampler="blackjax")
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:688, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
684 if not isinstance(step, NUTS):
685 raise ValueError(
686 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
687 )
--> 688 return _sample_external_nuts(
689 sampler=nuts_sampler,
690 draws=draws,
691 tune=tune,
692 chains=chains,
693 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
694 random_seed=random_seed,
695 initvals=initvals,
696 model=model,
697 var_names=var_names,
698 progressbar=progressbar,
699 idata_kwargs=idata_kwargs,
700 nuts_sampler_kwargs=nuts_sampler_kwargs,
701 **kwargs,
702 )
704 if isinstance(step, list):
705 step = CompoundStep(step)
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:351, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
348 elif sampler in ("numpyro", "blackjax"):
349 import pymc.sampling.jax as pymc_jax
--> 351 idata = pymc_jax.sample_jax_nuts(
352 draws=draws,
353 tune=tune,
354 chains=chains,
355 target_accept=target_accept,
356 random_seed=random_seed,
357 initvals=initvals,
358 model=model,
359 var_names=var_names,
360 progressbar=progressbar,
361 nuts_sampler=sampler,
362 idata_kwargs=idata_kwargs,
363 **nuts_sampler_kwargs,
364 )
365 return idata
367 else:
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:567, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
564 raise ValueError(f"{nuts_sampler=} not recognized")
566 tic1 = datetime.now()
--> 567 raw_mcmc_samples, sample_stats, library = sampler_fn(
568 model=model,
569 target_accept=target_accept,
570 tune=tune,
571 draws=draws,
572 chains=chains,
573 chain_method=chain_method,
574 progressbar=progressbar,
575 random_seed=random_seed,
576 initial_points=initial_points,
577 nuts_kwargs=nuts_kwargs,
578 )
579 tic2 = datetime.now()
581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:398, in _sample_blackjax_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
395 if chains == 1:
396 initial_points = [np.stack(init_state) for init_state in zip(initial_points)]
--> 398 logprob_fn = get_jaxified_logp(model)
400 seed = jax.random.PRNGKey(random_seed)
401 keys = jax.random.split(seed, chains)
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:153, in get_jaxified_logp(model, negative_logp)
151 if not negative_logp:
152 model_logp = -model_logp
--> 153 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
155 def logp_fn_wrap(x):
156 return logp_fn(*x)[0]
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pymc/sampling/jax.py:146, in get_jaxified_graph(inputs, outputs)
143 mode.JAX.optimizer.rewrite(fgraph)
145 # We now jaxify the optimized fgraph
--> 146 return jax_funcify(fgraph)
File ~/mambaforge/envs/metal_test_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
44 @jax_funcify.register(FunctionGraph)
45 def jax_funcify_FunctionGraph(
46 fgraph,
(...)
49 **kwargs,
50 ):
---> 51 return fgraph_to_python(
52 fgraph,
53 jax_funcify,
54 type_conversion_fn=jax_typify,
55 fgraph_name=fgraph_name,
56 **kwargs,
57 )
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pytensor/link/utils.py:742, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
737 input_storage = storage_map.setdefault(
738 i, [None if not isinstance(i, Constant) else i.data]
739 )
740 if input_storage[0] is not None or isinstance(i, Constant):
741 # Constants need to be assigned locally and referenced
--> 742 global_env[local_input_name] = type_conversion_fn(
743 input_storage[0], variable=i, storage=input_storage, **kwargs
744 )
745 # TODO: We could attempt to use the storage arrays directly
746 # E.g. `local_input_name = f"{local_input_name}[0]"`
747 node_input_names.append(local_input_name)
File ~/mambaforge/envs/metal_test_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:35, in jax_typify_ndarray(data, dtype, **kwargs)
33 if len(data.shape) == 0:
34 return data.item()
---> 35 return jnp.array(data, dtype=dtype)
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2197, in array(object, dtype, copy, order, ndmin)
2194 else:
2195 raise TypeError(f"Unexpected input type for array: {type(object)}")
-> 2197 out_array: Array = lax_internal._convert_element_type(
2198 out, dtype, weak_type=weak_type)
2199 if ndmin > ndim(out_array):
2200 out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/lax/lax.py:558, in _convert_element_type(operand, new_dtype, weak_type)
556 return type_cast(Array, operand)
557 else:
--> 558 return convert_element_type_p.bind(operand, new_dtype=new_dtype,
559 weak_type=bool(weak_type))
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/core.py:422, in Primitive.bind(self, *args, **params)
419 def bind(self, *args, **params):
420 assert (not config.enable_checks.value or
421 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 422 return self.bind_with_trace(find_top_trace(args), args, params)
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/core.py:425, in Primitive.bind_with_trace(self, trace, args, params)
424 def bind_with_trace(self, trace, args, params):
--> 425 out = trace.process_primitive(self, map(trace.full_raise, args), params)
426 return map(full_lower, out) if self.multiple_results else full_lower(out)
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/core.py:913, in EvalTrace.process_primitive(self, primitive, tracers, params)
912 def process_primitive(self, primitive, tracers, params):
--> 913 return primitive.impl(*tracers, **params)
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/dispatch.py:87, in apply_primitive(prim, *args, **params)
85 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
86 try:
---> 87 outs = fun(*args)
88 finally:
89 lib.jax_jit.swap_thread_local_state_disable_jit(prev)
[... skipping hidden 14 frame]
File ~/mambaforge/envs/metal_test_env/lib/python3.10/site-packages/jax/_src/compiler.py:238, in backend_compile(backend, module, options, host_callbacks)
233 return backend.compile(built_c, compile_options=options,
234 host_callbacks=host_callbacks)
235 # Some backends don't have `host_callbacks` option yet
236 # TODO(sharadmv): remove this fallback when all backends allow `compile`
237 # to take in `host_callbacks`
--> 238 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation:
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}], function_type = (tensor<10xf64>) -> tensor<10xf64>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<10xf64>):
"func.return"(%arg0) : (tensor<10xf64>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation:
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}], function_type = (tensor<10xf64>) -> tensor<10xf64>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<10xf64>):
"func.return"(%arg0) : (tensor<10xf64>) -> ()
}) : () -> ()
For all I know the problem is on the jax side, and may require issues to be filled in that repo. But I think it makes sense to have a pymc issue to raise this goal as a priority and perhaps to coordinate any additional issues on the pymc or jax side.