Skip to content

Commit 75789de

Browse files
committed
Implement InvGamma and Multinomial in Numba
1 parent 084bfd7 commit 75789de

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

pytensor/link/numba/dispatch/random.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
6464
@numba_core_rv_funcify.register(ptr.LaplaceRV)
6565
@numba_core_rv_funcify.register(ptr.BinomialRV)
6666
@numba_core_rv_funcify.register(ptr.NegBinomialRV)
67-
@numba_core_rv_funcify.register(ptr.MultinomialRV)
6867
@numba_core_rv_funcify.register(ptr.PermutationRV)
6968
@numba_core_rv_funcify.register(ptr.IntegersRV)
7069
def numba_core_rv_default(op, node):
@@ -132,6 +131,15 @@ def random(rng, b, scale):
132131
return random
133132

134133

134+
@numba_core_rv_funcify.register(ptr.InvGammaRV)
135+
def numba_core_InvGammaRV(op, node):
136+
@numba_basic.numba_njit
137+
def random(rng, shape, scale):
138+
return 1 / rng.gamma(shape, 1 / scale)
139+
140+
return random
141+
142+
135143
@numba_core_rv_funcify.register(ptr.CategoricalRV)
136144
def core_CategoricalRV(op, node):
137145
@numba_basic.numba_njit
@@ -142,6 +150,29 @@ def random_fn(rng, p):
142150
return random_fn
143151

144152

153+
@numba_core_rv_funcify.register(ptr.MultinomialRV)
154+
def core_MultinomialRV(op, node):
155+
dtype = op.dtype
156+
157+
@numba_basic.numba_njit
158+
def random_fn(rng, n, p):
159+
n_cat = p.shape[0]
160+
draws = np.zeros(n_cat, dtype=dtype)
161+
remaining_p = np.float64(1.0)
162+
remaining_n = n
163+
for i in range(n_cat - 1):
164+
draws[i] = rng.binomial(remaining_n, p[i] / remaining_p)
165+
remaining_n -= draws[i]
166+
if remaining_n <= 0:
167+
break
168+
remaining_p -= p[i]
169+
if remaining_n > 0:
170+
draws[n_cat - 1] = remaining_n
171+
return draws
172+
173+
return random_fn
174+
175+
145176
@numba_core_rv_funcify.register(ptr.MvNormalRV)
146177
def core_MvNormalRV(op, node):
147178
method = op.method

tests/link/numba/test_random.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,31 @@ def test_multivariate_normal():
514514
],
515515
(pt.as_tensor([2, 1])),
516516
),
517+
(
518+
ptr.invgamma,
519+
[
520+
(
521+
pt.dvector("shape"),
522+
np.array([1.0, 2.0], dtype=np.float64),
523+
),
524+
(
525+
pt.dvector("scale"),
526+
np.array([0.5, 3.0], dtype=np.float64),
527+
),
528+
],
529+
(2,),
530+
),
531+
(
532+
ptr.multinomial,
533+
[
534+
(
535+
pt.lvector("n"),
536+
np.array([1, 10, 1000], dtype=np.int64),
537+
),
538+
(pt.dvector("p"), np.array([0.3, 0.7], dtype=np.float64)),
539+
],
540+
None,
541+
),
517542
],
518543
ids=str,
519544
)

0 commit comments

Comments
 (0)