Skip to content

Commit fe0e0d7

Browse files
committed
Explain difference between BinaryMetropolis and BinaryGibbsMetropolis
Also test assignment of variables in BinaryMetropolis
1 parent ec6e4c0 commit fe0e0d7

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

pymc/step_methods/metropolis.py

+15
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,13 @@ class BinaryMetropolisState(StepMethodState):
383383
class BinaryMetropolis(ArrayStep):
384384
"""Metropolis-Hastings optimized for binary variables.
385385
386+
Unlike BinaryGibbsMetropolis, this step sampler proposes an update for all variable dimensions at once.
387+
388+
This will perform a single logp evaluation per step, at the expense of a lower acceptance rate when
389+
the posteriors of the binary variables are highly correlated.
390+
391+
The BinaryGibbsMetropolis (not this one) is the default step sampler for binary variables
392+
386393
Parameters
387394
----------
388395
vars: list
@@ -489,6 +496,14 @@ class BinaryGibbsMetropolisState(StepMethodState):
489496
class BinaryGibbsMetropolis(ArrayStep):
490497
"""A Metropolis-within-Gibbs step method optimized for binary variables.
491498
499+
Unlike BinaryMetropolis, this step sampler proposes a variable dimension update at a time.
500+
501+
This will increase acceptance rate when the posteriors of the binary variables are highly correlated,
502+
at the expense of doing more logp evaluations per step.
503+
504+
This is the default step sampler for binary variables.
505+
506+
492507
Parameters
493508
----------
494509
vars: list

tests/step_methods/test_metropolis.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -356,22 +356,21 @@ def test_step_continuous(self, step_fn, draws):
356356

357357
class TestRVsAssignmentMetropolis(RVsAssignmentStepsTester):
358358
@pytest.mark.parametrize(
359-
"step, step_kwargs",
359+
"step",
360360
[
361-
(BinaryGibbsMetropolis, {}),
362-
(CategoricalGibbsMetropolis, {}),
361+
BinaryMetropolis,
362+
BinaryGibbsMetropolis,
363+
CategoricalGibbsMetropolis,
363364
],
364365
)
365-
def test_discrete_steps(self, step, step_kwargs):
366+
def test_discrete_steps(self, step):
366367
with pm.Model() as m:
367368
d1 = pm.Bernoulli("d1", p=0.5)
368369
d2 = pm.Bernoulli("d2", p=0.5)
369370

370371
with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
371-
assert [m.rvs_to_values[d1]] == step([d1], **step_kwargs).vars
372-
assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(
373-
step([d1, d2], **step_kwargs).vars
374-
)
372+
assert [m.rvs_to_values[d1]] == step([d1]).vars
373+
assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(step([d1, d2]).vars)
375374

376375
@pytest.mark.parametrize(
377376
"step, step_kwargs", [(Metropolis, {}), (DEMetropolis, {}), (DEMetropolisZ, {})]

0 commit comments

Comments
 (0)