Skip to content

NotImplementedError with JAX backend and PyMC v4 #5240

Closed
@martiningram

Description

@martiningram

Hi all,

I'm trying to sample from a pretty simple model with the JAX backend in PyMC v4 and am running into an issue. This one worked fine with PyMC v3, if that's useful to know.

The model is here:

import pymc as pm                                                                                                                                                                                                    
import numpy as np                                                                                                                                                                                                   
import pymc.sampling_jax

# Generate bogus data to reproduce error
n_players = 100                                                                                           
winner_ids = np.random.randint(n_players - 1, size=200)                                                   
loser_ids = np.random.randint(n_players - 1, size=200)                                                    
winner_ids[winner_ids == loser_ids] += 1

with pm.Model() as model:                                                                                 
                                                                                                          
    player_sd = pm.HalfNormal("player_prior_sd", sigma=1.0)                                               
    player_skills = pm.Normal("player_skills", 0.0, sigma=player_sd, shape=(n_players,))                  
    logit_skills = player_skills[winner_ids] - player_skills[loser_ids]                                   
                                                                                                          
    # Bradley-Terry likelihood                                                                            
    lik = pm.Bernoulli(                                                                                   
        "win_lik", logit_p=logit_skills, observed=np.ones(winner_ids.shape[0])                            
    )

with model:                                                                                               
                                                                                                          
    jax_res = pymc.sampling_jax.sample_numpyro_nuts()

I get the following error:

Compiling...
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/pymc/model.py:925: FutureWarning: `Model.initial_point` has been deprecated. Use `Model.recompute_initial_point(seed=None)`.
  warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/graph/opt.py:232: UserWarning: Supervisor is not added. Please build a FunctionGraphvia aesara.compile.function.types.std_graph()or add the Supervisor class manually.
  sub_prof = optimizer.optimize(fgraph)
Traceback (most recent call last):
  File "error_repro.py", line 43, in <module>
    jax_res = pymc.sampling_jax.sample_numpyro_nuts()
  File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/pymc/sampling_jax.py", line 175, in sample_numpyro_nuts
    logp_fn = get_jaxified_logp(model)
  File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/pymc/sampling_jax.py", line 81, in get_jaxified_logp
    logp_fn = jax_funcify(logpt_fgraph)
  File "/home/martin/miniconda3/envs/dl/lib/python3.8/functools.py", line 875, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/jax/dispatch.py", line 626, in jax_funcify_FunctionGraph
    return fgraph_to_python(
  File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/utils.py", line 724, in fgraph_to_python
    compiled_func = op_conversion_fn(
  File "/home/martin/miniconda3/envs/dl/lib/python3.8/functools.py", line 875, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/jax/dispatch.py", line 143, in jax_funcify
    raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
NotImplementedError: No JAX conversion for the given `Op`: Check{sigma > 0}

Here are my versions:

  • PyMC commit: 95bd5e5
  • aesara v2.3.1
  • numpyro v0.8.0
  • JAX v0.2.13
  • jaxlib v0.1.65

I hope this is helpful -- looking forward to your comments!

All the best,
Martin

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions