Open
Description
Before
with pm.Model() as m:
a = pm.TruncatedNormal("a", 0, 1, lower=-1, upper=1)
pm.draw(a, mode="JAX") # Fails
pm.draw(a, mode="NUMBA") # Fails
After
with pm.Model() as m:
a = pm.TruncatedNormal("a", 0, 1, lower=-1, upper=1)
pm.draw(a, mode="JAX") # Works
pm.draw(a, mode="NUMBA") # Works
Context for the issue:
The TruncatedNormal
distribution creates a TruncatedNormalRV
Op
that doesn't have dispatch rules for either JAX or NUMBA. The Truncated
class itself seems to work though. In general, it would be nice to know where to write dispatches for these special pymc Ops.