Skip to content

Commit 90f48ed

Browse files
Show pickling issues in notebook on windows (#3991)
* Merge close remote connection * Manually pickle step method in multiprocess sampling * Fix tests for extra divergence info * Add test for remote process crash * Better formatting in test_parallel_sampling Co-authored-by: Junpeng Lao <[email protected]> * Use mp_ctx forkserver on MacOS * Add test for pickle with dill Co-authored-by: Junpeng Lao <[email protected]>
1 parent 77873e9 commit 90f48ed

6 files changed

+254
-83
lines changed

RELEASE-NOTES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
# Release Notes
22

33
## PyMC3 3.9.x (on deck)
4+
5+
### Maintenance
6+
- Fix an error on Windows and Mac where error message from unpickling models did not show up in the notebook, or where sampling froze when a worker process crashed (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991)).
7+
48
### Documentation
59
- Notebook on [multilevel modeling](https://docs.pymc.io/notebooks/multilevel_modeling.html) has been rewritten to showcase ArviZ and xarray usage for inference result analysis (see [#3963](https://github.com/pymc-devs/pymc3/pull/3963))
610

711
### New features
12+
- Introduce optional arguments to `pm.sample`: `mp_ctx` to control how the processes for parallel sampling are started, and `pickle_backend` to specify which library is used to pickle models in parallel sampling when the multiprocessing cnotext is not of type `fork`. (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991))
813
- Add sampler stats `process_time_diff`, `perf_counter_diff` and `perf_counter_start`, that record wall and CPU times for each NUTS and HMC sample (see [ #3986](https://github.com/pymc-devs/pymc3/pull/3986)).
914

1015
## PyMC3 3.9.2 (24 June 2020)

pymc3/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,6 @@
2727
handler = logging.StreamHandler()
2828
_log.addHandler(handler)
2929

30-
# Set start method to forkserver for MacOS to enable multiprocessing
31-
# Closes issue https://github.com/pymc-devs/pymc3/issues/3849
32-
sys = platform.system()
33-
if sys == "Darwin":
34-
new_context = mp.get_context("forkserver")
35-
3630

3731
def __set_compiler_flags():
3832
# Workarounds for Theano compiler problems on various platforms

pymc3/parallel_sampling.py

Lines changed: 150 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
import ctypes
1818
import time
1919
import logging
20+
import pickle
2021
from collections import namedtuple
2122
import traceback
23+
import platform
2224
from pymc3.exceptions import SamplingError
23-
import errno
2425

2526
import numpy as np
2627
from fastprogress.fastprogress import progress_bar
@@ -30,37 +31,6 @@
3031
logger = logging.getLogger("pymc3")
3132

3233

33-
def _get_broken_pipe_exception():
34-
import sys
35-
36-
if sys.platform == "win32":
37-
return RuntimeError(
38-
"The communication pipe between the main process "
39-
"and its spawned children is broken.\n"
40-
"In Windows OS, this usually means that the child "
41-
"process raised an exception while it was being "
42-
"spawned, before it was setup to communicate to "
43-
"the main process.\n"
44-
"The exceptions raised by the child process while "
45-
"spawning cannot be caught or handled from the "
46-
"main process, and when running from an IPython or "
47-
"jupyter notebook interactive kernel, the child's "
48-
"exception and traceback appears to be lost.\n"
49-
"A known way to see the child's error, and try to "
50-
"fix or handle it, is to run the problematic code "
51-
"as a batch script from a system's Command Prompt. "
52-
"The child's exception will be printed to the "
53-
"Command Promt's stderr, and it should be visible "
54-
"above this error and traceback.\n"
55-
"Note that if running a jupyter notebook that was "
56-
"invoked from a Command Prompt, the child's "
57-
"exception should have been printed to the Command "
58-
"Prompt on which the notebook is running."
59-
)
60-
else:
61-
return None
62-
63-
6434
class ParallelSamplingError(Exception):
6535
def __init__(self, message, chain, warnings=None):
6636
super().__init__(message)
@@ -104,26 +74,65 @@ def rebuild_exc(exc, tb):
10474
# ('start',)
10575

10676

107-
class _Process(multiprocessing.Process):
77+
class _Process:
10878
"""Seperate process for each chain.
10979
We communicate with the main process using a pipe,
11080
and send finished samples using shared memory.
11181
"""
11282

113-
def __init__(self, name:str, msg_pipe, step_method, shared_point, draws:int, tune:int, seed):
114-
super().__init__(daemon=True, name=name)
83+
def __init__(
84+
self,
85+
name: str,
86+
msg_pipe,
87+
step_method,
88+
step_method_is_pickled,
89+
shared_point,
90+
draws: int,
91+
tune: int,
92+
seed,
93+
pickle_backend,
94+
):
11595
self._msg_pipe = msg_pipe
11696
self._step_method = step_method
97+
self._step_method_is_pickled = step_method_is_pickled
11798
self._shared_point = shared_point
11899
self._seed = seed
119100
self._tt_seed = seed + 1
120101
self._draws = draws
121102
self._tune = tune
103+
self._pickle_backend = pickle_backend
104+
105+
def _unpickle_step_method(self):
106+
unpickle_error = (
107+
"The model could not be unpickled. This is required for sampling "
108+
"with more than one core and multiprocessing context spawn "
109+
"or forkserver."
110+
)
111+
if self._step_method_is_pickled:
112+
if self._pickle_backend == 'pickle':
113+
try:
114+
self._step_method = pickle.loads(self._step_method)
115+
except Exception:
116+
raise ValueError(unpickle_error)
117+
elif self._pickle_backend == 'dill':
118+
try:
119+
import dill
120+
except ImportError:
121+
raise ValueError(
122+
"dill must be installed for pickle_backend='dill'."
123+
)
124+
try:
125+
self._step_method = dill.loads(self._step_method)
126+
except Exception:
127+
raise ValueError(unpickle_error)
128+
else:
129+
raise ValueError("Unknown pickle backend")
122130

123131
def run(self):
124132
try:
125133
# We do not create this in __init__, as pickling this
126134
# would destroy the shared memory.
135+
self._unpickle_step_method()
127136
self._point = self._make_numpy_refs()
128137
self._start_loop()
129138
except KeyboardInterrupt:
@@ -219,10 +228,25 @@ def _collect_warnings(self):
219228
return []
220229

221230

231+
def _run_process(*args):
232+
_Process(*args).run()
233+
234+
222235
class ProcessAdapter:
223236
"""Control a Chain process from the main thread."""
224237

225-
def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
238+
def __init__(
239+
self,
240+
draws: int,
241+
tune: int,
242+
step_method,
243+
step_method_pickled,
244+
chain: int,
245+
seed,
246+
start,
247+
mp_ctx,
248+
pickle_backend,
249+
):
226250
self.chain = chain
227251
process_name = "worker_chain_%s" % chain
228252
self._msg_pipe, remote_conn = multiprocessing.Pipe()
@@ -237,7 +261,7 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
237261
if size != ctypes.c_size_t(size).value:
238262
raise ValueError("Variable %s is too large" % name)
239263

240-
array = multiprocessing.sharedctypes.RawArray("c", size)
264+
array = mp_ctx.RawArray("c", size)
241265
self._shared_point[name] = array
242266
array_np = np.frombuffer(array, dtype).reshape(shape)
243267
array_np[...] = start[name]
@@ -246,27 +270,31 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
246270
self._readable = True
247271
self._num_samples = 0
248272

249-
self._process = _Process(
250-
process_name,
251-
remote_conn,
252-
step_method,
253-
self._shared_point,
254-
draws,
255-
tune,
256-
seed,
273+
if step_method_pickled is not None:
274+
step_method_send = step_method_pickled
275+
else:
276+
step_method_send = step_method
277+
278+
self._process = mp_ctx.Process(
279+
daemon=True,
280+
name=process_name,
281+
target=_run_process,
282+
args=(
283+
process_name,
284+
remote_conn,
285+
step_method_send,
286+
step_method_pickled is not None,
287+
self._shared_point,
288+
draws,
289+
tune,
290+
seed,
291+
pickle_backend,
292+
)
257293
)
258-
try:
259-
self._process.start()
260-
except IOError as e:
261-
# Something may have gone wrong during the fork / spawn
262-
if e.errno == errno.EPIPE:
263-
exc = _get_broken_pipe_exception()
264-
if exc is not None:
265-
# Sleep a little to give the child process time to flush
266-
# all its error message
267-
time.sleep(0.2)
268-
raise exc
269-
raise
294+
self._process.start()
295+
# Close the remote pipe, so that we get notified if the other
296+
# end is closed.
297+
remote_conn.close()
270298

271299
@property
272300
def shared_point_view(self):
@@ -277,15 +305,38 @@ def shared_point_view(self):
277305
raise RuntimeError()
278306
return self._point
279307

308+
def _send(self, msg, *args):
309+
try:
310+
self._msg_pipe.send((msg, *args))
311+
except Exception:
312+
# try to recive an error message
313+
message = None
314+
try:
315+
message = self._msg_pipe.recv()
316+
except Exception:
317+
pass
318+
if message is not None and message[0] == "error":
319+
warns, old_error = message[1:]
320+
if warns is not None:
321+
error = ParallelSamplingError(
322+
str(old_error),
323+
self.chain,
324+
warns
325+
)
326+
else:
327+
error = RuntimeError("Chain %s failed." % self.chain)
328+
raise error from old_error
329+
raise
330+
280331
def start(self):
281-
self._msg_pipe.send(("start",))
332+
self._send("start")
282333

283334
def write_next(self):
284335
self._readable = False
285-
self._msg_pipe.send(("write_next",))
336+
self._send("write_next")
286337

287338
def abort(self):
288-
self._msg_pipe.send(("abort",))
339+
self._send("abort")
289340

290341
def join(self, timeout=None):
291342
self._process.join(timeout)
@@ -324,7 +375,7 @@ def terminate_all(processes, patience=2):
324375
for process in processes:
325376
try:
326377
process.abort()
327-
except EOFError:
378+
except Exception:
328379
pass
329380

330381
start_time = time.time()
@@ -353,23 +404,52 @@ def terminate_all(processes, patience=2):
353404
class ParallelSampler:
354405
def __init__(
355406
self,
356-
draws:int,
357-
tune:int,
358-
chains:int,
359-
cores:int,
360-
seeds:list,
361-
start_points:list,
407+
draws: int,
408+
tune: int,
409+
chains: int,
410+
cores: int,
411+
seeds: list,
412+
start_points: list,
362413
step_method,
363-
start_chain_num:int=0,
364-
progressbar:bool=True,
414+
start_chain_num: int = 0,
415+
progressbar: bool = True,
416+
mp_ctx=None,
417+
pickle_backend: str = 'pickle',
365418
):
366419

367420
if any(len(arg) != chains for arg in [seeds, start_points]):
368421
raise ValueError("Number of seeds and start_points must be %s." % chains)
369422

423+
if mp_ctx is None or isinstance(mp_ctx, str):
424+
# Closes issue https://github.com/pymc-devs/pymc3/issues/3849
425+
if platform.system() == 'Darwin':
426+
mp_ctx = "forkserver"
427+
mp_ctx = multiprocessing.get_context(mp_ctx)
428+
429+
step_method_pickled = None
430+
if mp_ctx.get_start_method() != 'fork':
431+
if pickle_backend == 'pickle':
432+
step_method_pickled = pickle.dumps(step_method, protocol=-1)
433+
elif pickle_backend == 'dill':
434+
try:
435+
import dill
436+
except ImportError:
437+
raise ValueError(
438+
"dill must be installed for pickle_backend='dill'."
439+
)
440+
step_method_pickled = dill.dumps(step_method, protocol=-1)
441+
370442
self._samplers = [
371443
ProcessAdapter(
372-
draws, tune, step_method, chain + start_chain_num, seed, start
444+
draws,
445+
tune,
446+
step_method,
447+
step_method_pickled,
448+
chain + start_chain_num,
449+
seed,
450+
start,
451+
mp_ctx,
452+
pickle_backend
373453
)
374454
for chain, seed, start in zip(range(chains), seeds, start_points)
375455
]

0 commit comments

Comments
 (0)