Skip to content

Commit 5444083

Browse files
committed
Implement step method sampler for DiscreteMarkovChain
1 parent 0d3021f commit 5444083

File tree

2 files changed

+120
-6
lines changed

2 files changed

+120
-6
lines changed

pymc_experimental/distributions/timeseries.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from pymc.logprob.abstract import _logprob
2121
from pymc.logprob.basic import logp
2222
from pymc.pytensorf import constant_fold, intX
23-
from pymc.util import check_dist_not_registered
23+
from pymc.step_methods import STEP_METHODS
24+
from pymc.step_methods.arraystep import ArrayStep
25+
from pymc.step_methods.compound import Competence
26+
from pymc.step_methods.metropolis import CategoricalGibbsMetropolis
27+
from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars
28+
from pytensor import Mode
2429
from pytensor.graph.basic import Node
2530
from pytensor.tensor import TensorVariable
2631
from pytensor.tensor.random.op import RandomVariable
@@ -101,10 +106,15 @@ class DiscreteMarkovChain(Distribution):
101106
Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P,
102107
3 in this case.
103108
104-
>>> with pm.Model() as markov_chain:
105-
>>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
106-
>>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
107-
>>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
109+
.. code-block:: python
110+
111+
import pymc as pm
112+
import pymc_experimental as pmx
113+
114+
with pm.Model() as markov_chain:
115+
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
116+
init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
117+
markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
108118
109119
"""
110120

@@ -266,3 +276,69 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
266276
"P must sum to 1 along the last axis, "
267277
"First dimension of init_dist must be n_lags",
268278
)
279+
280+
281+
class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis):
282+
name = "discrete_markov_chain_gibbs_metropolis"
283+
284+
def __init__(self, vars, proposal="uniform", order="random", model=None):
285+
model = pm.modelcontext(model)
286+
vars = get_value_vars_from_user_vars(vars, model)
287+
initial_point = model.initial_point()
288+
289+
dimcats = []
290+
# The above variable is a list of pairs (aggregate dimension, number
291+
# of categories). For example, if vars = [x, y] with x being a 2-D
292+
# variable with M categories and y being a 3-D variable with N
293+
# categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
294+
for v in vars:
295+
v_init_val = initial_point[v.name]
296+
rv_var = model.values_to_rvs[v]
297+
rv_op = rv_var.owner.op
298+
299+
if not isinstance(rv_op, DiscreteMarkovChainRV):
300+
raise TypeError("All variables must be DiscreteMarkovChainRV")
301+
302+
k_graph = rv_var.owner.inputs[0].shape[-1]
303+
(k_graph,) = model.replace_rvs_by_values((k_graph,))
304+
k = model.compile_fn(
305+
k_graph,
306+
inputs=model.value_vars,
307+
on_unused_input="ignore",
308+
mode=Mode(linker="py", optimizer=None),
309+
)(initial_point)
310+
start = len(dimcats)
311+
dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)]
312+
313+
if order == "random":
314+
self.shuffle_dims = True
315+
self.dimcats = dimcats
316+
else:
317+
if sorted(order) != list(range(len(dimcats))):
318+
raise ValueError("Argument 'order' has to be a permutation")
319+
self.shuffle_dims = False
320+
self.dimcats = [dimcats[j] for j in order]
321+
322+
if proposal == "uniform":
323+
self.astep = self.astep_unif
324+
elif proposal == "proportional":
325+
# Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
326+
self.astep = self.astep_prop
327+
else:
328+
raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")
329+
330+
# Doesn't actually tune, but it's required to emit a sampler stat
331+
# that indicates whether a draw was done in a tuning phase.
332+
self.tune = True
333+
334+
# We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic
335+
ArrayStep.__init__(self, vars, [model.compile_logp()])
336+
337+
@staticmethod
338+
def competence(var):
339+
if isinstance(var.owner.op, DiscreteMarkovChainRV):
340+
return Competence.IDEAL
341+
return Competence.INCOMPATIBLE
342+
343+
344+
STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis)

tests/distributions/test_discrete_markov_chain.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55
import pytensor.tensor as pt
66
import pytest
77

8+
from pymc.distributions import Categorical
89
from pymc.distributions.shape_utils import change_dist_size
910
from pymc.logprob.utils import ParameterValueError
11+
from pymc.sampling.mcmc import assign_step_methods
1012

11-
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
13+
from pymc_experimental.distributions.timeseries import (
14+
DiscreteMarkovChain,
15+
DiscreteMarkovChainGibbsMetropolis,
16+
)
1217

1318

1419
def transition_probability_tests(steps, n_states, n_lags, n_draws, atol):
@@ -216,3 +221,36 @@ def test_change_size_univariate(self):
216221

217222
new_rw = change_dist_size(chain, new_size=(4, 3), expand=True)
218223
assert tuple(new_rw.shape.eval()) == (4, 3, 100, 5)
224+
225+
def test_mcmc_sampling(self):
226+
with pm.Model(coords={"step": range(100)}) as model:
227+
init_dist = Categorical.dist(p=[0.5, 0.5])
228+
DiscreteMarkovChain(
229+
"markov_chain",
230+
P=[[0.1, 0.9], [0.1, 0.9]],
231+
init_dist=init_dist,
232+
shape=(100,),
233+
dims="step",
234+
)
235+
236+
step_method = assign_step_methods(model)
237+
assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis)
238+
239+
# Sampler needs no tuning
240+
idata = pm.sample(
241+
tune=0, chains=4, draws=250, progressbar=False, compute_convergence_checks=False
242+
)
243+
244+
np.testing.assert_allclose(
245+
idata.posterior["markov_chain"].isel(step=0).mean(("chain", "draw")),
246+
0.5,
247+
atol=0.05,
248+
)
249+
250+
np.testing.assert_allclose(
251+
idata.posterior["markov_chain"].isel(step=slice(1, None)).mean(("chain", "draw")),
252+
0.9,
253+
atol=0.05,
254+
)
255+
256+
assert pm.stats.ess(idata, method="tail").min() > 950

0 commit comments

Comments
 (0)