Skip to content

Commit 9ac810c

Browse files
committed
Default to JAX test mode in random tests
1 parent 0824dba commit 9ac810c

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

tests/link/jax/test_random.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
2828

2929

30-
def compile_random_function(*args, mode="JAX", **kwargs):
30+
def compile_random_function(*args, mode=jax_mode, **kwargs):
3131
with pytest.warns(
3232
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
3333
):
@@ -42,7 +42,7 @@ def test_random_RandomStream():
4242
srng = RandomStream(seed=123)
4343
out = srng.normal() - srng.normal()
4444

45-
fn = compile_random_function([], out, mode=jax_mode)
45+
fn = compile_random_function([], out)
4646
jax_res_1 = fn()
4747
jax_res_2 = fn()
4848

@@ -55,7 +55,7 @@ def test_random_updates(rng_ctor):
5555
rng = shared(original_value, name="original_rng", borrow=False)
5656
next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs
5757

58-
f = compile_random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
58+
f = compile_random_function([], [x], updates={rng: next_rng})
5959
assert f() != f()
6060

6161
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -479,7 +479,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
479479
"""
480480
rng = shared(np.random.default_rng(29403))
481481
g = rv_op(*dist_params, size=(10000, *base_size), rng=rng)
482-
g_fn = compile_random_function(dist_params, g, mode=jax_mode)
482+
g_fn = compile_random_function(dist_params, g)
483483
samples = g_fn(
484484
*[
485485
i.tag.test_value
@@ -521,7 +521,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
521521
param_that_implies_size = pt.matrix("param_that_implies_size", shape=(None, None))
522522

523523
rv = rv_fn(param_that_implies_size)
524-
draws = rv.eval({param_that_implies_size: np.zeros((2, 2))}, mode=jax_mode)
524+
draws = rv.eval({param_that_implies_size: np.zeros((2, 2))})
525525

526526
assert draws.shape == (2, 2)
527527
assert np.unique(draws).size == 4
@@ -531,7 +531,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
531531
def test_random_bernoulli(size):
532532
rng = shared(np.random.default_rng(123))
533533
g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng)
534-
g_fn = compile_random_function([], g, mode=jax_mode)
534+
g_fn = compile_random_function([], g)
535535
samples = g_fn()
536536
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
537537

@@ -542,7 +542,7 @@ def test_random_mvnormal():
542542
mu = np.ones(4)
543543
cov = np.eye(4)
544544
g = pt.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
545-
g_fn = compile_random_function([], g, mode=jax_mode)
545+
g_fn = compile_random_function([], g)
546546
samples = g_fn()
547547
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
548548

@@ -557,7 +557,7 @@ def test_random_mvnormal():
557557
def test_random_dirichlet(parameter, size):
558558
rng = shared(np.random.default_rng(123))
559559
g = pt.random.dirichlet(parameter, size=(1000, *size), rng=rng)
560-
g_fn = compile_random_function([], g, mode=jax_mode)
560+
g_fn = compile_random_function([], g)
561561
samples = g_fn()
562562
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
563563

@@ -566,7 +566,7 @@ def test_random_choice():
566566
# `replace=True` and `p is None`
567567
rng = shared(np.random.default_rng(123))
568568
g = pt.random.choice(np.arange(4), size=10_000, rng=rng)
569-
g_fn = compile_random_function([], g, mode=jax_mode)
569+
g_fn = compile_random_function([], g)
570570
samples = g_fn()
571571
assert samples.shape == (10_000,)
572572
# Elements are picked at equal frequency
@@ -575,7 +575,7 @@ def test_random_choice():
575575
# `replace=True` and `p is not None`
576576
rng = shared(np.random.default_rng(123))
577577
g = pt.random.choice(4, p=np.array([0.0, 0.5, 0.0, 0.5]), size=(5, 2), rng=rng)
578-
g_fn = compile_random_function([], g, mode=jax_mode)
578+
g_fn = compile_random_function([], g)
579579
samples = g_fn()
580580
assert samples.shape == (5, 2)
581581
# Only odd numbers are picked
@@ -584,7 +584,7 @@ def test_random_choice():
584584
# `replace=False` and `p is None`
585585
rng = shared(np.random.default_rng(123))
586586
g = pt.random.choice(np.arange(100), replace=False, size=(2, 49), rng=rng)
587-
g_fn = compile_random_function([], g, mode=jax_mode)
587+
g_fn = compile_random_function([], g)
588588
samples = g_fn()
589589
assert samples.shape == (2, 49)
590590
# Elements are unique
@@ -599,7 +599,7 @@ def test_random_choice():
599599
rng=rng,
600600
replace=False,
601601
)
602-
g_fn = compile_random_function([], g, mode=jax_mode)
602+
g_fn = compile_random_function([], g)
603603
samples = g_fn()
604604
assert samples.shape == (3,)
605605
# Elements are unique
@@ -611,14 +611,14 @@ def test_random_choice():
611611
def test_random_categorical():
612612
rng = shared(np.random.default_rng(123))
613613
g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
614-
g_fn = compile_random_function([], g, mode=jax_mode)
614+
g_fn = compile_random_function([], g)
615615
samples = g_fn()
616616
assert samples.shape == (10000, 4)
617617
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
618618

619619
# Test zero probabilities
620620
g = pt.random.categorical([0, 0.5, 0, 0.5], size=(1000,), rng=rng)
621-
g_fn = compile_random_function([], g, mode=jax_mode)
621+
g_fn = compile_random_function([], g)
622622
samples = g_fn()
623623
assert samples.shape == (1000,)
624624
assert np.all(samples % 2 == 1)
@@ -628,7 +628,7 @@ def test_random_permutation():
628628
array = np.arange(4)
629629
rng = shared(np.random.default_rng(123))
630630
g = pt.random.permutation(array, rng=rng)
631-
g_fn = compile_random_function([], g, mode=jax_mode)
631+
g_fn = compile_random_function([], g)
632632
permuted = g_fn()
633633
with pytest.raises(AssertionError):
634634
np.testing.assert_allclose(array, permuted)
@@ -651,7 +651,7 @@ def test_random_geometric():
651651
rng = shared(np.random.default_rng(123))
652652
p = np.array([0.3, 0.7])
653653
g = pt.random.geometric(p, size=(10_000, 2), rng=rng)
654-
g_fn = compile_random_function([], g, mode=jax_mode)
654+
g_fn = compile_random_function([], g)
655655
samples = g_fn()
656656
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
657657
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
@@ -662,7 +662,7 @@ def test_negative_binomial():
662662
n = np.array([10, 40])
663663
p = np.array([0.3, 0.7])
664664
g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
665-
g_fn = compile_random_function([], g, mode=jax_mode)
665+
g_fn = compile_random_function([], g)
666666
samples = g_fn()
667667
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
668668
np.testing.assert_allclose(
@@ -676,7 +676,7 @@ def test_binomial():
676676
n = np.array([10, 40])
677677
p = np.array([0.3, 0.7])
678678
g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng)
679-
g_fn = compile_random_function([], g, mode=jax_mode)
679+
g_fn = compile_random_function([], g)
680680
samples = g_fn()
681681
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
682682
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
@@ -691,7 +691,7 @@ def test_beta_binomial():
691691
a = np.array([1.5, 13])
692692
b = np.array([0.5, 9])
693693
g = pt.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
694-
g_fn = compile_random_function([], g, mode=jax_mode)
694+
g_fn = compile_random_function([], g)
695695
samples = g_fn()
696696
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
697697
np.testing.assert_allclose(
@@ -725,7 +725,7 @@ def test_vonmises_mu_outside_circle():
725725
mu = np.array([-30, 40])
726726
kappa = np.array([100, 10])
727727
g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
728-
g_fn = compile_random_function([], g, mode=jax_mode)
728+
g_fn = compile_random_function([], g)
729729
samples = g_fn()
730730
np.testing.assert_allclose(
731731
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
@@ -823,15 +823,15 @@ def test_random_concrete_shape():
823823
rng = shared(np.random.default_rng(123))
824824
x_pt = pt.dmatrix()
825825
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
826-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
826+
jax_fn = compile_random_function([x_pt], out)
827827
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
828828

829829

830830
def test_random_concrete_shape_from_param():
831831
rng = shared(np.random.default_rng(123))
832832
x_pt = pt.dmatrix()
833833
out = pt.random.normal(x_pt, 1, rng=rng)
834-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
834+
jax_fn = compile_random_function([x_pt], out)
835835
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
836836

837837

@@ -850,7 +850,7 @@ def test_random_concrete_shape_subtensor():
850850
rng = shared(np.random.default_rng(123))
851851
x_pt = pt.dmatrix()
852852
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
853-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
853+
jax_fn = compile_random_function([x_pt], out)
854854
assert jax_fn(np.ones((2, 3))).shape == (3,)
855855

856856

@@ -866,7 +866,7 @@ def test_random_concrete_shape_subtensor_tuple():
866866
rng = shared(np.random.default_rng(123))
867867
x_pt = pt.dmatrix()
868868
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
869-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
869+
jax_fn = compile_random_function([x_pt], out)
870870
assert jax_fn(np.ones((2, 3))).shape == (2,)
871871

872872

@@ -877,7 +877,7 @@ def test_random_concrete_shape_graph_input():
877877
rng = shared(np.random.default_rng(123))
878878
size_pt = pt.scalar()
879879
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
880-
jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
880+
jax_fn = compile_random_function([size_pt], out)
881881
assert jax_fn(10).shape == (10,)
882882

883883

0 commit comments

Comments
 (0)