27
27
from pytensor .link .jax .dispatch .random import numpyro_available # noqa: E402
28
28
29
29
30
- def compile_random_function (* args , mode = "JAX" , ** kwargs ):
30
+ def compile_random_function (* args , mode = jax_mode , ** kwargs ):
31
31
with pytest .warns (
32
32
UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
33
33
):
@@ -42,7 +42,7 @@ def test_random_RandomStream():
42
42
srng = RandomStream (seed = 123 )
43
43
out = srng .normal () - srng .normal ()
44
44
45
- fn = compile_random_function ([], out , mode = jax_mode )
45
+ fn = compile_random_function ([], out )
46
46
jax_res_1 = fn ()
47
47
jax_res_2 = fn ()
48
48
@@ -55,7 +55,7 @@ def test_random_updates(rng_ctor):
55
55
rng = shared (original_value , name = "original_rng" , borrow = False )
56
56
next_rng , x = pt .random .normal (name = "x" , rng = rng ).owner .outputs
57
57
58
- f = compile_random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
58
+ f = compile_random_function ([], [x ], updates = {rng : next_rng })
59
59
assert f () != f ()
60
60
61
61
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -479,7 +479,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
479
479
"""
480
480
rng = shared (np .random .default_rng (29403 ))
481
481
g = rv_op (* dist_params , size = (10000 , * base_size ), rng = rng )
482
- g_fn = compile_random_function (dist_params , g , mode = jax_mode )
482
+ g_fn = compile_random_function (dist_params , g )
483
483
samples = g_fn (
484
484
* [
485
485
i .tag .test_value
@@ -521,7 +521,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
521
521
param_that_implies_size = pt .matrix ("param_that_implies_size" , shape = (None , None ))
522
522
523
523
rv = rv_fn (param_that_implies_size )
524
- draws = rv .eval ({param_that_implies_size : np .zeros ((2 , 2 ))}, mode = jax_mode )
524
+ draws = rv .eval ({param_that_implies_size : np .zeros ((2 , 2 ))})
525
525
526
526
assert draws .shape == (2 , 2 )
527
527
assert np .unique (draws ).size == 4
@@ -531,7 +531,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
531
531
def test_random_bernoulli (size ):
532
532
rng = shared (np .random .default_rng (123 ))
533
533
g = pt .random .bernoulli (0.5 , size = (1000 , * size ), rng = rng )
534
- g_fn = compile_random_function ([], g , mode = jax_mode )
534
+ g_fn = compile_random_function ([], g )
535
535
samples = g_fn ()
536
536
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
537
537
@@ -542,7 +542,7 @@ def test_random_mvnormal():
542
542
mu = np .ones (4 )
543
543
cov = np .eye (4 )
544
544
g = pt .random .multivariate_normal (mu , cov , size = (10000 ,), rng = rng )
545
- g_fn = compile_random_function ([], g , mode = jax_mode )
545
+ g_fn = compile_random_function ([], g )
546
546
samples = g_fn ()
547
547
np .testing .assert_allclose (samples .mean (axis = 0 ), mu , atol = 0.1 )
548
548
@@ -557,7 +557,7 @@ def test_random_mvnormal():
557
557
def test_random_dirichlet (parameter , size ):
558
558
rng = shared (np .random .default_rng (123 ))
559
559
g = pt .random .dirichlet (parameter , size = (1000 , * size ), rng = rng )
560
- g_fn = compile_random_function ([], g , mode = jax_mode )
560
+ g_fn = compile_random_function ([], g )
561
561
samples = g_fn ()
562
562
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
563
563
@@ -566,7 +566,7 @@ def test_random_choice():
566
566
# `replace=True` and `p is None`
567
567
rng = shared (np .random .default_rng (123 ))
568
568
g = pt .random .choice (np .arange (4 ), size = 10_000 , rng = rng )
569
- g_fn = compile_random_function ([], g , mode = jax_mode )
569
+ g_fn = compile_random_function ([], g )
570
570
samples = g_fn ()
571
571
assert samples .shape == (10_000 ,)
572
572
# Elements are picked at equal frequency
@@ -575,7 +575,7 @@ def test_random_choice():
575
575
# `replace=True` and `p is not None`
576
576
rng = shared (np .random .default_rng (123 ))
577
577
g = pt .random .choice (4 , p = np .array ([0.0 , 0.5 , 0.0 , 0.5 ]), size = (5 , 2 ), rng = rng )
578
- g_fn = compile_random_function ([], g , mode = jax_mode )
578
+ g_fn = compile_random_function ([], g )
579
579
samples = g_fn ()
580
580
assert samples .shape == (5 , 2 )
581
581
# Only odd numbers are picked
@@ -584,7 +584,7 @@ def test_random_choice():
584
584
# `replace=False` and `p is None`
585
585
rng = shared (np .random .default_rng (123 ))
586
586
g = pt .random .choice (np .arange (100 ), replace = False , size = (2 , 49 ), rng = rng )
587
- g_fn = compile_random_function ([], g , mode = jax_mode )
587
+ g_fn = compile_random_function ([], g )
588
588
samples = g_fn ()
589
589
assert samples .shape == (2 , 49 )
590
590
# Elements are unique
@@ -599,7 +599,7 @@ def test_random_choice():
599
599
rng = rng ,
600
600
replace = False ,
601
601
)
602
- g_fn = compile_random_function ([], g , mode = jax_mode )
602
+ g_fn = compile_random_function ([], g )
603
603
samples = g_fn ()
604
604
assert samples .shape == (3 ,)
605
605
# Elements are unique
@@ -611,14 +611,14 @@ def test_random_choice():
611
611
def test_random_categorical ():
612
612
rng = shared (np .random .default_rng (123 ))
613
613
g = pt .random .categorical (0.25 * np .ones (4 ), size = (10000 , 4 ), rng = rng )
614
- g_fn = compile_random_function ([], g , mode = jax_mode )
614
+ g_fn = compile_random_function ([], g )
615
615
samples = g_fn ()
616
616
assert samples .shape == (10000 , 4 )
617
617
np .testing .assert_allclose (samples .mean (axis = 0 ), 6 / 4 , 1 )
618
618
619
619
# Test zero probabilities
620
620
g = pt .random .categorical ([0 , 0.5 , 0 , 0.5 ], size = (1000 ,), rng = rng )
621
- g_fn = compile_random_function ([], g , mode = jax_mode )
621
+ g_fn = compile_random_function ([], g )
622
622
samples = g_fn ()
623
623
assert samples .shape == (1000 ,)
624
624
assert np .all (samples % 2 == 1 )
@@ -628,7 +628,7 @@ def test_random_permutation():
628
628
array = np .arange (4 )
629
629
rng = shared (np .random .default_rng (123 ))
630
630
g = pt .random .permutation (array , rng = rng )
631
- g_fn = compile_random_function ([], g , mode = jax_mode )
631
+ g_fn = compile_random_function ([], g )
632
632
permuted = g_fn ()
633
633
with pytest .raises (AssertionError ):
634
634
np .testing .assert_allclose (array , permuted )
@@ -651,7 +651,7 @@ def test_random_geometric():
651
651
rng = shared (np .random .default_rng (123 ))
652
652
p = np .array ([0.3 , 0.7 ])
653
653
g = pt .random .geometric (p , size = (10_000 , 2 ), rng = rng )
654
- g_fn = compile_random_function ([], g , mode = jax_mode )
654
+ g_fn = compile_random_function ([], g )
655
655
samples = g_fn ()
656
656
np .testing .assert_allclose (samples .mean (axis = 0 ), 1 / p , rtol = 0.1 )
657
657
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt ((1 - p ) / p ** 2 ), rtol = 0.1 )
@@ -662,7 +662,7 @@ def test_negative_binomial():
662
662
n = np .array ([10 , 40 ])
663
663
p = np .array ([0.3 , 0.7 ])
664
664
g = pt .random .negative_binomial (n , p , size = (10_000 , 2 ), rng = rng )
665
- g_fn = compile_random_function ([], g , mode = jax_mode )
665
+ g_fn = compile_random_function ([], g )
666
666
samples = g_fn ()
667
667
np .testing .assert_allclose (samples .mean (axis = 0 ), n * (1 - p ) / p , rtol = 0.1 )
668
668
np .testing .assert_allclose (
@@ -676,7 +676,7 @@ def test_binomial():
676
676
n = np .array ([10 , 40 ])
677
677
p = np .array ([0.3 , 0.7 ])
678
678
g = pt .random .binomial (n , p , size = (10_000 , 2 ), rng = rng )
679
- g_fn = compile_random_function ([], g , mode = jax_mode )
679
+ g_fn = compile_random_function ([], g )
680
680
samples = g_fn ()
681
681
np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
682
682
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.1 )
@@ -691,7 +691,7 @@ def test_beta_binomial():
691
691
a = np .array ([1.5 , 13 ])
692
692
b = np .array ([0.5 , 9 ])
693
693
g = pt .random .betabinom (n , a , b , size = (10_000 , 2 ), rng = rng )
694
- g_fn = compile_random_function ([], g , mode = jax_mode )
694
+ g_fn = compile_random_function ([], g )
695
695
samples = g_fn ()
696
696
np .testing .assert_allclose (samples .mean (axis = 0 ), n * a / (a + b ), rtol = 0.1 )
697
697
np .testing .assert_allclose (
@@ -725,7 +725,7 @@ def test_vonmises_mu_outside_circle():
725
725
mu = np .array ([- 30 , 40 ])
726
726
kappa = np .array ([100 , 10 ])
727
727
g = pt .random .vonmises (mu , kappa , size = (10_000 , 2 ), rng = rng )
728
- g_fn = compile_random_function ([], g , mode = jax_mode )
728
+ g_fn = compile_random_function ([], g )
729
729
samples = g_fn ()
730
730
np .testing .assert_allclose (
731
731
samples .mean (axis = 0 ), (mu + np .pi ) % (2.0 * np .pi ) - np .pi , rtol = 0.1
@@ -823,15 +823,15 @@ def test_random_concrete_shape():
823
823
rng = shared (np .random .default_rng (123 ))
824
824
x_pt = pt .dmatrix ()
825
825
out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
826
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
826
+ jax_fn = compile_random_function ([x_pt ], out )
827
827
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
828
828
829
829
830
830
def test_random_concrete_shape_from_param ():
831
831
rng = shared (np .random .default_rng (123 ))
832
832
x_pt = pt .dmatrix ()
833
833
out = pt .random .normal (x_pt , 1 , rng = rng )
834
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
834
+ jax_fn = compile_random_function ([x_pt ], out )
835
835
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
836
836
837
837
@@ -850,7 +850,7 @@ def test_random_concrete_shape_subtensor():
850
850
rng = shared (np .random .default_rng (123 ))
851
851
x_pt = pt .dmatrix ()
852
852
out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
853
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
853
+ jax_fn = compile_random_function ([x_pt ], out )
854
854
assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
855
855
856
856
@@ -866,7 +866,7 @@ def test_random_concrete_shape_subtensor_tuple():
866
866
rng = shared (np .random .default_rng (123 ))
867
867
x_pt = pt .dmatrix ()
868
868
out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
869
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
869
+ jax_fn = compile_random_function ([x_pt ], out )
870
870
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
871
871
872
872
@@ -877,7 +877,7 @@ def test_random_concrete_shape_graph_input():
877
877
rng = shared (np .random .default_rng (123 ))
878
878
size_pt = pt .scalar ()
879
879
out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
880
- jax_fn = compile_random_function ([size_pt ], out , mode = jax_mode )
880
+ jax_fn = compile_random_function ([size_pt ], out )
881
881
assert jax_fn (10 ).shape == (10 ,)
882
882
883
883
0 commit comments