Skip to content

Commit 8c28e9d

Browse files
author
Chris Fonnesbeck
committed
Merge pull request #115 from memmett/hdf5-earray
Add a new database (hdf5ea) based on HDF5 and extendable arrays (instead of tables).
2 parents b0daa21 + b1b714c commit 8c28e9d

File tree

2 files changed

+345
-1
lines changed

2 files changed

+345
-1
lines changed

pymc/database/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
3333
"""
3434

35-
__modules__ = ['no_trace', 'txt', 'ram', 'pickle', 'sqlite', 'hdf5', 'hdf52', "__test_import__"]
35+
__modules__ = ['no_trace', 'txt', 'ram', 'pickle', 'sqlite', 'hdf5', 'hdf5ea', "__test_import__"]
3636

3737
from . import no_trace
3838
from . import txt
@@ -49,4 +49,9 @@
4949
except ImportError:
5050
pass
5151

52+
try:
53+
from . import hdf5ea
54+
except ImportError:
55+
pass
56+
5257

pymc/database/hdf5ea.py

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
"""HDF5 database module.
2+
3+
Store the traces in an HDF5 array using pytables.
4+
5+
6+
Implementation Notes
7+
--------------------
8+
9+
This version only supports numeric objects, and stores them in
10+
extentable HDF5 arrays. This allows the implementation to handle very
11+
large data vectors.
12+
13+
14+
Additional Dependencies
15+
-----------------------
16+
* HDF5 version 1.6.5, required by pytables.
17+
* pytables version 2 and up. <http://sourceforge.net/projects/pytables/>
18+
19+
"""
20+
21+
import os
22+
import sys
23+
import traceback
24+
import warnings
25+
26+
import numpy as np
27+
import pymc
28+
import tables
29+
30+
from pymc.database import base, pickle
31+
from pymc import six
32+
33+
__all__ = ['Trace', 'Database', 'load']
34+
35+
warn_tally = """
36+
Error tallying %s, will not try to tally it again this chain.
37+
Did you make all the same variables and step methods tallyable
38+
as were tallyable last time you used the database file?
39+
40+
Error:
41+
42+
%s"""
43+
44+
45+
###############################################################################
46+
47+
class Trace(base.Trace):
48+
"""HDF5 trace."""
49+
50+
51+
def tally(self, chain):
52+
"""Adds current value to trace."""
53+
54+
arr = np.asarray(self._getfunc())
55+
arr = arr.reshape((1,) + arr.shape)
56+
self.db._arrays[chain, self.name].append(arr)
57+
58+
59+
# def __getitem__(self, index):
60+
# """Mimic NumPy indexing for arrays."""
61+
# chain = self._chain
62+
63+
# if chain is not None:
64+
# tables = [self.db._gettables()[chain],]
65+
# else:
66+
# tables = self.db._gettables()
67+
68+
# out = []
69+
# for table in tables:
70+
# out.append(table.col(self.name))
71+
72+
# if np.isscalar(chain):
73+
# return out[0][index]
74+
# else:
75+
# return np.hstack(out)[index]
76+
77+
78+
def gettrace(self, burn=0, thin=1, chain=-1, slicing=None):
79+
"""Return the trace (last by default).
80+
81+
:Parameters:
82+
burn : integer
83+
The number of transient steps to skip.
84+
thin : integer
85+
Keep one in thin.
86+
chain : integer
87+
The index of the chain to fetch. If None, return all chains. The
88+
default is to return the last chain.
89+
slicing : slice object
90+
A slice overriding burn and thin assignement.
91+
"""
92+
93+
# XXX: handle chain == None case properly
94+
95+
if chain is None:
96+
chain = -1
97+
chain = self.db.chains[chain]
98+
99+
arr = self.db._arrays[chain, self.name]
100+
101+
if slicing is not None:
102+
burn, stop, thin = slicing.start, slicing.stop, slicing.step
103+
104+
if slicing is None or stop is None:
105+
stop = arr.nrows
106+
return np.asarray(arr.read(start=burn, stop=stop, step=thin))
107+
108+
__call__ = gettrace
109+
110+
# def length(self, chain=-1):
111+
# """Return the length of the trace.
112+
113+
# :Parameters:
114+
# chain : int or None
115+
# The chain index. If None, returns the combined length of all chains.
116+
# """
117+
# if chain is not None:
118+
# tables = [self.db._gettables()[chain],]
119+
# else:
120+
# tables = self.db._gettables()
121+
122+
# n = np.asarray([table.nrows for table in tables])
123+
# return n.sum()
124+
125+
126+
###############################################################################
127+
128+
class Database(pickle.Database):
129+
"""HDF5 database.
130+
131+
Create an HDF5 file <model>.h5. Each chain is stored in a group,
132+
and the stochastics and deterministics are stored as extendable
133+
arrays in each group.
134+
"""
135+
136+
137+
def __init__(self, dbname, dbmode='a',
138+
dbcomplevel=0, dbcomplib='zlib',
139+
**kwds):
140+
"""Create an HDF5 database instance, where samples are stored
141+
in extendable arrays.
142+
143+
:Parameters:
144+
dbname : string
145+
Name of the hdf5 file.
146+
dbmode : {'a', 'w', 'r'}
147+
File mode: 'a': append, 'w': overwrite, 'r': read-only.
148+
dbcomplevel : integer (0-9)
149+
Compression level, 0: no compression.
150+
dbcomplib : string
151+
Compression library (zlib, bzip2, lzo)
152+
153+
:Notes:
154+
* zlib has a good compression ratio, although somewhat slow, and
155+
reasonably fast decompression.
156+
* lzo is a fast compression library offering however a low compression
157+
ratio.
158+
* bzip2 has an excellent compression ratio but requires more CPU.
159+
"""
160+
161+
self.__name__ = 'hdf5ea'
162+
self.__Trace__ = Trace
163+
164+
self.dbname = dbname
165+
self.mode = dbmode
166+
167+
db_exists = os.path.exists(self.dbname)
168+
self._h5file = tables.openFile(self.dbname, self.mode)
169+
170+
default_filter = tables.Filters(complevel=dbcomplevel, complib=dbcomplib)
171+
if self.mode =='r' or (self.mode=='a' and db_exists):
172+
self.filter = getattr(self._h5file, 'filters', default_filter)
173+
else:
174+
self.filter = default_filter
175+
176+
self.trace_names = []
177+
self._traces = {}
178+
# self._states = {}
179+
self._chains = {}
180+
self._arrays = {}
181+
182+
# load existing data
183+
existing_chains = [ gr for gr in self._h5file.listNodes("/")
184+
if gr._v_name[:5] == 'chain' ]
185+
186+
for chain in existing_chains:
187+
nchain = int(chain._v_name[5:])
188+
self._chains[nchain] = chain
189+
190+
names = []
191+
for array in chain._f_listNodes():
192+
name = array._v_name
193+
self._arrays[nchain, name] = array
194+
195+
if name not in self._traces:
196+
self._traces[name] = Trace(name, db=self)
197+
198+
names.append(name)
199+
200+
self.trace_names.append(names)
201+
202+
203+
@property
204+
def chains(self):
205+
return range(len(self._chains))
206+
207+
208+
@property
209+
def nchains(self):
210+
return len(self._chains)
211+
212+
213+
# def connect_model(self, model):
214+
# """Link the Database to the Model instance.
215+
216+
# In case a new database is created from scratch, ``connect_model``
217+
# creates Trace objects for all tallyable pymc objects defined in
218+
# `model`.
219+
220+
# If the database is being loaded from an existing file, ``connect_model``
221+
# restore the objects trace to their stored value.
222+
223+
# :Parameters:
224+
# model : pymc.Model instance
225+
# An instance holding the pymc objects defining a statistical
226+
# model (stochastics, deterministics, data, ...)
227+
# """
228+
229+
# # Changed this to allow non-Model models. -AP
230+
# if isinstance(model, pymc.Model):
231+
# self.model = model
232+
# else:
233+
# raise AttributeError('Not a Model instance.')
234+
235+
# # Restore the state of the Model from an existing Database.
236+
# # The `load` method will have already created the Trace objects.
237+
# if hasattr(self, '_state_'):
238+
# names = set()
239+
# for morenames in self.trace_names:
240+
# names.update(morenames)
241+
# for name, fun in six.iteritems(model._funs_to_tally):
242+
# if name in self._traces:
243+
# self._traces[name]._getfunc = fun
244+
# names.remove(name)
245+
# if len(names) > 0:
246+
# raise RuntimeError("Some objects from the database"
247+
# + "have not been assigned a getfunc: %s"
248+
# % ', '.join(names))
249+
250+
251+
def _initialize(self, funs_to_tally, length):
252+
"""Create a group named ``chain#`` to store all data for this chain."""
253+
254+
chain = self.nchains
255+
self._chains[chain] = self._h5file.createGroup(
256+
'/', 'chain%d' % chain, 'chain #%d' % chain)
257+
258+
for name, fun in six.iteritems(funs_to_tally):
259+
260+
arr = np.asarray(fun())
261+
262+
assert arr.dtype != np.dtype('object')
263+
264+
array = self._h5file.createEArray(
265+
self._chains[chain], name,
266+
tables.Atom.from_dtype(arr.dtype), (0,) + arr.shape,
267+
filters=self.filter)
268+
269+
self._arrays[chain, name] = array
270+
self._traces[name] = Trace(name, getfunc=fun, db=self)
271+
self._traces[name]._initialize(self.chains, length)
272+
273+
self.trace_names.append(funs_to_tally.keys())
274+
275+
276+
def tally(self, chain=-1):
277+
chain = self.chains[chain]
278+
for name in self.trace_names[chain]:
279+
try:
280+
self._traces[name].tally(chain)
281+
self._arrays[chain, name].flush()
282+
except:
283+
cls, inst, tb = sys.exc_info()
284+
warnings.warn(warn_tally
285+
% (name, ''.join(traceback.format_exception(cls, inst, tb))))
286+
self.trace_names[chain].remove(name)
287+
288+
289+
290+
# def savestate(self, state, chain=-1):
291+
# """Store a dictionnary containing the state of the Model and its
292+
# StepMethods."""
293+
294+
# chain = self.chains[chain]
295+
# if chain in self._states:
296+
# self._states[chain] = state
297+
# else:
298+
# s = self._h5file.createVLArray(chain,'_state_',tables.ObjectAtom(),title='The saved state of the sampler',filters=self.filter)
299+
# s.append(state)
300+
# self._h5file.flush()
301+
302+
303+
# def getstate(self, chain=-1):
304+
# if len(self._chains)==0:
305+
# return {}
306+
# elif hasattr(self._chains[chain],'_state_'):
307+
# if len(self._chains[chain]._state_)>0:
308+
# return self._chains[chain]._state_[0]
309+
# else:
310+
# return {}
311+
# else:
312+
# return {}
313+
314+
315+
def _finalize(self, chain=-1):
316+
self._h5file.flush()
317+
318+
def close(self):
319+
self._h5file.close()
320+
321+
322+
323+
324+
def load(dbname, dbmode='a'):
325+
"""Load an existing hdf5 database.
326+
327+
Return a Database instance.
328+
329+
:Parameters:
330+
filename : string
331+
Name of the hdf5 database to open.
332+
mode : 'a', 'r'
333+
File mode : 'a': append, 'r': read-only.
334+
"""
335+
if dbmode == 'w':
336+
raise AttributeError("dbmode='w' not allowed for load.")
337+
db = Database(dbname, dbmode=dbmode)
338+
339+
return db

0 commit comments

Comments
 (0)