Skip to content

Commit 65eb592

Browse files
Fix duplicate "tune" stat in McBackend adapter
1 parent 38e87e2 commit 65eb592

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

pymc/backends/mcbackend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
BlockedStep,
3434
CompoundStep,
3535
StatsBijection,
36+
check_step_emits_tune,
3637
flat_statname,
3738
flatten_steps,
3839
)
@@ -207,11 +208,10 @@ def make_runmeta_and_point_fn(
207208
) -> Tuple[mcb.RunMeta, PointFunc]:
208209
variables, point_fn = get_variables_and_point_fn(model, initial_point)
209210

210-
sample_stats = [
211-
mcb.Variable("tune", "bool"),
212-
]
211+
check_step_emits_tune(step)
213212

214213
# In PyMC the sampler stats are grouped by the sampler.
214+
sample_stats = []
215215
steps = flatten_steps(step)
216216
for s, sm in enumerate(steps):
217217
for statname, (dtype, shape) in sm.stats_dtypes_shapes.items():

tests/backends/test_mcbackend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_make_runmeta_and_point_fn(simple_model):
119119
assert not vars["vector"].is_deterministic
120120
assert not vars["vector_interval__"].is_deterministic
121121
assert vars["matrix"].is_deterministic
122-
assert len(rmeta.sample_stats) == 1 + len(step.stats_dtypes[0])
122+
assert len(rmeta.sample_stats) == len(step.stats_dtypes[0])
123123
pass
124124

125125

0 commit comments

Comments
 (0)