Open
Description
Describe the issue:
I am raising the issue to create awareness, as I wasn't able to find anything about this error.
While trying to replicate the updating priors notebook, I noticed that if we switch sampler to numpyro, then the following error is raised
NotImplementedError: No JAX conversion for the given
Op
: SplineWrapper{spline=}
Reproduceable code example:
for _ in range(10):
...
with Model():
...
trace = sample(1000, nuts_sampler = "numpyro", nuts_sampler_kwargs = {"chain_method": "parallel"})
traces.append(trace)
Error message:
NotImplementedError: No JAX conversion for the given Op
: SplineWrapper{spline=}
PyMC version information:
pymc version 5.6.1
numpyro version 0.12.1
jax version 0.4.14