Open
Description
Describe the issue:
When I use the blackjax backend, I get datetype errors.
the code runs fine with the pymc and nutpie samplers.
Reproduceable code example:
import numpy as np
import pymc as pm
from pymc import HalfCauchy, Model, Normal, sample
if __name__ == "__main__":
print(f"Running on PyMC v{pm.__version__}")
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
y = 1 + rng.normal(scale=0.5, size=200)
with Model() as model:
sigma = HalfCauchy("sigma", beta=10)
mu = Normal("mu", mu=0, sigma=1)
_ = Normal("y", mu=mu, sigma=sigma, observed=y)
idata = sample(3000, progressbar=True, nuts_sampler="blackjax")
Error message:
<details>
XlaRuntimeError: INTERNAL: Compute error: CpuCallback error: Traceback (most recent call last):
File "C:\.....\Lib\site-packages\jax\_src\interpreters\mlir.py", line 2781, in _wrapped_callback
RuntimeError: Incorrect output dtype for return value #0: Expected: int64, Actual: int32
<details>
PyMC version information:
Platform windows 11 (winpython distribution), Python 3.12.6, PyMC v5.18.2, blackjax 1.2.4
Context for the issue:
No response