Skip to content

Commit 9bf2190

Browse files
committed
Fix error in warn_treedepth when using multiple NUTS sampler
1 parent 6c6fd13 commit 9bf2190

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

pymc/stats/convergence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def warn_treedepth(idata: arviz.InferenceData) -> list[SamplerWarning]:
164164

165165
warnings = []
166166
for c in rmtd.chain:
167-
if sum(rmtd.sel(chain=c)) / rmtd.sizes["draw"] > 0.05:
167+
if (rmtd.sel(chain=c).mean("draw") > 0.05).any():
168168
warnings.append(
169169
SamplerWarning(
170170
WarningType.TREEDEPTH,

pymc/step_methods/hmc/nuts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _hamiltonian_step(self, start, p0, step_size):
198198

199199
if divergence_info or turning:
200200
break
201-
else:
201+
else: # no-break
202202
reached_max_treedepth = not self.tune
203203

204204
stats = tree.stats()

tests/stats/test_convergence.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,22 @@ def test_warn_treedepth():
4242
assert "Chain 1 reached the maximum tree depth" in warns[0].message
4343

4444

45+
def test_warn_treedepth_multiple_samplers():
46+
"""Check we handle cases when sampling with multiple NUTS samplers, each of which reports max_treedepth."""
47+
max_treedepth = np.zeros((3, 2, 2), dtype=bool)
48+
max_treedepth[0, 0, 0] = True
49+
max_treedepth[2, 1, 1] = True
50+
idata = arviz.from_dict(
51+
sample_stats={
52+
"reached_max_treedepth": max_treedepth,
53+
}
54+
)
55+
warns = convergence.warn_treedepth(idata)
56+
assert len(warns) == 2
57+
assert "Chain 0 reached the maximum tree depth" in warns[0].message
58+
assert "Chain 2 reached the maximum tree depth" in warns[1].message
59+
60+
4561
def test_log_warning_stats(caplog):
4662
s1 = dict(warning="Temperature too low!")
4763
s2 = dict(warning="Temperature too high!")

0 commit comments

Comments
 (0)