Skip to content

Commit 8297c5b

Browse files
Add test for warn_treedepth function
1 parent 0775f3e commit 8297c5b

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/stats/test_convergence.py

+11
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ def test_warn_divergences():
3131
assert "2 divergences after tuning" in warns[0].message
3232

3333

34+
def test_warn_treedepth():
35+
idata = arviz.from_dict(
36+
sample_stats={
37+
"reached_max_treedepth": np.array([[0, 0, 0], [0, 1, 0]]).astype(bool),
38+
}
39+
)
40+
warns = convergence.warn_treedepth(idata)
41+
assert len(warns) == 1
42+
assert "Chain 1 reached the maximum tree depth" in warns[0].message
43+
44+
3445
def test_log_warning_stats(caplog):
3546
s1 = dict(warning="Temperature too low!")
3647
s2 = dict(warning="Temperature too high!")

0 commit comments

Comments
 (0)