Skip to content

Commit 61c15af

Browse files
committed
Handle implicit broadcasting correctly in RandomVariable vectorization
1 parent e827311 commit 61c15af

File tree

3 files changed

+82
-8
lines changed

3 files changed

+82
-8
lines changed

pytensor/tensor/random/op.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytensor.configdefaults import config
99
from pytensor.graph.basic import Apply, Variable, equal_computations
1010
from pytensor.graph.op import Op
11-
from pytensor.graph.replace import _vectorize_node
11+
from pytensor.graph.replace import _vectorize_node, vectorize_graph
1212
from pytensor.misc.safe_asarray import _asarray
1313
from pytensor.scalar import ScalarVariable
1414
from pytensor.tensor.basic import (
@@ -20,7 +20,10 @@
2020
infer_static_shape,
2121
)
2222
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
23-
from pytensor.tensor.random.utils import broadcast_params, normalize_size_param
23+
from pytensor.tensor.random.utils import (
24+
explicit_expand_dims,
25+
normalize_size_param,
26+
)
2427
from pytensor.tensor.shape import shape_tuple
2528
from pytensor.tensor.type import TensorType, all_dtypes
2629
from pytensor.tensor.type_other import NoneConst
@@ -387,10 +390,26 @@ def vectorize_random_variable(
387390
# If size was provided originally and a new size hasn't been provided,
388391
# We extend it to accommodate the new input batch dimensions.
389392
# Otherwise, we assume the new size already has the right values
393+
394+
# Need to make parameters implicit broadcasting explicit
395+
original_dist_params = node.inputs[3:]
390396
old_size = node.inputs[1]
391397
len_old_size = get_vector_length(old_size)
398+
399+
original_expanded_dist_params = explicit_expand_dims(
400+
original_dist_params, op.ndims_params, len_old_size
401+
)
402+
# We call vectorize_graph to automatically handle any new explicit expand_dims
403+
dist_params = vectorize_graph(
404+
original_expanded_dist_params, dict(zip(original_dist_params, dist_params))
405+
)
406+
392407
if len_old_size and equal_computations([old_size], [size]):
393-
bcasted_param = broadcast_params(dist_params, op.ndims_params)[0]
408+
# If the original RV had a size variable and a new one has not been provided,
409+
# we need to define a new size as the concatenation of the original size dimensions
410+
# and the novel ones implied by new broadcasted batched parameters dimensions.
411+
# We use the first broadcasted batch dimension for reference.
412+
bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0]
394413
new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size
395414
if new_param_ndim >= 0:
396415
new_size_dims = bcasted_param.shape[:new_param_ndim]

pytensor/tensor/random/utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pytensor.tensor.basic import as_tensor_variable, cast, constant
1414
from pytensor.tensor.extra_ops import broadcast_to
1515
from pytensor.tensor.math import maximum
16-
from pytensor.tensor.shape import specify_shape
16+
from pytensor.tensor.shape import shape_padleft, specify_shape
1717
from pytensor.tensor.type import int_dtypes
1818
from pytensor.tensor.variable import TensorVariable
1919

@@ -121,6 +121,34 @@ def broadcast_params(params, ndims_params):
121121
return bcast_params
122122

123123

124+
def explicit_expand_dims(
125+
params: Sequence[TensorVariable],
126+
ndim_params: tuple[int],
127+
size_length: int = 0,
128+
) -> list[TensorVariable]:
129+
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
130+
131+
batch_dims = [
132+
param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params)
133+
]
134+
135+
if size_length:
136+
# NOTE: PyTensor is currently treating zero-length size as size=None, which is not what Numpy does
137+
# See: https://github.com/pymc-devs/pytensor/issues/568
138+
max_batch_dims = size_length
139+
else:
140+
max_batch_dims = max(batch_dims)
141+
142+
new_params = []
143+
for new_param, batch_dim in zip(params, batch_dims):
144+
missing_dims = max_batch_dims - batch_dim
145+
if missing_dims:
146+
new_param = shape_padleft(new_param, missing_dims)
147+
new_params.append(new_param)
148+
149+
return new_params
150+
151+
124152
def normalize_size_param(
125153
size: Optional[Union[int, np.ndarray, Variable, Sequence]],
126154
) -> Variable:

tests/tensor/random/test_op.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,16 +248,16 @@ def test_vectorize_node():
248248
# Test without size
249249
node = normal(vec).owner
250250
new_inputs = node.inputs.copy()
251-
new_inputs[3] = mat
251+
new_inputs[3] = mat # mu
252252
vect_node = vectorize_node(node, *new_inputs)
253253
assert vect_node.op is normal
254254
assert vect_node.inputs[3] is mat
255255

256256
# Test with size, new size provided
257257
node = normal(vec, size=(3,)).owner
258258
new_inputs = node.inputs.copy()
259-
new_inputs[1] = (2, 3)
260-
new_inputs[3] = mat
259+
new_inputs[1] = (2, 3) # size
260+
new_inputs[3] = mat # mu
261261
vect_node = vectorize_node(node, *new_inputs)
262262
assert vect_node.op is normal
263263
assert tuple(vect_node.inputs[1].eval()) == (2, 3)
@@ -266,10 +266,37 @@ def test_vectorize_node():
266266
# Test with size, new size not provided
267267
node = normal(vec, size=(3,)).owner
268268
new_inputs = node.inputs.copy()
269-
new_inputs[3] = mat
269+
new_inputs[3] = mat # mu
270270
vect_node = vectorize_node(node, *new_inputs)
271271
assert vect_node.op is normal
272272
assert vect_node.inputs[3] is mat
273273
assert tuple(
274274
vect_node.inputs[1].eval({mat: np.zeros((2, 3), dtype=config.floatX)})
275275
) == (2, 3)
276+
277+
# Test parameter broadcasting
278+
node = normal(vec).owner
279+
new_inputs = node.inputs.copy()
280+
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
281+
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
282+
vect_node = vectorize_node(node, *new_inputs)
283+
assert vect_node.op is normal
284+
assert vect_node.default_output().type.shape == (10, 5)
285+
286+
# Test parameter broadcasting with non-expanding size
287+
node = normal(vec, size=(5,)).owner
288+
new_inputs = node.inputs.copy()
289+
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
290+
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
291+
vect_node = vectorize_node(node, *new_inputs)
292+
assert vect_node.op is normal
293+
assert vect_node.default_output().type.shape == (10, 5)
294+
295+
# Test parameter broadcasting with expanding size
296+
node = normal(vec, size=(2, 5)).owner
297+
new_inputs = node.inputs.copy()
298+
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
299+
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
300+
vect_node = vectorize_node(node, *new_inputs)
301+
assert vect_node.op is normal
302+
assert vect_node.default_output().type.shape == (10, 2, 5)

0 commit comments

Comments
 (0)