Skip to content

Commit f833a14

Browse files
lucianopazJunpeng Lao
authored and
Junpeng Lao
committed
Refactor distribution.draw_values (#2902)
* Fix for #2900. Changed the way in which draw_values handles the named node-inputs. Now the tree dependence is constructed to set the givens dict. * Fixed conflicts * Fixed more conflicts * Fixed typo * Changed test_dep_vars to test for successful draws even in the cases of dependent variables. * Removed comments from test_random.py, and distribution.py. Added content to RELEASE-NOTES. Fixed bug in the interaction between draw_values, _draw_value and _compile_theano_function. In some cases, draw_values would set an item of the givens dictionary to a theano.tensor.TensorConstant. In _draw_value(param, ...), if param was a theano.tensor.TensorVariable without a random method, and not set in point, _compile_theano_function would be called, using as one of its variables, a theano.tensor.TensorConstant. This lead to TypeError: ("Constants not allowed in param list",...) exceptions being raised. The fix was to skip the inclusion into the givens dictionary of named nodes that were instances of theano.tensor.TensorConstant, because their value would already be available for theano during the function compilation. * Fixed another bug which was similar to the theano.tensor.TensorConstant, but it occurred on theano.tensor.sharedvar.SharedVariable instances. The error that was raised was similar, SharedVariables cannot be supplied as raw input to theano.function. The fix was the same as for TensorConstants, skip them when constructing the givens dictionary. * Guarded against a potencial bug. In draw_values, when skipping for TensorConstant and SharedVariable types, these nodes could be added to the stack again later because their names would not be in givens.keys(). To counter that, a separate set, `stored`, with the names of nodes that are either stored in givens or whos values should be available to theano.function, is used to chose which nodes to add to the stack. * Syntax change based on twiecki's comment. * Extended RELEASE-NOTES.md to also mention the sharedvar.ShareVariable fix.
1 parent 7d327b3 commit f833a14

File tree

4 files changed

+132
-34
lines changed

4 files changed

+132
-34
lines changed

RELEASE-NOTES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
- Add `offset` kwarg to `.glm`.
2121
- Changed the `compare` function to accept a dictionary of model-trace pairs instead of two separate lists of models and traces.
2222
- add test and support for creating multivariate mixture and mixture of mixtures
23+
- `distribution.draw_values`, now is also able to draw values from conditionally dependent RVs, such as autotransformed RVs (Refer to PR #2902).
2324

2425
### Fixes
2526

2627
- `VonMises` does not overflow for large values of kappa. i0 and i1 have been removed and we now use log_i0 to compute the logp.
2728
- The bandwidth for KDE plots is computed using a modified version of Scott's rule. The new version uses entropy instead of standard deviation. This works better for multimodal distributions. Functions using KDE plots has a new argument `bw` controlling the bandwidth.
2829
- fix PyMC3 variable is not replaced if provided in more_replacements (#2890)
30+
- Fix for issue #2900. For many situations, named node-inputs do not have a `random` method, while some intermediate node may have it. This meant that if the named node-input at the leaf of the graph did not have a fixed value, `theano` would try to compile it and fail to find inputs, raising a `theano.gof.fg.MissingInputError`. This was fixed by going through the theano variable's owner inputs graph, trying to get intermediate named-nodes values if the leafs had failed.
31+
- In `distribution.draw_values`, some named nodes could be `theano.tensor.TensorConstant`s or `theano.tensor.sharedvar.SharedVariable`s. Nevertheless, in `distribution._draw_value`, these would be passed to `distribution._compile_theano_function` as if they were `theano.tensor.TensorVariable`s. This could lead to the following exceptions `TypeError: ('Constants not allowed in param list', ...)` or `TypeError: Cannot use a shared variable (...)`. The fix was to not add `theano.tensor.TensorConstant` or `theano.tensor.sharedvar.SharedVariable` named nodes into the `givens` dict that could be used in `distribution._compile_theano_function`.
2932

3033
### Deprecations
3134

pymc3/distributions/distribution.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from theano import function
55
import theano
66
from ..memoize import memoize
7-
from ..model import Model, get_named_nodes, FreeRV, ObservedRV
7+
from ..model import Model, get_named_nodes_and_relations, FreeRV, ObservedRV
88
from ..vartypes import string_types
99

1010
__all__ = ['DensityDist', 'Distribution', 'Continuous', 'Discrete',
@@ -229,17 +229,72 @@ def draw_values(params, point=None):
229229
230230
"""
231231
# Distribution parameters may be nodes which have named node-inputs
232-
# specified in the point. Need to find the node-inputs to replace them.
233-
givens = {}
232+
# specified in the point. Need to find the node-inputs, their
233+
# parents and children to replace them.
234+
leaf_nodes = {}
235+
named_nodes_parents = {}
236+
named_nodes_children = {}
234237
for param in params:
235238
if hasattr(param, 'name'):
236-
named_nodes = get_named_nodes(param)
237-
if param.name in named_nodes:
238-
named_nodes.pop(param.name)
239-
for name, node in named_nodes.items():
240-
if not isinstance(node, (tt.sharedvar.SharedVariable,
241-
tt.TensorConstant)):
242-
givens[name] = (node, _draw_value(node, point=point))
239+
# Get the named nodes under the `param` node
240+
nn, nnp, nnc = get_named_nodes_and_relations(param)
241+
leaf_nodes.update(nn)
242+
# Update the discovered parental relationships
243+
for k in nnp.keys():
244+
if k not in named_nodes_parents.keys():
245+
named_nodes_parents[k] = nnp[k]
246+
else:
247+
named_nodes_parents[k].update(nnp[k])
248+
# Update the discovered child relationships
249+
for k in nnc.keys():
250+
if k not in named_nodes_children.keys():
251+
named_nodes_children[k] = nnc[k]
252+
else:
253+
named_nodes_children[k].update(nnc[k])
254+
255+
# Init givens and the stack of nodes to try to `_draw_value` from
256+
givens = {}
257+
stored = set([]) # Some nodes
258+
stack = list(leaf_nodes.values()) # A queue would be more appropriate
259+
while stack:
260+
next_ = stack.pop(0)
261+
if next_ in stored:
262+
# If the node already has a givens value, skip it
263+
continue
264+
elif isinstance(next_, (tt.TensorConstant,
265+
tt.sharedvar.SharedVariable)):
266+
# If the node is a theano.tensor.TensorConstant or a
267+
# theano.tensor.sharedvar.SharedVariable, its value will be
268+
# available automatically in _compile_theano_function so
269+
# we can skip it. Furthermore, if this node was treated as a
270+
# TensorVariable that should be compiled by theano in
271+
# _compile_theano_function, it would raise a `TypeError:
272+
# ('Constants not allowed in param list', ...)` for
273+
# TensorConstant, and a `TypeError: Cannot use a shared
274+
# variable (...) as explicit input` for SharedVariable.
275+
stored.add(next_.name)
276+
continue
277+
else:
278+
# If the node does not have a givens value, try to draw it.
279+
# The named node's children givens values must also be taken
280+
# into account.
281+
children = named_nodes_children[next_]
282+
temp_givens = [givens[k] for k in givens.keys() if k in children]
283+
try:
284+
# This may fail for autotransformed RVs, which don't
285+
# have the random method
286+
givens[next_.name] = (next_, _draw_value(next_,
287+
point=point,
288+
givens=temp_givens))
289+
stored.add(next_.name)
290+
except theano.gof.fg.MissingInputError:
291+
# The node failed, so we must add the node's parents to
292+
# the stack of nodes to try to draw from. We exclude the
293+
# nodes in the `params` list.
294+
stack.extend([node for node in named_nodes_parents[next_]
295+
if node is not None and
296+
node.name not in stored and
297+
node not in params])
243298
values = []
244299
for param in params:
245300
values.append(_draw_value(param, point=point, givens=givens.values()))

pymc3/model.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,29 +78,71 @@ def incorporate_methods(source, destination, methods, default=None,
7878
else:
7979
setattr(destination, method, None)
8080

81-
82-
def get_named_nodes(graph):
83-
"""Get the named nodes in a theano graph
84-
(i.e., nodes whose name attribute is not None).
81+
def get_named_nodes_and_relations(graph):
82+
"""Get the named nodes in a theano graph (i.e., nodes whose name
83+
attribute is not None) along with their relationships (i.e., the
84+
node's named parents, and named children, while skipping unnamed
85+
intermediate nodes)
8586
8687
Parameters
8788
----------
8889
graph - a theano node
8990
9091
Returns:
91-
A dictionary of name:node pairs.
92+
leaf_nodes: A dictionary of name:node pairs, of the named nodes that
93+
are also leafs of the graph
94+
node_parents: A dictionary of node:set([parents]) pairs. Each key is
95+
a theano named node, and the corresponding value is the set of
96+
theano named nodes that are parents of the node. These parental
97+
relations skip unnamed intermediate nodes.
98+
node_children: A dictionary of node:set([children]) pairs. Each key
99+
is a theano named node, and the corresponding value is the set
100+
of theano named nodes that are children of the node. These child
101+
relations skip unnamed intermediate nodes.
102+
92103
"""
93-
return _get_named_nodes(graph, {})
94-
95-
96-
def _get_named_nodes(graph, nodes):
97-
if graph.owner is None:
98-
if graph.name is not None:
99-
nodes.update({graph.name: graph})
104+
if graph.name is not None:
105+
node_parents = {graph: set()}
106+
node_children = {graph: set()}
100107
else:
108+
node_parents = {}
109+
node_children = {}
110+
return _get_named_nodes_and_relations(graph, None, {}, node_parents, node_children)
111+
112+
def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
113+
node_parents, node_children):
114+
if graph.owner is None: # Leaf node
115+
if graph.name is not None: # Named leaf node
116+
leaf_nodes.update({graph.name: graph})
117+
if parent is not None: # Is None for the root node
118+
try:
119+
node_parents[graph].add(parent)
120+
except KeyError:
121+
node_parents[graph] = set([parent])
122+
node_children[parent].add(graph)
123+
# Flag that the leaf node has no children
124+
node_children[graph] = set()
125+
else: # Intermediate node
126+
if graph.name is not None: # Intermediate named node
127+
if parent is not None: # Is only None for the root node
128+
try:
129+
node_parents[graph].add(parent)
130+
except KeyError:
131+
node_parents[graph] = set([parent])
132+
node_children[parent].add(graph)
133+
# The current node will be set as the parent of the next
134+
# nodes only if it is a named node
135+
parent = graph
136+
# Init the nodes children to an empty set
137+
node_children[graph] = set()
101138
for i in graph.owner.inputs:
102-
nodes.update(_get_named_nodes(i, nodes))
103-
return nodes
139+
temp_nodes, temp_inter, temp_tree = \
140+
_get_named_nodes_and_relations(i, parent, leaf_nodes,
141+
node_parents, node_children)
142+
leaf_nodes.update(temp_nodes)
143+
node_parents.update(temp_inter)
144+
node_children.update(temp_tree)
145+
return leaf_nodes, node_parents, node_children
104146

105147

106148
class Context(object):

pymc3/tests/test_random.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,11 @@ def test_dep_vars(self):
8080
point = {'a': np.array([1., 2.])}
8181
npt.assert_equal(draw_values([a], point=point), [point['a']])
8282

83-
with pytest.raises(theano.gof.MissingInputError):
84-
draw_values([a])
85-
86-
# We need the untransformed vars
87-
with pytest.raises(theano.gof.MissingInputError):
88-
draw_values([a], point={'sd': np.array([2., 3.])})
89-
90-
val1 = draw_values([a], point={'sd_log__': np.array([2., 3.])})[0]
91-
val2 = draw_values([a], point={'sd_log__': np.array([2., 3.])})[0]
92-
assert np.all(val1 != val2)
83+
val1 = draw_values([a])[0]
84+
val2 = draw_values([a], point={'sd': np.array([2., 3.])})[0]
85+
val3 = draw_values([a], point={'sd_log__': np.array([2., 3.])})[0]
86+
val4 = draw_values([a], point={'sd_log__': np.array([2., 3.])})[0]
87+
88+
assert all([np.all(val1 != val2), np.all(val1 != val3),
89+
np.all(val1 != val4), np.all(val2 != val3),
90+
np.all(val2 != val4), np.all(val3 != val4)])

0 commit comments

Comments
 (0)