14
14
15
15
import time
16
16
import logging
17
+ import warnings
18
+ from collections .abc import Iterable
19
+ import multiprocessing as mp
20
+ import numpy as np
21
+
17
22
from .smc import SMC
23
+ from ..model import modelcontext
24
+ from ..backends .base import MultiTrace
25
+ from ..parallel_sampling import _cpu_count
26
+
27
+ EXPERIMENTAL_WARNING = (
28
+ "Warning: SMC-ABC is an experimental step method and not yet recommended for use in PyMC3!"
29
+ )
18
30
19
31
20
32
def sample_smc (
21
- draws = 1000 ,
33
+ draws = 2000 ,
22
34
kernel = "metropolis" ,
23
35
n_steps = 25 ,
24
- parallel = False ,
25
36
start = None ,
26
- cores = None ,
27
37
tune_steps = True ,
28
38
p_acc_rate = 0.99 ,
29
39
threshold = 0.5 ,
30
40
epsilon = 1.0 ,
31
41
dist_func = "gaussian_kernel" ,
32
42
sum_stat = "identity" ,
33
- progressbar = False ,
34
43
model = None ,
35
44
random_seed = - 1 ,
45
+ parallel = False ,
46
+ chains = None ,
47
+ cores = None ,
36
48
):
37
49
r"""
38
50
Sequential Monte Carlo based sampling
@@ -49,15 +61,9 @@ def sample_smc(
49
61
The number of steps of each Markov Chain. If ``tune_steps == True`` ``n_steps`` will be used
50
62
for the first stage and for the others it will be determined automatically based on the
51
63
acceptance rate and `p_acc_rate`, the max number of steps is ``n_steps``.
52
- parallel: bool
53
- Distribute computations across cores if the number of cores is larger than 1.
54
- Defaults to False.
55
64
start: dict, or array of dict
56
65
Starting point in parameter space. It should be a list of dict with length `chains`.
57
66
When None (default) the starting point is sampled from the prior distribution.
58
- cores: int
59
- The number of chains to run in parallel. If ``None`` (default), it will be automatically
60
- set to the number of CPUs in the system.
61
67
tune_steps: bool
62
68
Whether to compute the number of steps automatically or not. Defaults to True
63
69
p_acc_rate: float
@@ -75,11 +81,19 @@ def sample_smc(
75
81
sum_stat: str or callable
76
82
Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``, ``median``.
77
83
If a callable is based it should return a number or a 1d numpy array.
78
- progressbar: bool
79
- Flag for displaying a progress bar. Defaults to False.
80
84
model: Model (optional if in ``with`` context)).
81
85
random_seed: int
82
86
random seed
87
+ parallel: bool
88
+ Distribute computations across cores if the number of cores is larger than 1.
89
+ Defaults to False.
90
+ cores : int
91
+ The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
92
+ system, but at most 4.
93
+ chains : int
94
+ The number of chains to sample. Running independent chains is important for some
95
+ convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
96
+ is larger.
83
97
84
98
Notes
85
99
-----
@@ -126,52 +140,126 @@ def sample_smc(
126
140
%282007%29133:7%28816%29>`__
127
141
"""
128
142
143
+ _log = logging .getLogger ("pymc3" )
144
+ _log .info ("Initializing SMC sampler..." )
145
+
146
+ if cores is None :
147
+ cores = _cpu_count ()
148
+
149
+ if chains is None :
150
+ chains = max (2 , cores )
151
+
152
+ _log .info (f"Multiprocess sampling ({ chains } chains in { cores } jobs)" )
153
+
154
+ if random_seed == - 1 :
155
+ random_seed = None
156
+ if chains == 1 and isinstance (random_seed , int ):
157
+ random_seed = [random_seed ]
158
+ if random_seed is None or isinstance (random_seed , int ):
159
+ if random_seed is not None :
160
+ np .random .seed (random_seed )
161
+ random_seed = [np .random .randint (2 ** 30 ) for _ in range (chains )]
162
+ if not isinstance (random_seed , Iterable ):
163
+ raise TypeError ("Invalid value for `random_seed`. Must be tuple, list or int" )
164
+
165
+ if kernel .lower () == "abc" :
166
+ warnings .warn (EXPERIMENTAL_WARNING )
167
+ if len (modelcontext (model ).observed_RVs ) != 1 :
168
+ warnings .warn ("SMC-ABC only works properly with models with one observed variable" )
169
+
170
+ params = (
171
+ draws ,
172
+ kernel ,
173
+ n_steps ,
174
+ start ,
175
+ tune_steps ,
176
+ p_acc_rate ,
177
+ threshold ,
178
+ epsilon ,
179
+ dist_func ,
180
+ sum_stat ,
181
+ model ,
182
+ )
183
+
184
+ t1 = time .time ()
185
+ if parallel :
186
+ loggers = [_log ] + [None ] * (chains - 1 )
187
+ pool = mp .Pool (cores )
188
+ results = pool .starmap (
189
+ sample_smc_int , [(* params , random_seed [i ], i , loggers [i ]) for i in range (chains )]
190
+ )
191
+
192
+ pool .close ()
193
+ pool .join ()
194
+ else :
195
+ results = []
196
+ for i in range (chains ):
197
+ results .append ((sample_smc_int (* params , random_seed [i ], i , _log )))
198
+
199
+ traces , log_marginal_likelihoods , betas , accept_ratios , nsteps = zip (* results )
200
+ trace = MultiTrace (traces )
201
+ trace .report ._n_draws = draws
202
+ trace .report ._n_tune = 0
203
+ trace .report ._t_sampling = time .time () - t1
204
+ trace .report .log_marginal_likelihood = np .array (log_marginal_likelihoods )
205
+ trace .report .betas = betas
206
+ trace .report .accept_ratios = accept_ratios
207
+ trace .report .nsteps = nsteps
208
+
209
+ return trace
210
+
211
+
212
+ def sample_smc_int (
213
+ draws ,
214
+ kernel ,
215
+ n_steps ,
216
+ start ,
217
+ tune_steps ,
218
+ p_acc_rate ,
219
+ threshold ,
220
+ epsilon ,
221
+ dist_func ,
222
+ sum_stat ,
223
+ model ,
224
+ random_seed ,
225
+ chain ,
226
+ _log ,
227
+ ):
228
+
129
229
smc = SMC (
130
230
draws = draws ,
131
231
kernel = kernel ,
132
232
n_steps = n_steps ,
133
- parallel = parallel ,
134
233
start = start ,
135
- cores = cores ,
136
234
tune_steps = tune_steps ,
137
235
p_acc_rate = p_acc_rate ,
138
236
threshold = threshold ,
139
237
epsilon = epsilon ,
140
238
dist_func = dist_func ,
141
239
sum_stat = sum_stat ,
142
- progressbar = progressbar ,
143
240
model = model ,
144
241
random_seed = random_seed ,
242
+ chain = chain ,
145
243
)
146
-
147
- t1 = time .time ()
148
- _log = logging .getLogger ("pymc3" )
149
- _log .info ("Sample initial stage: ..." )
150
244
stage = 0
245
+ betas = []
246
+ accept_ratios = []
247
+ nsteps = []
151
248
smc .initialize_population ()
152
249
smc .setup_kernel ()
153
250
smc .initialize_logp ()
154
251
155
252
while smc .beta < 1 :
156
253
smc .update_weights_beta ()
157
- _log .info (
158
- "Stage: {:3d} Beta: {:.3f} Steps: {:3d} Acce: {:.3f}" .format (
159
- stage , smc .beta , smc .n_steps , smc .acc_rate
160
- )
161
- )
162
- smc .resample ()
254
+ if _log is not None :
255
+ _log .info (f"Stage: { stage :3d} Beta: { smc .beta :.3f} " )
163
256
smc .update_proposal ()
164
- if stage > 0 :
165
- smc .tune ()
257
+ smc .resample ()
166
258
smc .mutate ()
259
+ smc .tune ()
167
260
stage += 1
261
+ betas .append (smc .beta )
262
+ accept_ratios .append (smc .acc_rate )
263
+ nsteps .append (smc .n_steps )
168
264
169
- if smc .parallel and smc .cores > 1 :
170
- smc .pool .close ()
171
- smc .pool .join ()
172
-
173
- trace = smc .posterior_to_trace ()
174
- trace .report ._n_draws = smc .draws
175
- trace .report ._n_tune = 0
176
- trace .report ._t_sampling = time .time () - t1
177
- return trace
265
+ return smc .posterior_to_trace (), smc .log_marginal_likelihood , betas , accept_ratios , nsteps
0 commit comments