Skip to content

ENH: Add dispatches of TruncatedNormal distribution for forward sampling #7489

Open
@lucianopaz

Description

@lucianopaz

Before

with pm.Model() as m:
    a = pm.TruncatedNormal("a", 0, 1, lower=-1, upper=1)

pm.draw(a, mode="JAX") # Fails
pm.draw(a, mode="NUMBA") # Fails

After

with pm.Model() as m:
    a = pm.TruncatedNormal("a", 0, 1, lower=-1, upper=1)

pm.draw(a, mode="JAX") # Works
pm.draw(a, mode="NUMBA") # Works

Context for the issue:

The TruncatedNormal distribution creates a TruncatedNormalRV Op that doesn't have dispatch rules for either JAX or NUMBA. The Truncated class itself seems to work though. In general, it would be nice to know where to write dispatches for these special pymc Ops.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions