17
17
import ctypes
18
18
import time
19
19
import logging
20
+ import pickle
20
21
from collections import namedtuple
21
22
import traceback
23
+ import platform
22
24
from pymc3 .exceptions import SamplingError
23
- import errno
24
25
25
26
import numpy as np
26
27
from fastprogress .fastprogress import progress_bar
30
31
logger = logging .getLogger ("pymc3" )
31
32
32
33
33
- def _get_broken_pipe_exception ():
34
- import sys
35
-
36
- if sys .platform == "win32" :
37
- return RuntimeError (
38
- "The communication pipe between the main process "
39
- "and its spawned children is broken.\n "
40
- "In Windows OS, this usually means that the child "
41
- "process raised an exception while it was being "
42
- "spawned, before it was setup to communicate to "
43
- "the main process.\n "
44
- "The exceptions raised by the child process while "
45
- "spawning cannot be caught or handled from the "
46
- "main process, and when running from an IPython or "
47
- "jupyter notebook interactive kernel, the child's "
48
- "exception and traceback appears to be lost.\n "
49
- "A known way to see the child's error, and try to "
50
- "fix or handle it, is to run the problematic code "
51
- "as a batch script from a system's Command Prompt. "
52
- "The child's exception will be printed to the "
53
- "Command Promt's stderr, and it should be visible "
54
- "above this error and traceback.\n "
55
- "Note that if running a jupyter notebook that was "
56
- "invoked from a Command Prompt, the child's "
57
- "exception should have been printed to the Command "
58
- "Prompt on which the notebook is running."
59
- )
60
- else :
61
- return None
62
-
63
-
64
34
class ParallelSamplingError (Exception ):
65
35
def __init__ (self , message , chain , warnings = None ):
66
36
super ().__init__ (message )
@@ -104,26 +74,65 @@ def rebuild_exc(exc, tb):
104
74
# ('start',)
105
75
106
76
107
- class _Process ( multiprocessing . Process ) :
77
+ class _Process :
108
78
"""Seperate process for each chain.
109
79
We communicate with the main process using a pipe,
110
80
and send finished samples using shared memory.
111
81
"""
112
82
113
- def __init__ (self , name :str , msg_pipe , step_method , shared_point , draws :int , tune :int , seed ):
114
- super ().__init__ (daemon = True , name = name )
83
+ def __init__ (
84
+ self ,
85
+ name : str ,
86
+ msg_pipe ,
87
+ step_method ,
88
+ step_method_is_pickled ,
89
+ shared_point ,
90
+ draws : int ,
91
+ tune : int ,
92
+ seed ,
93
+ pickle_backend ,
94
+ ):
115
95
self ._msg_pipe = msg_pipe
116
96
self ._step_method = step_method
97
+ self ._step_method_is_pickled = step_method_is_pickled
117
98
self ._shared_point = shared_point
118
99
self ._seed = seed
119
100
self ._tt_seed = seed + 1
120
101
self ._draws = draws
121
102
self ._tune = tune
103
+ self ._pickle_backend = pickle_backend
104
+
105
+ def _unpickle_step_method (self ):
106
+ unpickle_error = (
107
+ "The model could not be unpickled. This is required for sampling "
108
+ "with more than one core and multiprocessing context spawn "
109
+ "or forkserver."
110
+ )
111
+ if self ._step_method_is_pickled :
112
+ if self ._pickle_backend == 'pickle' :
113
+ try :
114
+ self ._step_method = pickle .loads (self ._step_method )
115
+ except Exception :
116
+ raise ValueError (unpickle_error )
117
+ elif self ._pickle_backend == 'dill' :
118
+ try :
119
+ import dill
120
+ except ImportError :
121
+ raise ValueError (
122
+ "dill must be installed for pickle_backend='dill'."
123
+ )
124
+ try :
125
+ self ._step_method = dill .loads (self ._step_method )
126
+ except Exception :
127
+ raise ValueError (unpickle_error )
128
+ else :
129
+ raise ValueError ("Unknown pickle backend" )
122
130
123
131
def run (self ):
124
132
try :
125
133
# We do not create this in __init__, as pickling this
126
134
# would destroy the shared memory.
135
+ self ._unpickle_step_method ()
127
136
self ._point = self ._make_numpy_refs ()
128
137
self ._start_loop ()
129
138
except KeyboardInterrupt :
@@ -219,10 +228,25 @@ def _collect_warnings(self):
219
228
return []
220
229
221
230
231
+ def _run_process (* args ):
232
+ _Process (* args ).run ()
233
+
234
+
222
235
class ProcessAdapter :
223
236
"""Control a Chain process from the main thread."""
224
237
225
- def __init__ (self , draws :int , tune :int , step_method , chain :int , seed , start ):
238
+ def __init__ (
239
+ self ,
240
+ draws : int ,
241
+ tune : int ,
242
+ step_method ,
243
+ step_method_pickled ,
244
+ chain : int ,
245
+ seed ,
246
+ start ,
247
+ mp_ctx ,
248
+ pickle_backend ,
249
+ ):
226
250
self .chain = chain
227
251
process_name = "worker_chain_%s" % chain
228
252
self ._msg_pipe , remote_conn = multiprocessing .Pipe ()
@@ -237,7 +261,7 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
237
261
if size != ctypes .c_size_t (size ).value :
238
262
raise ValueError ("Variable %s is too large" % name )
239
263
240
- array = multiprocessing . sharedctypes .RawArray ("c" , size )
264
+ array = mp_ctx .RawArray ("c" , size )
241
265
self ._shared_point [name ] = array
242
266
array_np = np .frombuffer (array , dtype ).reshape (shape )
243
267
array_np [...] = start [name ]
@@ -246,27 +270,31 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
246
270
self ._readable = True
247
271
self ._num_samples = 0
248
272
249
- self ._process = _Process (
250
- process_name ,
251
- remote_conn ,
252
- step_method ,
253
- self ._shared_point ,
254
- draws ,
255
- tune ,
256
- seed ,
273
+ if step_method_pickled is not None :
274
+ step_method_send = step_method_pickled
275
+ else :
276
+ step_method_send = step_method
277
+
278
+ self ._process = mp_ctx .Process (
279
+ daemon = True ,
280
+ name = process_name ,
281
+ target = _run_process ,
282
+ args = (
283
+ process_name ,
284
+ remote_conn ,
285
+ step_method_send ,
286
+ step_method_pickled is not None ,
287
+ self ._shared_point ,
288
+ draws ,
289
+ tune ,
290
+ seed ,
291
+ pickle_backend ,
292
+ )
257
293
)
258
- try :
259
- self ._process .start ()
260
- except IOError as e :
261
- # Something may have gone wrong during the fork / spawn
262
- if e .errno == errno .EPIPE :
263
- exc = _get_broken_pipe_exception ()
264
- if exc is not None :
265
- # Sleep a little to give the child process time to flush
266
- # all its error message
267
- time .sleep (0.2 )
268
- raise exc
269
- raise
294
+ self ._process .start ()
295
+ # Close the remote pipe, so that we get notified if the other
296
+ # end is closed.
297
+ remote_conn .close ()
270
298
271
299
@property
272
300
def shared_point_view (self ):
@@ -277,15 +305,38 @@ def shared_point_view(self):
277
305
raise RuntimeError ()
278
306
return self ._point
279
307
308
+ def _send (self , msg , * args ):
309
+ try :
310
+ self ._msg_pipe .send ((msg , * args ))
311
+ except Exception :
312
+ # try to recive an error message
313
+ message = None
314
+ try :
315
+ message = self ._msg_pipe .recv ()
316
+ except Exception :
317
+ pass
318
+ if message is not None and message [0 ] == "error" :
319
+ warns , old_error = message [1 :]
320
+ if warns is not None :
321
+ error = ParallelSamplingError (
322
+ str (old_error ),
323
+ self .chain ,
324
+ warns
325
+ )
326
+ else :
327
+ error = RuntimeError ("Chain %s failed." % self .chain )
328
+ raise error from old_error
329
+ raise
330
+
280
331
def start (self ):
281
- self ._msg_pipe . send (( "start" ,) )
332
+ self ._send ( "start" )
282
333
283
334
def write_next (self ):
284
335
self ._readable = False
285
- self ._msg_pipe . send (( "write_next" ,) )
336
+ self ._send ( "write_next" )
286
337
287
338
def abort (self ):
288
- self ._msg_pipe . send (( "abort" ,) )
339
+ self ._send ( "abort" )
289
340
290
341
def join (self , timeout = None ):
291
342
self ._process .join (timeout )
@@ -324,7 +375,7 @@ def terminate_all(processes, patience=2):
324
375
for process in processes :
325
376
try :
326
377
process .abort ()
327
- except EOFError :
378
+ except Exception :
328
379
pass
329
380
330
381
start_time = time .time ()
@@ -353,23 +404,52 @@ def terminate_all(processes, patience=2):
353
404
class ParallelSampler :
354
405
def __init__ (
355
406
self ,
356
- draws :int ,
357
- tune :int ,
358
- chains :int ,
359
- cores :int ,
360
- seeds :list ,
361
- start_points :list ,
407
+ draws : int ,
408
+ tune : int ,
409
+ chains : int ,
410
+ cores : int ,
411
+ seeds : list ,
412
+ start_points : list ,
362
413
step_method ,
363
- start_chain_num :int = 0 ,
364
- progressbar :bool = True ,
414
+ start_chain_num : int = 0 ,
415
+ progressbar : bool = True ,
416
+ mp_ctx = None ,
417
+ pickle_backend : str = 'pickle' ,
365
418
):
366
419
367
420
if any (len (arg ) != chains for arg in [seeds , start_points ]):
368
421
raise ValueError ("Number of seeds and start_points must be %s." % chains )
369
422
423
+ if mp_ctx is None or isinstance (mp_ctx , str ):
424
+ # Closes issue https://github.com/pymc-devs/pymc3/issues/3849
425
+ if platform .system () == 'Darwin' :
426
+ mp_ctx = "forkserver"
427
+ mp_ctx = multiprocessing .get_context (mp_ctx )
428
+
429
+ step_method_pickled = None
430
+ if mp_ctx .get_start_method () != 'fork' :
431
+ if pickle_backend == 'pickle' :
432
+ step_method_pickled = pickle .dumps (step_method , protocol = - 1 )
433
+ elif pickle_backend == 'dill' :
434
+ try :
435
+ import dill
436
+ except ImportError :
437
+ raise ValueError (
438
+ "dill must be installed for pickle_backend='dill'."
439
+ )
440
+ step_method_pickled = dill .dumps (step_method , protocol = - 1 )
441
+
370
442
self ._samplers = [
371
443
ProcessAdapter (
372
- draws , tune , step_method , chain + start_chain_num , seed , start
444
+ draws ,
445
+ tune ,
446
+ step_method ,
447
+ step_method_pickled ,
448
+ chain + start_chain_num ,
449
+ seed ,
450
+ start ,
451
+ mp_ctx ,
452
+ pickle_backend
373
453
)
374
454
for chain , seed , start in zip (range (chains ), seeds , start_points )
375
455
]
0 commit comments