Skip to content

Commit 5e28205

Browse files
committed
Update RNG in numba Dirichlet test
1 parent db0ce0e commit 5e28205

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

tests/link/numba/test_random.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -653,15 +653,11 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
653653
def test_DirichletRV(a, size, cm):
654654
a, a_val = a
655655
rng = shared(np.random.default_rng(29402))
656-
g = ptr.dirichlet(a, size=size, rng=rng)
657-
g_fn = function([a], g, mode=numba_mode)
656+
next_rng, g = ptr.dirichlet(a, size=size, rng=rng).owner.outputs
657+
g_fn = function([a], g, mode=numba_mode, updates={rng: next_rng})
658658

659659
with cm:
660-
all_samples = []
661-
for i in range(1000):
662-
samples = g_fn(a_val)
663-
all_samples.append(samples)
664-
660+
all_samples = [g_fn(a_val) for _ in range(1000)]
665661
exp_res = a_val / a_val.sum(-1)
666662
res = np.mean(all_samples, axis=tuple(range(0, a_val.ndim - 1)))
667663
assert np.allclose(res, exp_res, atol=1e-4)

0 commit comments

Comments
 (0)