Description
Before
No response
After
with pm.Model():
...
pm.sample(..., checkpoint_file=some_path, checkpoint_freq=10)
Context for the issue:
If one has models that take very long to sample, it would be great to have a way to store the information of the steppers
in a checkpoint file so that if something happens and sampling stops, we could pick up from where we left off. This is a very old feature request that is related to #292, #143 and #3661.
Those issues talk about iter_sample
that works as a generator that one could simply pause and resume later. The problem with that is that there is no access to the stepper's state. I think that we need two things to get the samplers warm started:
- The trace that was collected so far
- The step method's state
Currently, most samplers and step methods provide some ways to get 1 but we never have access to 2. The current pymc samplers have a bunch of KeyboardInterrupt
catches (here, here, here, and here). We could add a handling call there to also store the step method's state. nutpie
has the non-blocking sampling with an abort
function call when KeyboardInterrupt
gets hit. We could maybe add a similar state recording thing there. blackjax
has its progress bar conditional steps which we could try to mimic to get the same effect. numpyro
has a similar thing going with the progress bar but it looks like it's way deeper than with blackjax
.
All of this to say that I think that we need to define some kind of standard way for the samplers to provide their state information. The specific samplers would then have to conform to the standard using whatever internal things they need. For pymc
samplers it would be some way to recreate the step methods (maybe using some kind of __setstate__
and __getstate__
), for nutpie
it would have to be some new datatype that could be sent into ruff, for blackjax
it could be the kernel and random keys. I think that the important thing is to get the standard approach to which samplers should conform to, and once we have those, we could build support for checkpoints and restarting sampling from them later.