Closed
Description
Hi all,
I'm trying to sample from a pretty simple model with the JAX backend in PyMC v4 and am running into an issue. This one worked fine with PyMC v3, if that's useful to know.
The model is here:
import pymc as pm
import numpy as np
import pymc.sampling_jax
# Generate bogus data to reproduce error
n_players = 100
winner_ids = np.random.randint(n_players - 1, size=200)
loser_ids = np.random.randint(n_players - 1, size=200)
winner_ids[winner_ids == loser_ids] += 1
with pm.Model() as model:
player_sd = pm.HalfNormal("player_prior_sd", sigma=1.0)
player_skills = pm.Normal("player_skills", 0.0, sigma=player_sd, shape=(n_players,))
logit_skills = player_skills[winner_ids] - player_skills[loser_ids]
# Bradley-Terry likelihood
lik = pm.Bernoulli(
"win_lik", logit_p=logit_skills, observed=np.ones(winner_ids.shape[0])
)
with model:
jax_res = pymc.sampling_jax.sample_numpyro_nuts()
I get the following error:
Compiling...
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/pymc/model.py:925: FutureWarning: `Model.initial_point` has been deprecated. Use `Model.recompute_initial_point(seed=None)`.
warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/graph/opt.py:232: UserWarning: Supervisor is not added. Please build a FunctionGraphvia aesara.compile.function.types.std_graph()or add the Supervisor class manually.
sub_prof = optimizer.optimize(fgraph)
Traceback (most recent call last):
File "error_repro.py", line 43, in <module>
jax_res = pymc.sampling_jax.sample_numpyro_nuts()
File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/pymc/sampling_jax.py", line 175, in sample_numpyro_nuts
logp_fn = get_jaxified_logp(model)
File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/pymc/sampling_jax.py", line 81, in get_jaxified_logp
logp_fn = jax_funcify(logpt_fgraph)
File "/home/martin/miniconda3/envs/dl/lib/python3.8/functools.py", line 875, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/jax/dispatch.py", line 626, in jax_funcify_FunctionGraph
return fgraph_to_python(
File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/utils.py", line 724, in fgraph_to_python
compiled_func = op_conversion_fn(
File "/home/martin/miniconda3/envs/dl/lib/python3.8/functools.py", line 875, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/jax/dispatch.py", line 143, in jax_funcify
raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
NotImplementedError: No JAX conversion for the given `Op`: Check{sigma > 0}
Here are my versions:
- PyMC commit: 95bd5e5
- aesara v2.3.1
- numpyro v0.8.0
- JAX v0.2.13
- jaxlib v0.1.65
I hope this is helpful -- looking forward to your comments!
All the best,
Martin
Metadata
Metadata
Assignees
Labels
No labels