@@ -772,17 +772,22 @@ def test_exec_nuts_init(method):
772
772
("adapt_diag" , None , does_not_raise ()),
773
773
],
774
774
)
775
- def test_default_sample_nuts_jitter (init , start , expectation ):
776
- # This test tries to check whether the starting points returned by init_nuts are actually being
777
- # used when pm.sample() is called without specifying an explicit start point (see
775
+ def test_default_sample_nuts_jitter (init , start , expectation , monkeypatch ):
776
+ # This test tries to check whether the starting points returned by init_nuts are actually
777
+ # being used when pm.sample() is called without specifying an explicit start point (see
778
778
# https://github.com/pymc-devs/pymc3/pull/4285).
779
- # A random seed was selected to make sure the initialization with "jitter+adapt_diag" would fail.
780
- # This will need to be changed in the future if the initialization or randomization method changes
781
- # or if default initialization is made more robust.
779
+ def _mocked_init_nuts (* args , ** kwargs ):
780
+ if init == 'adapt_diag' :
781
+ start_ = [{'x' : np .array (0.79788456 )}]
782
+ else :
783
+ start_ = [{'x' : np .array (- 0.04949886 )}]
784
+ _ , step = pm .init_nuts (* args , ** kwargs )
785
+ return start_ , step
786
+ monkeypatch .setattr ('pymc3.sampling.init_nuts' , _mocked_init_nuts )
782
787
with pm .Model () as m :
783
788
x = pm .HalfNormal ("x" , transform = None )
784
789
with expectation :
785
- pm .sample (tune = 1 , draws = 0 , chains = 1 , random_seed = 7 , init = init , start = start )
790
+ pm .sample (tune = 1 , draws = 0 , chains = 1 , init = init , start = start )
786
791
787
792
788
793
@pytest .fixture (scope = "class" )
0 commit comments