Skip to content

Commit 98187ad

Browse files
committed
- Improve test documentation and add a new condition
1 parent 38d21c8 commit 98187ad

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

pymc3/tests/test_sampling.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,11 +769,15 @@ def test_exec_nuts_init(method):
769769
("jitter+adapt_diag", None, pytest.raises(SamplingError)),
770770
("auto", {"x": 0}, does_not_raise()),
771771
("jitter+adapt_diag", {"x": 0}, does_not_raise()),
772+
("adapt_diag", None, does_not_raise()),
772773
],
773774
)
774775
def test_default_sample_nuts_jitter(init, start, expectation):
775-
# Random seed was selected to make sure initialization with "jitter+adapt_diag" would fail.
776-
# This will need to be changed in the future if initialization or randomization method changes
776+
# This test tries to check whether the starting points returned by init_nuts are actually being
777+
# used when pm.sample() is called without specifying an explicit start point (see
778+
# https://github.com/pymc-devs/pymc3/pull/4285).
779+
# A random seed was selected to make sure the initialization with "jitter+adapt_diag" would fail.
780+
# This will need to be changed in the future if the initialization or randomization method changes
777781
# or if default initialization is made more robust.
778782
with pm.Model() as m:
779783
x = pm.HalfNormal("x", transform=None)

0 commit comments

Comments
 (0)