Skip to content

BUG: blackjax sampler gives Incorrect output dtype for return value #0: Expected: int64, Actual: int32 #7593

Open
@mvds314

Description

@mvds314

Describe the issue:

When I use the blackjax backend, I get datetype errors.

the code runs fine with the pymc and nutpie samplers.

Reproduceable code example:

import numpy as np
import pymc as pm

from pymc import HalfCauchy, Model, Normal, sample

if __name__ == "__main__":
    print(f"Running on PyMC v{pm.__version__}")
    RANDOM_SEED = 8927
    rng = np.random.default_rng(RANDOM_SEED)
    y = 1 + rng.normal(scale=0.5, size=200)
    with Model() as model:
        sigma = HalfCauchy("sigma", beta=10)
        mu = Normal("mu", mu=0, sigma=1)
        _ = Normal("y", mu=mu, sigma=sigma, observed=y)
        idata = sample(3000, progressbar=True, nuts_sampler="blackjax")

Error message:

<details>
XlaRuntimeError: INTERNAL: Compute error: CpuCallback error: Traceback (most recent call last):
  File "C:\.....\Lib\site-packages\jax\_src\interpreters\mlir.py", line 2781, in _wrapped_callback
RuntimeError: Incorrect output dtype for return value #0: Expected: int64, Actual: int32
<details>

PyMC version information:

Platform windows 11 (winpython distribution), Python 3.12.6, PyMC v5.18.2, blackjax 1.2.4

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions