Skip to content

Commit b1b714c

Browse files
author
Matthew Emmett
committed
HDF5EA: Load meta info from existing file.
1 parent 304e667 commit b1b714c

File tree

1 file changed

+48
-69
lines changed

1 file changed

+48
-69
lines changed

pymc/database/hdf5ea.py

Lines changed: 48 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __init__(self, dbname, dbmode='a',
138138
dbcomplevel=0, dbcomplib='zlib',
139139
**kwds):
140140
"""Create an HDF5 database instance, where samples are stored
141-
in tables.
141+
in extendable arrays.
142142
143143
:Parameters:
144144
dbname : string
@@ -175,49 +175,29 @@ def __init__(self, dbname, dbmode='a',
175175

176176
self.trace_names = []
177177
self._traces = {}
178-
self._states = {}
178+
# self._states = {}
179179
self._chains = {}
180180
self._arrays = {}
181181

182-
# # LOAD LOGIC
183-
# if self.chains > 0:
184-
# # Create traces from objects stored in Table.
185-
# db = self
186-
# for k in db._tables[-1].colnames:
187-
# db._traces[k] = Trace(name=k, db=db)
188-
# setattr(db, k, db._traces[k])
182+
# load existing data
183+
existing_chains = [ gr for gr in self._h5file.listNodes("/")
184+
if gr._v_name[:5] == 'chain' ]
189185

186+
for chain in existing_chains:
187+
nchain = int(chain._v_name[5:])
188+
self._chains[nchain] = chain
190189

191-
# # Walk nodes proceed from top to bottom, so we need to invert
192-
# # the list to have the chains in chronological order.
193-
# objects = {}
194-
# for chain in self._chains:
195-
# for node in db._h5file.walkNodes(chain, classname='VLArray'):
196-
# if node._v_name != '_state_':
197-
# try:
198-
# objects[node._v_name].append(node)
199-
# except:
200-
# objects[node._v_name] = [node,]
190+
names = []
191+
for array in chain._f_listNodes():
192+
name = array._v_name
193+
self._arrays[nchain, name] = array
201194

202-
# # Note that the list vlarrays is in reverse order.
203-
# for k, vlarrays in six.iteritems(objects):
204-
# db._traces[k] = TraceObject(name=k, db=db, vlarrays=vlarrays)
205-
# setattr(db, k, db._traces[k])
195+
if name not in self._traces:
196+
self._traces[name] = Trace(name, db=self)
206197

207-
# # Restore table attributes.
208-
# # This restores the sampler's state for the last chain.
209-
# table = db._tables[-1]
210-
# for k in table.attrs._v_attrnamesuser:
211-
# setattr(db, k, getattr(table.attrs, k))
212-
213-
# # Restore group attributes.
214-
# for k in db._chains[-1]._f_listNodes():
215-
# if k.__class__ not in [tables.Table, tables.Group]:
216-
# setattr(db, k.name, k)
217-
218-
# varnames = db._tables[-1].colnames+ objects.keys()
219-
# db.trace_names = db.chains * [varnames,]
198+
names.append(name)
220199

200+
self.trace_names.append(names)
221201

222202

223203
@property
@@ -230,42 +210,42 @@ def nchains(self):
230210
return len(self._chains)
231211

232212

233-
def connect_model(self, model):
234-
"""Link the Database to the Model instance.
213+
# def connect_model(self, model):
214+
# """Link the Database to the Model instance.
235215

236-
In case a new database is created from scratch, ``connect_model``
237-
creates Trace objects for all tallyable pymc objects defined in
238-
`model`.
216+
# In case a new database is created from scratch, ``connect_model``
217+
# creates Trace objects for all tallyable pymc objects defined in
218+
# `model`.
239219

240-
If the database is being loaded from an existing file, ``connect_model``
241-
restore the objects trace to their stored value.
220+
# If the database is being loaded from an existing file, ``connect_model``
221+
# restore the objects trace to their stored value.
242222

243-
:Parameters:
244-
model : pymc.Model instance
245-
An instance holding the pymc objects defining a statistical
246-
model (stochastics, deterministics, data, ...)
247-
"""
223+
# :Parameters:
224+
# model : pymc.Model instance
225+
# An instance holding the pymc objects defining a statistical
226+
# model (stochastics, deterministics, data, ...)
227+
# """
248228

249-
# Changed this to allow non-Model models. -AP
250-
if isinstance(model, pymc.Model):
251-
self.model = model
252-
else:
253-
raise AttributeError('Not a Model instance.')
254-
255-
# Restore the state of the Model from an existing Database.
256-
# The `load` method will have already created the Trace objects.
257-
if hasattr(self, '_state_'):
258-
names = set()
259-
for morenames in self.trace_names:
260-
names.update(morenames)
261-
for name, fun in six.iteritems(model._funs_to_tally):
262-
if name in self._traces:
263-
self._traces[name]._getfunc = fun
264-
names.remove(name)
265-
if len(names) > 0:
266-
raise RuntimeError("Some objects from the database"
267-
+ "have not been assigned a getfunc: %s"
268-
% ', '.join(names))
229+
# # Changed this to allow non-Model models. -AP
230+
# if isinstance(model, pymc.Model):
231+
# self.model = model
232+
# else:
233+
# raise AttributeError('Not a Model instance.')
234+
235+
# # Restore the state of the Model from an existing Database.
236+
# # The `load` method will have already created the Trace objects.
237+
# if hasattr(self, '_state_'):
238+
# names = set()
239+
# for morenames in self.trace_names:
240+
# names.update(morenames)
241+
# for name, fun in six.iteritems(model._funs_to_tally):
242+
# if name in self._traces:
243+
# self._traces[name]._getfunc = fun
244+
# names.remove(name)
245+
# if len(names) > 0:
246+
# raise RuntimeError("Some objects from the database"
247+
# + "have not been assigned a getfunc: %s"
248+
# % ', '.join(names))
269249

270250

271251
def _initialize(self, funs_to_tally, length):
@@ -290,7 +270,6 @@ def _initialize(self, funs_to_tally, length):
290270
self._traces[name] = Trace(name, getfunc=fun, db=self)
291271
self._traces[name]._initialize(self.chains, length)
292272

293-
# XXX: not quite sure if this is right
294273
self.trace_names.append(funs_to_tally.keys())
295274

296275

0 commit comments

Comments
 (0)