-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Text and SQLite backends for PyMC3 #449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c1c6a85
3060c24
6dda7e1
683507b
2e80cef
ae77f69
e3152f7
d9b5d2f
caab2e1
40eb2b3
156fc42
978a528
2e40ff5
bcc7633
ad02151
a1c1458
9133a4c
8aa4d9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
"""Backends for traces | ||
|
||
Available backends | ||
------------------ | ||
|
||
1. NumPy array (pymc.backends.NDArray) | ||
2. Text files (pymc.backends.Text) | ||
3. SQLite (pymc.backends.SQLite) | ||
|
||
The NumPy arrays and text files both hold the entire trace in memory, | ||
whereas SQLite commits the trace to the database while sampling. | ||
|
||
Selecting a backend | ||
------------------- | ||
|
||
By default, a NumPy array is used as the backend. To specify a different | ||
backend, pass a backend instance to `sample`. | ||
|
||
For example, the following would save traces to the file 'test.db'. | ||
|
||
>>> import pymc as pm | ||
>>> db = pm.backends.SQLite('test.db') | ||
>>> trace = pm.sample(..., db=db) | ||
|
||
Selecting values from a backend | ||
------------------------------- | ||
|
||
After a backend is finished sampling, it returns a Trace object. Values | ||
can be accessed in a few ways. The easiest way is to index the backend | ||
object with a variable or variable name. | ||
|
||
>>> trace['x'] # or trace[x] | ||
|
||
The call will return a list containing the sampling values for all | ||
chains of `x`. (Each call to `pymc.sample` creates a separate chain of | ||
samples.) | ||
|
||
For more control is needed of which values are returned, the | ||
`get_values` method can be used. The call below will return values from | ||
all chains, burning the first 1000 iterations from each chain. | ||
|
||
>>> trace.get_values('x', burn=1000) | ||
|
||
Setting the `combined` flag will concatenate the results from all the | ||
chains. | ||
|
||
>>> trace.get_values('x', burn=1000, combine=True) | ||
|
||
The `chains` parameter of `get_values` can be used to limit the chains | ||
that are retrieved. To work with a subset of chains without having to | ||
specify `chains` each call, you can set the `active_chains` attribute. | ||
|
||
>>> trace.chains | ||
[0, 1, 2] | ||
>>> trace.active_chains = [0, 2] | ||
|
||
After this, only chains 0 and 2 will be used in operations that work | ||
with multiple chains. | ||
|
||
Similary, the `default_chain` attribute sets which chain is used for | ||
functions that require a single chain (e.g., point). | ||
|
||
>>> trace.point(4) # or trace[4] | ||
|
||
Backends can also suppport slicing the trace object. For example, the | ||
following call would return a new trace object without the first 1000 | ||
sampling iterations for all variables. | ||
|
||
>>> sliced_trace = trace[1000:] | ||
|
||
Loading a saved backend | ||
----------------------- | ||
|
||
Saved backends can be loaded using `load` function in the module for the | ||
specific backend. | ||
|
||
>>> trace = pm.backends.sqlite.load('test.db') | ||
|
||
Writing custom backends | ||
----------------------- | ||
|
||
Backends consist of two classes: one that handles storing the sample | ||
results (e.g., backends.ndarray.NDArray or backends.sqlite.SQLite) and | ||
one that handles value selection (e.g., backends.ndarray.Trace or | ||
backends.sqlite.Trace). | ||
|
||
Three methods of the storage class will be called: | ||
|
||
- setup: Before sampling is started, the `setup` method will be called | ||
with two arguments: the number of draws and the chain number. This is | ||
useful setting up any structure for storing the sampling values that | ||
require the above information. | ||
|
||
- record: Record the sampling results for the current draw. This method | ||
will be called with a dictionary of values mapped to the variable | ||
names. This is the only function that *must* do something to have a | ||
meaningful backend. | ||
|
||
- close: This method is called following sampling and should perform any | ||
actions necessary for finalizing and cleaning up the backend. | ||
|
||
The base storage class `backends.base.Backend` provides model setup that | ||
is used by PyMC backends. | ||
|
||
After sampling has completed, the `trace` attribute of the storage | ||
object will be returned. To have a consistent interface with the backend | ||
trace objects in PyMC, this attribute should be an instance of a class | ||
that inherits from pymc.backends.base.Trace, and several methods in the | ||
inherited Trace object should be defined. | ||
|
||
- get_values: This is the core method for selecting values from the | ||
backend. It can be called directly and is used by __getitem__ when the | ||
backend is indexed with a variable name or object. | ||
|
||
- _slice: Defines how the backend returns a slice of itself. This | ||
is called if the backend is indexed with a slice range. | ||
|
||
- point: Returns values for each variables at a single iteration. This | ||
is called if the backend is indexed with a single integer. | ||
|
||
- __len__: This should return the number of draws (for the default | ||
chain). | ||
|
||
- chains: Property that returns a list of chains | ||
|
||
In addtion, a `merge_chains` method should be defined if the backend | ||
will be used with parallel sampling. This method describes how to merge | ||
sampling chains from a list of other traces. | ||
|
||
As mentioned above, the only method necessary to store the sampling | ||
values is `record`. Other methods in the storage may consist of only a | ||
pass statement. The storage object should have an attribute `trace` | ||
(with a `merge_chains` method for parallel sampling), but this does not | ||
have to do anything if storing the values is all that is desired. The | ||
backends.base.Trace is provided for convenience in setting up a | ||
consistent Trace object. | ||
|
||
For specific examples, see pymc.backends.{ndarray,text,sqlite}.py. | ||
""" | ||
from pymc.backends.ndarray import NDArray | ||
from pymc.backends.text import Text | ||
from pymc.backends.sqlite import SQLite |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
"""Base backend for traces | ||
|
||
See the docstring for pymc.backends for more information (includng | ||
creating custom backends). | ||
""" | ||
import numpy as np | ||
from pymc.model import modelcontext | ||
|
||
|
||
class Backend(object): | ||
"""Base storage class | ||
|
||
Parameters | ||
---------- | ||
name : str | ||
Name of backend. | ||
model : Model | ||
If None, the model is taken from the `with` context. | ||
variables : list of variable objects | ||
Sampling values will be stored for these variables | ||
""" | ||
def __init__(self, name, model=None, variables=None): | ||
self.name = name | ||
|
||
model = modelcontext(model) | ||
if variables is None: | ||
variables = model.unobserved_RVs | ||
self.variables = variables | ||
self.var_names = [str(var) for var in variables] | ||
self.fn = model.fastfn(variables) | ||
|
||
## get variable shapes. common enough that I think most backends | ||
## will use this | ||
var_values = zip(self.var_names, self.fn(model.test_point)) | ||
self.var_shapes = {var: value.shape | ||
for var, value in var_values} | ||
self.chain = None | ||
self.trace = None | ||
|
||
def setup(self, draws, chain): | ||
"""Perform chain-specific setup | ||
|
||
draws : int | ||
Expected number of draws | ||
chain : int | ||
chain number | ||
""" | ||
pass | ||
|
||
def record(self, point): | ||
"""Record results of a sampling iteration | ||
|
||
point : dict | ||
Values mappled to variable names | ||
""" | ||
raise NotImplementedError | ||
|
||
def close(self): | ||
"""Close the database backend | ||
|
||
This is called after sampling has finished. | ||
""" | ||
pass | ||
|
||
|
||
class Trace(object): | ||
""" | ||
Parameters | ||
---------- | ||
var_names : list of strs | ||
Sample variables names | ||
backend : Backend object | ||
|
||
Attributes | ||
---------- | ||
var_names | ||
backend : Backend object | ||
nchains : int | ||
Number of sampling chains | ||
chains : list of ints | ||
List of sampling chain numbers | ||
default_chain : int | ||
Chain to be used if single chain requested | ||
active_chains : list of ints | ||
Values from chains to be used operations | ||
""" | ||
def __init__(self, var_names, backend=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor: why not just 'vars'? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I should definitely go through this and change it to be consistent. |
||
self.var_names = var_names | ||
self.backend = backend | ||
self._active_chains = [] | ||
self._default_chain = None | ||
|
||
@property | ||
def default_chain(self): | ||
"""Default chain to use for operations that require one chain (e.g., | ||
`point`) | ||
""" | ||
if self._default_chain is None: | ||
return self.active_chains[-1] | ||
return self._default_chain | ||
|
||
@default_chain.setter | ||
def default_chain(self, value): | ||
self._default_chain = value | ||
|
||
@property | ||
def active_chains(self): | ||
"""List of chains to be used. Defaults to all. | ||
""" | ||
if not self._active_chains: | ||
return self.chains | ||
return self._active_chains | ||
|
||
@active_chains.setter | ||
def active_chains(self, values): | ||
try: | ||
self._active_chains = [chain for chain in values] | ||
except TypeError: | ||
self._active_chains = [values] | ||
|
||
@property | ||
def nchains(self): | ||
"""Number of chains | ||
|
||
A chain is created for each sample call (including parallel | ||
threads). | ||
""" | ||
return len(self.chains) | ||
|
||
def __getitem__(self, idx): | ||
if isinstance(idx, slice): | ||
return self._slice(idx) | ||
|
||
try: | ||
return self.point(idx) | ||
except ValueError: | ||
pass | ||
except TypeError: | ||
pass | ||
return self.get_values(idx) | ||
|
||
## Selection methods that children must define | ||
|
||
@property | ||
def chains(self): | ||
"""All chains in trace""" | ||
raise NotImplementedError | ||
|
||
def __len__(self): | ||
raise NotImplementedError | ||
|
||
def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None, | ||
squeeze=True): | ||
"""Get values from samples | ||
|
||
Parameters | ||
---------- | ||
var_name : str | ||
burn : int | ||
thin : int | ||
combine : bool | ||
If True, results from all chains will be concatenated. | ||
chains : list | ||
Chains to retrieve. If None, `active_chains` is used. | ||
squeeze : bool | ||
If `combine` is False, return a single array element if the | ||
resulting list of values only has one element (even if | ||
`combine` is True). | ||
|
||
Returns | ||
------- | ||
A list of NumPy array of values | ||
""" | ||
raise NotImplementedError | ||
|
||
def _slice(self, idx): | ||
"""Slice trace object""" | ||
raise NotImplementedError | ||
|
||
def point(self, idx, chain=None): | ||
"""Return dictionary of point values at `idx` for current chain | ||
with variables names as keys. | ||
|
||
If `chain` is not specified, `default_chain` is used. | ||
""" | ||
raise NotImplementedError | ||
|
||
def merge_chains(traces): | ||
"""Merge chains from trace instances | ||
|
||
Parameters | ||
---------- | ||
traces : list | ||
Backend trace instances. Each instance should have only one | ||
chain, and all chain numbers should be unique. | ||
|
||
Raises | ||
------ | ||
ValueError is raised if any traces have the same current chain | ||
number. | ||
|
||
Returns | ||
------- | ||
Backend instance with merge chains | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
def _squeeze_cat(results, combine, squeeze): | ||
"""Squeeze and concatenate the results dependending on values of | ||
`combine` and `squeeze`""" | ||
if combine: | ||
results = np.concatenate(results) | ||
if not squeeze: | ||
results = [results] | ||
else: | ||
if squeeze and len(results) == 1: | ||
results = results[0] | ||
return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I think this class provides too much structure for the different backends, and I wasn't expecting that there would be a shared base class for all backends at all (though maybe there's enough shared code that you really want one).
I would prefer to have a very simple interface for the backends, mainly they implement a
record
and perhapsclose
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought, there is at least a significant subset of backends that will have significant shared code: backends with a distinct container for each variable (lets call those 'container backends'), like a NdArray backend or SQL backend. However, I still think this class can be made significantly simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comments. I will try to incorporate your suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the difference that we're trying to draw between a Trace and a Backend? It looks to me like Trace and Backend should be the same class. The object that is responsible for storing the points should also be responsible for retrieving them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also seems to me that a trace with multiple chains should just be a container for multiple traces, and the logic for individual chains should be in separate classes. Composition is better than just adding features to a class.
Is the reason why even the basic traces deal with 'chains' because the SQL db has a shared resource between its chains?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In an earlier version, I experimented with using one class, and we could
move that way again, but there are a few reasons I think two classes are
a good idea.
users. By returning a separate trace object for retrieval, users have
an object with methods that only concern what they'll be using it for
(selecting and viewing values). They don't deal with the storage
object directly (
sample
does).clearer when defining custom backends which methods should be
overridden for each class. For a working backend, only the
record
method of the storage backend must be defined. The trace class just
provides a way to make the custom backend have a consistent interface
for accessing values.
responsibilities between storing and retrieving values seemed like a
natural division.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking the other way around: a single-chain trace should just be
a "multiple" trace object with one chain. This way, all traces are
handled the same, and functions like traceplot don't need to check
whether the trace is an instance of the single or multiple trace class.
Also, when using non-memory backends, the distinction between single and
multiple trace isn't as clear to me because this will be made at the
level of the call to the database.
Is there a specific advantage to using a container object that you have
in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm mostly thinking of code clarity, extendability and composability. I've found this code somewhat challenging to read and understand and that concerns me.
That's fair. That is a benefit.
I definitely agree a common interface to traces is a good idea. There are other ways to achieve this, for example, by always returning the container trace class.
True, but in psample, each process still needs an individual object to record their traces.