Skip to content

Commit e2db2ce

Browse files
committed
Implement step method sampler for DiscreteMarkovChain
1 parent 19d1bd0 commit e2db2ce

File tree

2 files changed

+122
-6
lines changed

2 files changed

+122
-6
lines changed

pymc_experimental/distributions/timeseries.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from pymc.logprob.abstract import _logprob
2121
from pymc.logprob.basic import logp
2222
from pymc.pytensorf import constant_fold, intX
23-
from pymc.util import check_dist_not_registered
23+
from pymc.step_methods import STEP_METHODS
24+
from pymc.step_methods.arraystep import ArrayStep
25+
from pymc.step_methods.compound import Competence
26+
from pymc.step_methods.metropolis import CategoricalGibbsMetropolis
27+
from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars
28+
from pytensor import Mode
2429
from pytensor.graph.basic import Node
2530
from pytensor.tensor import TensorVariable
2631
from pytensor.tensor.random.op import RandomVariable
@@ -101,10 +106,15 @@ class DiscreteMarkovChain(Distribution):
101106
Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P,
102107
3 in this case.
103108
104-
>>> with pm.Model() as markov_chain:
105-
>>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
106-
>>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
107-
>>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
109+
.. code-block:: python
110+
111+
import pymc as pm
112+
import pymc_experimental as pmx
113+
114+
with pm.Model() as markov_chain:
115+
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
116+
init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
117+
markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
108118
109119
"""
110120

@@ -266,3 +276,70 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
266276
"P must sum to 1 along the last axis, "
267277
"First dimension of init_dist must be n_lags",
268278
)
279+
280+
281+
class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis):
282+
283+
name = "discrete_markov_chain_gibbs_metropolis"
284+
285+
def __init__(self, vars, proposal="uniform", order="random", model=None):
286+
model = pm.modelcontext(model)
287+
vars = get_value_vars_from_user_vars(vars, model)
288+
initial_point = model.initial_point()
289+
290+
dimcats = []
291+
# The above variable is a list of pairs (aggregate dimension, number
292+
# of categories). For example, if vars = [x, y] with x being a 2-D
293+
# variable with M categories and y being a 3-D variable with N
294+
# categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
295+
for v in vars:
296+
v_init_val = initial_point[v.name]
297+
rv_var = model.values_to_rvs[v]
298+
rv_op = rv_var.owner.op
299+
300+
if not isinstance(rv_op, DiscreteMarkovChainRV):
301+
raise TypeError("All variables must be DiscreteMarkovChainRV")
302+
303+
k_graph = rv_var.owner.inputs[0].shape[-1]
304+
(k_graph,) = model.replace_rvs_by_values((k_graph,))
305+
k = model.compile_fn(
306+
k_graph,
307+
inputs=model.value_vars,
308+
on_unused_input="ignore",
309+
mode=Mode(linker="py", optimizer=None),
310+
)(initial_point)
311+
start = len(dimcats)
312+
dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)]
313+
314+
if order == "random":
315+
self.shuffle_dims = True
316+
self.dimcats = dimcats
317+
else:
318+
if sorted(order) != list(range(len(dimcats))):
319+
raise ValueError("Argument 'order' has to be a permutation")
320+
self.shuffle_dims = False
321+
self.dimcats = [dimcats[j] for j in order]
322+
323+
if proposal == "uniform":
324+
self.astep = self.astep_unif
325+
elif proposal == "proportional":
326+
# Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
327+
self.astep = self.astep_prop
328+
else:
329+
raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")
330+
331+
# Doesn't actually tune, but it's required to emit a sampler stat
332+
# that indicates whether a draw was done in a tuning phase.
333+
self.tune = True
334+
335+
# We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic
336+
ArrayStep.__init__(self, vars, [model.compile_logp()])
337+
338+
@staticmethod
339+
def competence(var):
340+
if isinstance(var.owner.op, DiscreteMarkovChainRV):
341+
return Competence.IDEAL
342+
return Competence.INCOMPATIBLE
343+
344+
345+
STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis)

pymc_experimental/tests/distributions/test_discrete_markov_chain.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@
44
# general imports
55
import pytensor.tensor as pt
66
import pytest
7+
from pymc.distributions import Categorical
78
from pymc.distributions.shape_utils import change_dist_size
89
from pymc.logprob.utils import ParameterValueError
10+
from pymc.sampling.mcmc import assign_step_methods
911

10-
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
12+
from pymc_experimental.distributions.timeseries import (
13+
DiscreteMarkovChain,
14+
DiscreteMarkovChainGibbsMetropolis,
15+
)
1116

1217

1318
def transition_probability_tests(steps, n_states, n_lags, n_draws, atol):
@@ -215,3 +220,37 @@ def test_change_size_univariate(self):
215220

216221
new_rw = change_dist_size(chain, new_size=(4, 3), expand=True)
217222
assert tuple(new_rw.shape.eval()) == (4, 3, 100, 5)
223+
224+
def test_mcmc_sampling(self):
225+
226+
with pm.Model(coords={"step": range(100)}) as model:
227+
init_dist = Categorical.dist(p=[0.5, 0.5])
228+
DiscreteMarkovChain(
229+
"markov_chain",
230+
P=[[0.1, 0.9], [0.1, 0.9]],
231+
init_dist=init_dist,
232+
shape=(100,),
233+
dims="step",
234+
)
235+
236+
step_method = assign_step_methods(model)
237+
assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis)
238+
239+
# Sampler needs no tuning
240+
idata = pm.sample(
241+
tune=0, chains=4, draws=250, progressbar=False, compute_convergence_checks=False
242+
)
243+
244+
np.testing.assert_allclose(
245+
idata.posterior["markov_chain"].isel(step=0).mean(("chain", "draw")),
246+
0.5,
247+
atol=0.05,
248+
)
249+
250+
np.testing.assert_allclose(
251+
idata.posterior["markov_chain"].isel(step=slice(1, None)).mean(("chain", "draw")),
252+
0.9,
253+
atol=0.05,
254+
)
255+
256+
assert pm.stats.ess(idata, method="tail").min() > 950

0 commit comments

Comments
 (0)