Skip to content

Text and SQLite backends for PyMC3 #449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Backends for traces

Available backends
------------------

1. NumPy array (pymc.backends.NDArray)
2. Text files (pymc.backends.Text)
3. SQLite (pymc.backends.SQLite)

The NumPy arrays and text files both hold the entire trace in memory,
whereas SQLite commits the trace to the database while sampling.

Selecting a backend
-------------------

By default, a NumPy array is used as the backend. To specify a different
backend, pass a backend instance to `sample`.

For example, the following would save traces to the file 'test.db'.

>>> import pymc as pm
>>> db = pm.backends.SQLite('test.db')
>>> trace = pm.sample(..., db=db)

Selecting values from a backend
-------------------------------

After a backend is finished sampling, it returns a Trace object. Values
can be accessed in a few ways. The easiest way is to index the backend
object with a variable or variable name.

>>> trace['x'] # or trace[x]

The call will return a list containing the sampling values for all
chains of `x`. (Each call to `pymc.sample` creates a separate chain of
samples.)

For more control is needed of which values are returned, the
`get_values` method can be used. The call below will return values from
all chains, burning the first 1000 iterations from each chain.

>>> trace.get_values('x', burn=1000)

Setting the `combined` flag will concatenate the results from all the
chains.

>>> trace.get_values('x', burn=1000, combine=True)

The `chains` parameter of `get_values` can be used to limit the chains
that are retrieved. To work with a subset of chains without having to
specify `chains` each call, you can set the `active_chains` attribute.

>>> trace.chains
[0, 1, 2]
>>> trace.active_chains = [0, 2]

After this, only chains 0 and 2 will be used in operations that work
with multiple chains.

Similary, the `default_chain` attribute sets which chain is used for
functions that require a single chain (e.g., point).

>>> trace.point(4) # or trace[4]

Backends can also suppport slicing the trace object. For example, the
following call would return a new trace object without the first 1000
sampling iterations for all variables.

>>> sliced_trace = trace[1000:]

Loading a saved backend
-----------------------

Saved backends can be loaded using `load` function in the module for the
specific backend.

>>> trace = pm.backends.sqlite.load('test.db')

Writing custom backends
-----------------------

Backends consist of two classes: one that handles storing the sample
results (e.g., backends.ndarray.NDArray or backends.sqlite.SQLite) and
one that handles value selection (e.g., backends.ndarray.Trace or
backends.sqlite.Trace).

Three methods of the storage class will be called:

- setup: Before sampling is started, the `setup` method will be called
with two arguments: the number of draws and the chain number. This is
useful setting up any structure for storing the sampling values that
require the above information.

- record: Record the sampling results for the current draw. This method
will be called with a dictionary of values mapped to the variable
names. This is the only function that *must* do something to have a
meaningful backend.

- close: This method is called following sampling and should perform any
actions necessary for finalizing and cleaning up the backend.

The base storage class `backends.base.Backend` provides model setup that
is used by PyMC backends.

After sampling has completed, the `trace` attribute of the storage
object will be returned. To have a consistent interface with the backend
trace objects in PyMC, this attribute should be an instance of a class
that inherits from pymc.backends.base.Trace, and several methods in the
inherited Trace object should be defined.

- get_values: This is the core method for selecting values from the
backend. It can be called directly and is used by __getitem__ when the
backend is indexed with a variable name or object.

- _slice: Defines how the backend returns a slice of itself. This
is called if the backend is indexed with a slice range.

- point: Returns values for each variables at a single iteration. This
is called if the backend is indexed with a single integer.

- __len__: This should return the number of draws (for the default
chain).

- chains: Property that returns a list of chains

In addtion, a `merge_chains` method should be defined if the backend
will be used with parallel sampling. This method describes how to merge
sampling chains from a list of other traces.

As mentioned above, the only method necessary to store the sampling
values is `record`. Other methods in the storage may consist of only a
pass statement. The storage object should have an attribute `trace`
(with a `merge_chains` method for parallel sampling), but this does not
have to do anything if storing the values is all that is desired. The
backends.base.Trace is provided for convenience in setting up a
consistent Trace object.

For specific examples, see pymc.backends.{ndarray,text,sqlite}.py.
"""
from pymc.backends.ndarray import NDArray
from pymc.backends.text import Text
from pymc.backends.sqlite import SQLite
219 changes: 219 additions & 0 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""Base backend for traces

See the docstring for pymc.backends for more information (includng
creating custom backends).
"""
import numpy as np
from pymc.model import modelcontext


class Backend(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I think this class provides too much structure for the different backends, and I wasn't expecting that there would be a shared base class for all backends at all (though maybe there's enough shared code that you really want one).

I would prefer to have a very simple interface for the backends, mainly they implement a record and perhaps close.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, there is at least a significant subset of backends that will have significant shared code: backends with a distinct container for each variable (lets call those 'container backends'), like a NdArray backend or SQL backend. However, I still think this class can be made significantly simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments. I will try to incorporate your suggestions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference that we're trying to draw between a Trace and a Backend? It looks to me like Trace and Backend should be the same class. The object that is responsible for storing the points should also be responsible for retrieving them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also seems to me that a trace with multiple chains should just be a container for multiple traces, and the logic for individual chains should be in separate classes. Composition is better than just adding features to a class.

Is the reason why even the basic traces deal with 'chains' because the SQL db has a shared resource between its chains?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference that we're trying to draw between a Trace and a
Backend? It looks to me like Trace and Backend should be the same
class. The object that is responsible for storing the points should
also be responsible for retrieving them.

In an earlier version, I experimented with using one class, and we could
move that way again, but there are a few reasons I think two classes are
a good idea.

  • It makes the trace interface more consistent and less cluttered for
    users. By returning a separate trace object for retrieval, users have
    an object with methods that only concern what they'll be using it for
    (selecting and viewing values). They don't deal with the storage
    object directly (sample does).
  • I think having separate classes for storage versus retrieval makes it
    clearer when defining custom backends which methods should be
    overridden for each class. For a working backend, only the record
    method of the storage backend must be defined. The trace class just
    provides a way to make the custom backend have a consistent interface
    for accessing values.
  • If put in one class, it's pretty large, so dividing the
    responsibilities between storing and retrieving values seemed like a
    natural division.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also seems to me that a trace with multiple chains should just be a
container for multiple traces, and the logic for individual chains
should be in separate classes.

I was thinking the other way around: a single-chain trace should just be
a "multiple" trace object with one chain. This way, all traces are
handled the same, and functions like traceplot don't need to check
whether the trace is an instance of the single or multiple trace class.

Also, when using non-memory backends, the distinction between single and
multiple trace isn't as clear to me because this will be made at the
level of the call to the database.

Is there a specific advantage to using a container object that you have
in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific advantage to using a container object that you have
in mind?

I think I'm mostly thinking of code clarity, extendability and composability. I've found this code somewhat challenging to read and understand and that concerns me.

It makes the trace interface more consistent and less cluttered for
users.

That's fair. That is a benefit.

I think having separate classes for storage versus retrieval makes it
clearer when defining custom backends which methods should be
overridden for each class.

This way, all traces are
handled the same, and functions like traceplot don't need to check
whether the trace is an instance of the single or multiple trace class.

I definitely agree a common interface to traces is a good idea. There are other ways to achieve this, for example, by always returning the container trace class.

Also, when using non-memory backends, the distinction between single and
multiple trace isn't as clear to me because this will be made at the
level of the call to the database.

True, but in psample, each process still needs an individual object to record their traces.

"""Base storage class

Parameters
----------
name : str
Name of backend.
model : Model
If None, the model is taken from the `with` context.
variables : list of variable objects
Sampling values will be stored for these variables
"""
def __init__(self, name, model=None, variables=None):
self.name = name

model = modelcontext(model)
if variables is None:
variables = model.unobserved_RVs
self.variables = variables
self.var_names = [str(var) for var in variables]
self.fn = model.fastfn(variables)

## get variable shapes. common enough that I think most backends
## will use this
var_values = zip(self.var_names, self.fn(model.test_point))
self.var_shapes = {var: value.shape
for var, value in var_values}
self.chain = None
self.trace = None

def setup(self, draws, chain):
"""Perform chain-specific setup

draws : int
Expected number of draws
chain : int
chain number
"""
pass

def record(self, point):
"""Record results of a sampling iteration

point : dict
Values mappled to variable names
"""
raise NotImplementedError

def close(self):
"""Close the database backend

This is called after sampling has finished.
"""
pass


class Trace(object):
"""
Parameters
----------
var_names : list of strs
Sample variables names
backend : Backend object

Attributes
----------
var_names
backend : Backend object
nchains : int
Number of sampling chains
chains : list of ints
List of sampling chain numbers
default_chain : int
Chain to be used if single chain requested
active_chains : list of ints
Values from chains to be used operations
"""
def __init__(self, var_names, backend=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: why not just 'vars'?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should definitely go through this and change it to be consistent.
Should it be varnames instead of vars because they're strings or are
you using vars for both? Also, any concern that vars is a built-in?

self.var_names = var_names
self.backend = backend
self._active_chains = []
self._default_chain = None

@property
def default_chain(self):
"""Default chain to use for operations that require one chain (e.g.,
`point`)
"""
if self._default_chain is None:
return self.active_chains[-1]
return self._default_chain

@default_chain.setter
def default_chain(self, value):
self._default_chain = value

@property
def active_chains(self):
"""List of chains to be used. Defaults to all.
"""
if not self._active_chains:
return self.chains
return self._active_chains

@active_chains.setter
def active_chains(self, values):
try:
self._active_chains = [chain for chain in values]
except TypeError:
self._active_chains = [values]

@property
def nchains(self):
"""Number of chains

A chain is created for each sample call (including parallel
threads).
"""
return len(self.chains)

def __getitem__(self, idx):
if isinstance(idx, slice):
return self._slice(idx)

try:
return self.point(idx)
except ValueError:
pass
except TypeError:
pass
return self.get_values(idx)

## Selection methods that children must define

@property
def chains(self):
"""All chains in trace"""
raise NotImplementedError

def __len__(self):
raise NotImplementedError

def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None,
squeeze=True):
"""Get values from samples

Parameters
----------
var_name : str
burn : int
thin : int
combine : bool
If True, results from all chains will be concatenated.
chains : list
Chains to retrieve. If None, `active_chains` is used.
squeeze : bool
If `combine` is False, return a single array element if the
resulting list of values only has one element (even if
`combine` is True).

Returns
-------
A list of NumPy array of values
"""
raise NotImplementedError

def _slice(self, idx):
"""Slice trace object"""
raise NotImplementedError

def point(self, idx, chain=None):
"""Return dictionary of point values at `idx` for current chain
with variables names as keys.

If `chain` is not specified, `default_chain` is used.
"""
raise NotImplementedError

def merge_chains(traces):
"""Merge chains from trace instances

Parameters
----------
traces : list
Backend trace instances. Each instance should have only one
chain, and all chain numbers should be unique.

Raises
------
ValueError is raised if any traces have the same current chain
number.

Returns
-------
Backend instance with merge chains
"""
raise NotImplementedError


def _squeeze_cat(results, combine, squeeze):
"""Squeeze and concatenate the results dependending on values of
`combine` and `squeeze`"""
if combine:
results = np.concatenate(results)
if not squeeze:
results = [results]
else:
if squeeze and len(results) == 1:
results = results[0]
return results
Loading