Skip to content

Specialized DiscreteMarkovChain step sampler #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
ci:
autofix_prs: false

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
Expand Down
15 changes: 2 additions & 13 deletions pymc_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
import logging

from pymc_experimental import distributions, gp, statespace, utils
from pymc_experimental import gp, statespace, utils
from pymc_experimental.distributions import *
from pymc_experimental.inference.fit import fit
from pymc_experimental.model.marginal_model import MarginalModel
from pymc_experimental.model.model_api import as_model
Expand All @@ -26,15 +27,3 @@
if len(_log.handlers) == 0:
handler = logging.StreamHandler()
_log.addHandler(handler)


__all__ = [
"distributions",
"gp",
"statespace",
"utils",
"fit",
"MarginalModel",
"as_model",
"__version__",
]
86 changes: 81 additions & 5 deletions pymc_experimental/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import logp
from pymc.pytensorf import constant_fold, intX
from pymc.util import check_dist_not_registered
from pymc.step_methods import STEP_METHODS
from pymc.step_methods.arraystep import ArrayStep
from pymc.step_methods.compound import Competence
from pymc.step_methods.metropolis import CategoricalGibbsMetropolis
from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars
from pytensor import Mode
from pytensor.graph.basic import Node
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -101,10 +106,15 @@ class DiscreteMarkovChain(Distribution):
Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P,
3 in this case.

>>> with pm.Model() as markov_chain:
>>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
>>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
>>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
.. code-block:: python

import pymc as pm
import pymc_experimental as pmx

with pm.Model() as markov_chain:
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))

"""

Expand Down Expand Up @@ -266,3 +276,69 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
"P must sum to 1 along the last axis, "
"First dimension of init_dist must be n_lags",
)


class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis):
name = "discrete_markov_chain_gibbs_metropolis"

def __init__(self, vars, proposal="uniform", order="random", model=None):
model = pm.modelcontext(model)
vars = get_value_vars_from_user_vars(vars, model)
initial_point = model.initial_point()

dimcats = []
# The above variable is a list of pairs (aggregate dimension, number
# of categories). For example, if vars = [x, y] with x being a 2-D
# variable with M categories and y being a 3-D variable with N
# categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
for v in vars:
v_init_val = initial_point[v.name]
rv_var = model.values_to_rvs[v]
rv_op = rv_var.owner.op

if not isinstance(rv_op, DiscreteMarkovChainRV):
raise TypeError("All variables must be DiscreteMarkovChainRV")

k_graph = rv_var.owner.inputs[0].shape[-1]
(k_graph,) = model.replace_rvs_by_values((k_graph,))
k = model.compile_fn(
k_graph,
inputs=model.value_vars,
on_unused_input="ignore",
mode=Mode(linker="py", optimizer=None),
)(initial_point)
start = len(dimcats)
dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)]

if order == "random":
self.shuffle_dims = True
self.dimcats = dimcats
else:
if sorted(order) != list(range(len(dimcats))):
raise ValueError("Argument 'order' has to be a permutation")
self.shuffle_dims = False
self.dimcats = [dimcats[j] for j in order]

if proposal == "uniform":
self.astep = self.astep_unif
elif proposal == "proportional":
# Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
self.astep = self.astep_prop
else:
raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")

# Doesn't actually tune, but it's required to emit a sampler stat
# that indicates whether a draw was done in a tuning phase.
self.tune = True

# We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic
ArrayStep.__init__(self, vars, [model.compile_logp()])

@staticmethod
def competence(var):
if isinstance(var.owner.op, DiscreteMarkovChainRV):
return Competence.IDEAL
return Competence.INCOMPATIBLE


STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ lines-between-types = 1
'F401', # Unused import warning for test files -- this check removes imports of fixtures
'F811' # Redefine while unused -- this check fails on imported fixtures
]
'pymc_experimental/__init__.py' = ['F401', 'F403']
40 changes: 39 additions & 1 deletion tests/distributions/test_discrete_markov_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
import pytensor.tensor as pt
import pytest

from pymc.distributions import Categorical
from pymc.distributions.shape_utils import change_dist_size
from pymc.logprob.utils import ParameterValueError
from pymc.sampling.mcmc import assign_step_methods

from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
from pymc_experimental.distributions.timeseries import (
DiscreteMarkovChain,
DiscreteMarkovChainGibbsMetropolis,
)


def transition_probability_tests(steps, n_states, n_lags, n_draws, atol):
Expand Down Expand Up @@ -216,3 +221,36 @@ def test_change_size_univariate(self):

new_rw = change_dist_size(chain, new_size=(4, 3), expand=True)
assert tuple(new_rw.shape.eval()) == (4, 3, 100, 5)

def test_mcmc_sampling(self):
with pm.Model(coords={"step": range(100)}) as model:
init_dist = Categorical.dist(p=[0.5, 0.5])
DiscreteMarkovChain(
"markov_chain",
P=[[0.1, 0.9], [0.1, 0.9]],
init_dist=init_dist,
shape=(100,),
dims="step",
)

step_method = assign_step_methods(model)
assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis)

# Sampler needs no tuning
idata = pm.sample(
tune=0, chains=4, draws=250, progressbar=False, compute_convergence_checks=False
)

np.testing.assert_allclose(
idata.posterior["markov_chain"].isel(step=0).mean(("chain", "draw")),
0.5,
atol=0.05,
)

np.testing.assert_allclose(
idata.posterior["markov_chain"].isel(step=slice(1, None)).mean(("chain", "draw")),
0.9,
atol=0.05,
)

assert pm.stats.ess(idata, method="tail").min() > 950
Loading