Skip to content

Commit 8699dad

Browse files
committed
Use monkeypatch for more robust test
1 parent 98187ad commit 8699dad

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

pymc3/tests/test_sampling.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -772,17 +772,22 @@ def test_exec_nuts_init(method):
772772
("adapt_diag", None, does_not_raise()),
773773
],
774774
)
775-
def test_default_sample_nuts_jitter(init, start, expectation):
776-
# This test tries to check whether the starting points returned by init_nuts are actually being
777-
# used when pm.sample() is called without specifying an explicit start point (see
775+
def test_default_sample_nuts_jitter(init, start, expectation, monkeypatch):
776+
# This test tries to check whether the starting points returned by init_nuts are actually
777+
# being used when pm.sample() is called without specifying an explicit start point (see
778778
# https://github.com/pymc-devs/pymc3/pull/4285).
779-
# A random seed was selected to make sure the initialization with "jitter+adapt_diag" would fail.
780-
# This will need to be changed in the future if the initialization or randomization method changes
781-
# or if default initialization is made more robust.
779+
def _mocked_init_nuts(*args, **kwargs):
780+
if init == 'adapt_diag':
781+
start_ = [{'x': np.array(0.79788456)}]
782+
else:
783+
start_ = [{'x': np.array(-0.04949886)}]
784+
_, step = pm.init_nuts(*args, **kwargs)
785+
return start_, step
786+
monkeypatch.setattr('pymc3.sampling.init_nuts', _mocked_init_nuts)
782787
with pm.Model() as m:
783788
x = pm.HalfNormal("x", transform=None)
784789
with expectation:
785-
pm.sample(tune=1, draws=0, chains=1, random_seed=7, init=init, start=start)
790+
pm.sample(tune=1, draws=0, chains=1, init=init, start=start)
786791

787792

788793
@pytest.fixture(scope="class")

0 commit comments

Comments
 (0)