Skip to content

Commit a745bc1

Browse files
committed
Fix broadcasting bug in vectorize of RandomVariables
1 parent 044910b commit a745bc1

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

pytensor/tensor/random/op.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
2222
from pytensor.tensor.random.utils import (
23+
compute_batch_shape,
2324
explicit_expand_dims,
2425
normalize_size_param,
2526
)
@@ -403,15 +404,14 @@ def vectorize_random_variable(
403404
original_expanded_dist_params, dict(zip(original_dist_params, dist_params))
404405
)
405406

406-
if len_old_size and equal_computations([old_size], [size]):
407+
new_ndim = dist_params[0].type.ndim - original_expanded_dist_params[0].type.ndim
408+
409+
if new_ndim and len_old_size and equal_computations([old_size], [size]):
407410
# If the original RV had a size variable and a new one has not been provided,
408411
# we need to define a new size as the concatenation of the original size dimensions
409412
# and the novel ones implied by new broadcasted batched parameters dimensions.
410-
# We use the first broadcasted batch dimension for reference.
411-
bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0]
412-
new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size
413-
if new_param_ndim >= 0:
414-
new_size_dims = bcasted_param.shape[:new_param_ndim]
415-
size = concatenate([new_size_dims, size])
413+
broadcasted_batch_shape = compute_batch_shape(dist_params, op.ndims_params)
414+
new_size_dims = broadcasted_batch_shape[:new_ndim]
415+
size = concatenate([new_size_dims, size])
416416

417417
return op.make_node(rng, size, dtype, *dist_params)

pytensor/tensor/random/utils.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pytensor.scalar import ScalarVariable
1212
from pytensor.tensor import get_vector_length
1313
from pytensor.tensor.basic import as_tensor_variable, cast, constant
14-
from pytensor.tensor.extra_ops import broadcast_to
14+
from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
1515
from pytensor.tensor.math import maximum
1616
from pytensor.tensor.shape import shape_padleft, specify_shape
1717
from pytensor.tensor.type import int_dtypes
@@ -149,6 +149,15 @@ def explicit_expand_dims(
149149
return new_params
150150

151151

152+
def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
153+
params = explicit_expand_dims(params, ndims_params)
154+
batch_params = [
155+
param[(..., *(0,) * core_ndim)]
156+
for param, core_ndim in zip(params, ndims_params)
157+
]
158+
return broadcast_arrays(*batch_params)[0].shape
159+
160+
152161
def normalize_size_param(
153162
size: int | np.ndarray | Variable | Sequence | None,
154163
) -> Variable:

tests/tensor/random/test_op.py

+8
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,14 @@ def test_vectorize_node():
292292
assert vect_node.op is normal
293293
assert vect_node.default_output().type.shape == (10, 5)
294294

295+
node = normal(vec, size=(5,)).owner
296+
new_inputs = node.inputs.copy()
297+
new_inputs[3] = tensor("mu", shape=(1, 5)) # mu
298+
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
299+
vect_node = vectorize_node(node, *new_inputs)
300+
assert vect_node.op is normal
301+
assert vect_node.default_output().type.shape == (10, 5)
302+
295303
# Test parameter broadcasting with expanding size
296304
node = normal(vec, size=(2, 5)).owner
297305
new_inputs = node.inputs.copy()

0 commit comments

Comments
 (0)