Skip to content

Commit ff8a4c7

Browse files
Fix warn_treedepth looking at the wrong stat (#6591)
* Fix `warn_treedepth` looking at the wrong stat * Add test for `warn_treedepth` function Closes #6587
1 parent b6521f2 commit ff8a4c7

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

pymc/stats/convergence.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -154,17 +154,17 @@ def warn_treedepth(idata: arviz.InferenceData) -> List[SamplerWarning]:
154154
if sampler_stats is None:
155155
return []
156156

157-
treedepth = sampler_stats.get("tree_depth", None)
158-
if treedepth is None:
157+
rmtd = sampler_stats.get("reached_max_treedepth", None)
158+
if rmtd is None:
159159
return []
160160

161161
warnings = []
162-
for c in treedepth.chain:
163-
if sum(treedepth.sel(chain=c)) / treedepth.sizes["draw"] > 0.05:
162+
for c in rmtd.chain:
163+
if sum(rmtd.sel(chain=c)) / rmtd.sizes["draw"] > 0.05:
164164
warnings.append(
165165
SamplerWarning(
166166
WarningType.TREEDEPTH,
167-
f"Chain {c} reached the maximum tree depth."
167+
f"Chain {int(c)} reached the maximum tree depth."
168168
" Increase `max_treedepth`, increase `target_accept` or reparameterize.",
169169
"warn",
170170
)

tests/stats/test_convergence.py

+11
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ def test_warn_divergences():
3131
assert "2 divergences after tuning" in warns[0].message
3232

3333

34+
def test_warn_treedepth():
35+
idata = arviz.from_dict(
36+
sample_stats={
37+
"reached_max_treedepth": np.array([[0, 0, 0], [0, 1, 0]]).astype(bool),
38+
}
39+
)
40+
warns = convergence.warn_treedepth(idata)
41+
assert len(warns) == 1
42+
assert "Chain 1 reached the maximum tree depth" in warns[0].message
43+
44+
3445
def test_log_warning_stats(caplog):
3546
s1 = dict(warning="Temperature too low!")
3647
s2 = dict(warning="Temperature too high!")

0 commit comments

Comments
 (0)