Skip to content

Commit 747db63

Browse files
authored
SMC: refactor, speed-up and run multiple chains in parallel for diagnostics (#3981)
* first attempt to vectorize smc kernel * add ess, remove multiprocessing * run multiple chains * remove unused imports * add more info to report * minor fix * test log * fix type_num error * remove unused imports update BF notebook * update notebook with diagnostics * update notebooks * update notebook * update notebook
1 parent facbdf1 commit 747db63

File tree

5 files changed

+580
-340
lines changed

5 files changed

+580
-340
lines changed

docs/source/notebooks/Bayes_factor.ipynb

Lines changed: 112 additions & 61 deletions
Large diffs are not rendered by default.

docs/source/notebooks/SMC2_gaussians.ipynb

Lines changed: 291 additions & 75 deletions
Large diffs are not rendered by default.

pymc3/smc/sample_smc.py

Lines changed: 124 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,37 @@
1414

1515
import time
1616
import logging
17+
import warnings
18+
from collections.abc import Iterable
19+
import multiprocessing as mp
20+
import numpy as np
21+
1722
from .smc import SMC
23+
from ..model import modelcontext
24+
from ..backends.base import MultiTrace
25+
from ..parallel_sampling import _cpu_count
26+
27+
EXPERIMENTAL_WARNING = (
28+
"Warning: SMC-ABC is an experimental step method and not yet recommended for use in PyMC3!"
29+
)
1830

1931

2032
def sample_smc(
21-
draws=1000,
33+
draws=2000,
2234
kernel="metropolis",
2335
n_steps=25,
24-
parallel=False,
2536
start=None,
26-
cores=None,
2737
tune_steps=True,
2838
p_acc_rate=0.99,
2939
threshold=0.5,
3040
epsilon=1.0,
3141
dist_func="gaussian_kernel",
3242
sum_stat="identity",
33-
progressbar=False,
3443
model=None,
3544
random_seed=-1,
45+
parallel=False,
46+
chains=None,
47+
cores=None,
3648
):
3749
r"""
3850
Sequential Monte Carlo based sampling
@@ -49,15 +61,9 @@ def sample_smc(
4961
The number of steps of each Markov Chain. If ``tune_steps == True`` ``n_steps`` will be used
5062
for the first stage and for the others it will be determined automatically based on the
5163
acceptance rate and `p_acc_rate`, the max number of steps is ``n_steps``.
52-
parallel: bool
53-
Distribute computations across cores if the number of cores is larger than 1.
54-
Defaults to False.
5564
start: dict, or array of dict
5665
Starting point in parameter space. It should be a list of dict with length `chains`.
5766
When None (default) the starting point is sampled from the prior distribution.
58-
cores: int
59-
The number of chains to run in parallel. If ``None`` (default), it will be automatically
60-
set to the number of CPUs in the system.
6167
tune_steps: bool
6268
Whether to compute the number of steps automatically or not. Defaults to True
6369
p_acc_rate: float
@@ -75,11 +81,19 @@ def sample_smc(
7581
sum_stat: str or callable
7682
Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``, ``median``.
7783
If a callable is based it should return a number or a 1d numpy array.
78-
progressbar: bool
79-
Flag for displaying a progress bar. Defaults to False.
8084
model: Model (optional if in ``with`` context)).
8185
random_seed: int
8286
random seed
87+
parallel: bool
88+
Distribute computations across cores if the number of cores is larger than 1.
89+
Defaults to False.
90+
cores : int
91+
The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
92+
system, but at most 4.
93+
chains : int
94+
The number of chains to sample. Running independent chains is important for some
95+
convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
96+
is larger.
8397
8498
Notes
8599
-----
@@ -126,52 +140,126 @@ def sample_smc(
126140
%282007%29133:7%28816%29>`__
127141
"""
128142

143+
_log = logging.getLogger("pymc3")
144+
_log.info("Initializing SMC sampler...")
145+
146+
if cores is None:
147+
cores = _cpu_count()
148+
149+
if chains is None:
150+
chains = max(2, cores)
151+
152+
_log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)")
153+
154+
if random_seed == -1:
155+
random_seed = None
156+
if chains == 1 and isinstance(random_seed, int):
157+
random_seed = [random_seed]
158+
if random_seed is None or isinstance(random_seed, int):
159+
if random_seed is not None:
160+
np.random.seed(random_seed)
161+
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
162+
if not isinstance(random_seed, Iterable):
163+
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
164+
165+
if kernel.lower() == "abc":
166+
warnings.warn(EXPERIMENTAL_WARNING)
167+
if len(modelcontext(model).observed_RVs) != 1:
168+
warnings.warn("SMC-ABC only works properly with models with one observed variable")
169+
170+
params = (
171+
draws,
172+
kernel,
173+
n_steps,
174+
start,
175+
tune_steps,
176+
p_acc_rate,
177+
threshold,
178+
epsilon,
179+
dist_func,
180+
sum_stat,
181+
model,
182+
)
183+
184+
t1 = time.time()
185+
if parallel:
186+
loggers = [_log] + [None] * (chains - 1)
187+
pool = mp.Pool(cores)
188+
results = pool.starmap(
189+
sample_smc_int, [(*params, random_seed[i], i, loggers[i]) for i in range(chains)]
190+
)
191+
192+
pool.close()
193+
pool.join()
194+
else:
195+
results = []
196+
for i in range(chains):
197+
results.append((sample_smc_int(*params, random_seed[i], i, _log)))
198+
199+
traces, log_marginal_likelihoods, betas, accept_ratios, nsteps = zip(*results)
200+
trace = MultiTrace(traces)
201+
trace.report._n_draws = draws
202+
trace.report._n_tune = 0
203+
trace.report._t_sampling = time.time() - t1
204+
trace.report.log_marginal_likelihood = np.array(log_marginal_likelihoods)
205+
trace.report.betas = betas
206+
trace.report.accept_ratios = accept_ratios
207+
trace.report.nsteps = nsteps
208+
209+
return trace
210+
211+
212+
def sample_smc_int(
213+
draws,
214+
kernel,
215+
n_steps,
216+
start,
217+
tune_steps,
218+
p_acc_rate,
219+
threshold,
220+
epsilon,
221+
dist_func,
222+
sum_stat,
223+
model,
224+
random_seed,
225+
chain,
226+
_log,
227+
):
228+
129229
smc = SMC(
130230
draws=draws,
131231
kernel=kernel,
132232
n_steps=n_steps,
133-
parallel=parallel,
134233
start=start,
135-
cores=cores,
136234
tune_steps=tune_steps,
137235
p_acc_rate=p_acc_rate,
138236
threshold=threshold,
139237
epsilon=epsilon,
140238
dist_func=dist_func,
141239
sum_stat=sum_stat,
142-
progressbar=progressbar,
143240
model=model,
144241
random_seed=random_seed,
242+
chain=chain,
145243
)
146-
147-
t1 = time.time()
148-
_log = logging.getLogger("pymc3")
149-
_log.info("Sample initial stage: ...")
150244
stage = 0
245+
betas = []
246+
accept_ratios = []
247+
nsteps = []
151248
smc.initialize_population()
152249
smc.setup_kernel()
153250
smc.initialize_logp()
154251

155252
while smc.beta < 1:
156253
smc.update_weights_beta()
157-
_log.info(
158-
"Stage: {:3d} Beta: {:.3f} Steps: {:3d} Acce: {:.3f}".format(
159-
stage, smc.beta, smc.n_steps, smc.acc_rate
160-
)
161-
)
162-
smc.resample()
254+
if _log is not None:
255+
_log.info(f"Stage: {stage:3d} Beta: {smc.beta:.3f}")
163256
smc.update_proposal()
164-
if stage > 0:
165-
smc.tune()
257+
smc.resample()
166258
smc.mutate()
259+
smc.tune()
167260
stage += 1
261+
betas.append(smc.beta)
262+
accept_ratios.append(smc.acc_rate)
263+
nsteps.append(smc.n_steps)
168264

169-
if smc.parallel and smc.cores > 1:
170-
smc.pool.close()
171-
smc.pool.join()
172-
173-
trace = smc.posterior_to_trace()
174-
trace.report._n_draws = smc.draws
175-
trace.report._n_tune = 0
176-
trace.report._t_sampling = time.time() - t1
177-
return trace
265+
return smc.posterior_to_trace(), smc.log_marginal_likelihood, betas, accept_ratios, nsteps

0 commit comments

Comments
 (0)