Skip to content

Interpolated not supported in JAX/Numba backends #6838

Open
@FBruzzesi

Description

@FBruzzesi

Describe the issue:

I am raising the issue to create awareness, as I wasn't able to find anything about this error.

While trying to replicate the updating priors notebook, I noticed that if we switch sampler to numpyro, then the following error is raised

NotImplementedError: No JAX conversion for the given Op: SplineWrapper{spline=}

Reproduceable code example:

for _ in range(10):
    ...
    with Model():
        ...
        trace = sample(1000, nuts_sampler = "numpyro", nuts_sampler_kwargs = {"chain_method": "parallel"})
        traces.append(trace)

Error message:

NotImplementedError: No JAX conversion for the given Op: SplineWrapper{spline=}

PyMC version information:

pymc version 5.6.1
numpyro version 0.12.1
jax version 0.4.14

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions