Skip to content

ENH: Add checkpoints during sampling #7503

Open
@lucianopaz

Description

@lucianopaz

Before

No response

After

with pm.Model():
    ...
    pm.sample(..., checkpoint_file=some_path, checkpoint_freq=10)

Context for the issue:

If one has models that take very long to sample, it would be great to have a way to store the information of the steppers in a checkpoint file so that if something happens and sampling stops, we could pick up from where we left off. This is a very old feature request that is related to #292, #143 and #3661.

Those issues talk about iter_sample that works as a generator that one could simply pause and resume later. The problem with that is that there is no access to the stepper's state. I think that we need two things to get the samplers warm started:

  1. The trace that was collected so far
  2. The step method's state

Currently, most samplers and step methods provide some ways to get 1 but we never have access to 2. The current pymc samplers have a bunch of KeyboardInterrupt catches (here, here, here, and here). We could add a handling call there to also store the step method's state. nutpie has the non-blocking sampling with an abort function call when KeyboardInterrupt gets hit. We could maybe add a similar state recording thing there. blackjax has its progress bar conditional steps which we could try to mimic to get the same effect. numpyro has a similar thing going with the progress bar but it looks like it's way deeper than with blackjax.

All of this to say that I think that we need to define some kind of standard way for the samplers to provide their state information. The specific samplers would then have to conform to the standard using whatever internal things they need. For pymc samplers it would be some way to recreate the step methods (maybe using some kind of __setstate__ and __getstate__), for nutpie it would have to be some new datatype that could be sent into ruff, for blackjax it could be the kernel and random keys. I think that the important thing is to get the standard approach to which samplers should conform to, and once we have those, we could build support for checkpoints and restarting sampling from them later.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions