|
4 | 4 | from theano import function
|
5 | 5 | import theano
|
6 | 6 | from ..memoize import memoize
|
7 |
| -from ..model import Model, get_named_nodes, FreeRV, ObservedRV |
| 7 | +from ..model import Model, get_named_nodes_and_relations, FreeRV, ObservedRV |
8 | 8 | from ..vartypes import string_types
|
9 | 9 |
|
10 | 10 | __all__ = ['DensityDist', 'Distribution', 'Continuous', 'Discrete',
|
@@ -229,17 +229,72 @@ def draw_values(params, point=None):
|
229 | 229 |
|
230 | 230 | """
|
231 | 231 | # Distribution parameters may be nodes which have named node-inputs
|
232 |
| - # specified in the point. Need to find the node-inputs to replace them. |
233 |
| - givens = {} |
| 232 | + # specified in the point. Need to find the node-inputs, their |
| 233 | + # parents and children to replace them. |
| 234 | + leaf_nodes = {} |
| 235 | + named_nodes_parents = {} |
| 236 | + named_nodes_children = {} |
234 | 237 | for param in params:
|
235 | 238 | if hasattr(param, 'name'):
|
236 |
| - named_nodes = get_named_nodes(param) |
237 |
| - if param.name in named_nodes: |
238 |
| - named_nodes.pop(param.name) |
239 |
| - for name, node in named_nodes.items(): |
240 |
| - if not isinstance(node, (tt.sharedvar.SharedVariable, |
241 |
| - tt.TensorConstant)): |
242 |
| - givens[name] = (node, _draw_value(node, point=point)) |
| 239 | + # Get the named nodes under the `param` node |
| 240 | + nn, nnp, nnc = get_named_nodes_and_relations(param) |
| 241 | + leaf_nodes.update(nn) |
| 242 | + # Update the discovered parental relationships |
| 243 | + for k in nnp.keys(): |
| 244 | + if k not in named_nodes_parents.keys(): |
| 245 | + named_nodes_parents[k] = nnp[k] |
| 246 | + else: |
| 247 | + named_nodes_parents[k].update(nnp[k]) |
| 248 | + # Update the discovered child relationships |
| 249 | + for k in nnc.keys(): |
| 250 | + if k not in named_nodes_children.keys(): |
| 251 | + named_nodes_children[k] = nnc[k] |
| 252 | + else: |
| 253 | + named_nodes_children[k].update(nnc[k]) |
| 254 | + |
| 255 | + # Init givens and the stack of nodes to try to `_draw_value` from |
| 256 | + givens = {} |
| 257 | + stored = set([]) # Some nodes |
| 258 | + stack = list(leaf_nodes.values()) # A queue would be more appropriate |
| 259 | + while stack: |
| 260 | + next_ = stack.pop(0) |
| 261 | + if next_ in stored: |
| 262 | + # If the node already has a givens value, skip it |
| 263 | + continue |
| 264 | + elif isinstance(next_, (tt.TensorConstant, |
| 265 | + tt.sharedvar.SharedVariable)): |
| 266 | + # If the node is a theano.tensor.TensorConstant or a |
| 267 | + # theano.tensor.sharedvar.SharedVariable, its value will be |
| 268 | + # available automatically in _compile_theano_function so |
| 269 | + # we can skip it. Furthermore, if this node was treated as a |
| 270 | + # TensorVariable that should be compiled by theano in |
| 271 | + # _compile_theano_function, it would raise a `TypeError: |
| 272 | + # ('Constants not allowed in param list', ...)` for |
| 273 | + # TensorConstant, and a `TypeError: Cannot use a shared |
| 274 | + # variable (...) as explicit input` for SharedVariable. |
| 275 | + stored.add(next_.name) |
| 276 | + continue |
| 277 | + else: |
| 278 | + # If the node does not have a givens value, try to draw it. |
| 279 | + # The named node's children givens values must also be taken |
| 280 | + # into account. |
| 281 | + children = named_nodes_children[next_] |
| 282 | + temp_givens = [givens[k] for k in givens.keys() if k in children] |
| 283 | + try: |
| 284 | + # This may fail for autotransformed RVs, which don't |
| 285 | + # have the random method |
| 286 | + givens[next_.name] = (next_, _draw_value(next_, |
| 287 | + point=point, |
| 288 | + givens=temp_givens)) |
| 289 | + stored.add(next_.name) |
| 290 | + except theano.gof.fg.MissingInputError: |
| 291 | + # The node failed, so we must add the node's parents to |
| 292 | + # the stack of nodes to try to draw from. We exclude the |
| 293 | + # nodes in the `params` list. |
| 294 | + stack.extend([node for node in named_nodes_parents[next_] |
| 295 | + if node is not None and |
| 296 | + node.name not in stored and |
| 297 | + node not in params]) |
243 | 298 | values = []
|
244 | 299 | for param in params:
|
245 | 300 | values.append(_draw_value(param, point=point, givens=givens.values()))
|
|
0 commit comments