Skip to content

Commit bae694d

Browse files
committed
Align batch_ndims of RandomVariable inputs in make_node
Also fixes bug in vectorize of RandomVariable, where it was wrongly using the first parameter to infer the new size dims, even though that was not broadcasted, only expanded with new dims.
1 parent 16a4f3b commit bae694d

File tree

3 files changed

+58
-50
lines changed

3 files changed

+58
-50
lines changed

pytensor/tensor/random/op.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytensor.configdefaults import config
88
from pytensor.graph.basic import Apply, Variable, equal_computations
99
from pytensor.graph.op import Op
10-
from pytensor.graph.replace import _vectorize_node, vectorize_graph
10+
from pytensor.graph.replace import _vectorize_node
1111
from pytensor.misc.safe_asarray import _asarray
1212
from pytensor.scalar import ScalarVariable
1313
from pytensor.tensor.basic import (
@@ -20,6 +20,7 @@
2020
)
2121
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
2222
from pytensor.tensor.random.utils import (
23+
compute_batch_shape,
2324
explicit_expand_dims,
2425
normalize_size_param,
2526
)
@@ -130,6 +131,9 @@ def __str__(self):
130131
props_str = ", ".join(f"{getattr(self, prop)}" for prop in self.__props__[1:])
131132
return f"{self.name}_rv{{{props_str}}}"
132133

134+
def batch_ndim(self, node):
135+
return node.default_output().type.ndim - self.ndim_supp
136+
133137
def _infer_shape(
134138
self,
135139
size: TensorVariable,
@@ -298,8 +302,12 @@ def make_node(self, rng, size, dtype, *dist_params):
298302
dtype_idx = constant(dtype, dtype="int64")
299303
dtype = all_dtypes[dtype_idx.data]
300304

301-
outtype = TensorType(dtype=dtype, shape=static_shape)
302-
out_var = outtype()
305+
out_var = TensorType(dtype=dtype, shape=static_shape)()
306+
307+
# Add expand_dims to align batch dimensions
308+
dist_params = explicit_expand_dims(
309+
dist_params, self.ndims_params, size_length=size.type.shape[0]
310+
)
303311
inputs = (rng, size, dtype_idx, *dist_params)
304312
outputs = (rng.type(), out_var)
305313

@@ -390,28 +398,22 @@ def vectorize_random_variable(
390398
# We extend it to accommodate the new input batch dimensions.
391399
# Otherwise, we assume the new size already has the right values
392400

393-
# Need to make parameters implicit broadcasting explicit
394-
original_dist_params = node.inputs[3:]
401+
old_dist_params = node.inputs[3:]
395402
old_size = node.inputs[1]
396403
len_old_size = get_vector_length(old_size)
397404

398-
original_expanded_dist_params = explicit_expand_dims(
399-
original_dist_params, op.ndims_params, len_old_size
400-
)
401-
# We call vectorize_graph to automatically handle any new explicit expand_dims
402-
dist_params = vectorize_graph(
403-
original_expanded_dist_params, dict(zip(original_dist_params, dist_params))
404-
)
405+
new_dist_params = explicit_expand_dims(dist_params, op.ndims_params)
405406

406407
if len_old_size and equal_computations([old_size], [size]):
407408
# If the original RV had a size variable and a new one has not been provided,
408409
# we need to define a new size as the concatenation of the original size dimensions
409410
# and the novel ones implied by new broadcasted batched parameters dimensions.
410-
# We use the first broadcasted batch dimension for reference.
411-
bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0]
412-
new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size
413-
if new_param_ndim >= 0:
414-
new_size_dims = bcasted_param.shape[:new_param_ndim]
411+
new_ndim = new_dist_params[0].type.ndim - old_dist_params[0].type.ndim
412+
if new_ndim >= 0:
413+
new_size = compute_batch_shape(
414+
new_dist_params, ndims_params=op.ndims_params
415+
)
416+
new_size_dims = new_size[:new_ndim]
415417
size = concatenate([new_size_dims, size])
416418

417-
return op.make_node(rng, size, dtype, *dist_params)
419+
return op.make_node(rng, size, dtype, *new_dist_params)

pytensor/tensor/random/utils.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pytensor.scalar import ScalarVariable
1212
from pytensor.tensor import get_vector_length
1313
from pytensor.tensor.basic import as_tensor_variable, cast, constant
14-
from pytensor.tensor.extra_ops import broadcast_to
14+
from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
1515
from pytensor.tensor.math import maximum
1616
from pytensor.tensor.shape import shape_padleft, specify_shape
1717
from pytensor.tensor.type import int_dtypes
@@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params):
123123

124124
def explicit_expand_dims(
125125
params: Sequence[TensorVariable],
126-
ndim_params: tuple[int],
126+
ndim_params: Sequence[int],
127127
size_length: int = 0,
128128
) -> list[TensorVariable]:
129129
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
@@ -149,6 +149,15 @@ def explicit_expand_dims(
149149
return new_params
150150

151151

152+
def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
153+
params = explicit_expand_dims(params, ndims_params)
154+
batch_params = [
155+
param[..., *[(0,) for _ in range(core_ndim)]]
156+
for param, core_ndim in zip(params, ndims_params)
157+
]
158+
return broadcast_arrays(*batch_params)[0].shape
159+
160+
152161
def normalize_size_param(
153162
size: int | np.ndarray | Variable | Sequence | None,
154163
) -> Variable:

tests/tensor/random/test_op.py

+27-30
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytensor.tensor as pt
55
from pytensor import config, function
66
from pytensor.gradient import NullTypeGradError, grad
7-
from pytensor.graph.replace import vectorize_node
7+
from pytensor.graph.replace import vectorize_graph
88
from pytensor.raise_op import Assert
99
from pytensor.tensor.math import eq
1010
from pytensor.tensor.random import normal
@@ -241,62 +241,59 @@ def test_multivariate_rv_infer_static_shape():
241241
assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3)
242242

243243

244-
def test_vectorize_node():
244+
def test_vectorize():
245245
vec = tensor(shape=(None,))
246246
mat = tensor(shape=(None, None))
247247

248248
# Test without size
249-
node = normal(vec).owner
250-
new_inputs = node.inputs.copy()
251-
new_inputs[3] = mat # mu
252-
vect_node = vectorize_node(node, *new_inputs)
249+
out = normal(vec)
250+
vect_node = vectorize_graph(out, {vec: mat}).owner
253251
assert vect_node.op is normal
254252
assert vect_node.inputs[3] is mat
255253

256254
# Test with size, new size provided
257-
node = normal(vec, size=(3,)).owner
258-
new_inputs = node.inputs.copy()
259-
new_inputs[1] = (2, 3) # size
260-
new_inputs[3] = mat # mu
261-
vect_node = vectorize_node(node, *new_inputs)
255+
size = pt.as_tensor(np.array((3,), dtype="int64"))
256+
out = normal(vec, size=size)
257+
vect_node = vectorize_graph(out, {vec: mat, size: (2, 3)}).owner
262258
assert vect_node.op is normal
263259
assert tuple(vect_node.inputs[1].eval()) == (2, 3)
264260
assert vect_node.inputs[3] is mat
265261

266262
# Test with size, new size not provided
267-
node = normal(vec, size=(3,)).owner
268-
new_inputs = node.inputs.copy()
269-
new_inputs[3] = mat # mu
270-
vect_node = vectorize_node(node, *new_inputs)
263+
out = normal(vec, size=(3,))
264+
vect_node = vectorize_graph(out, {vec: mat}).owner
271265
assert vect_node.op is normal
272266
assert vect_node.inputs[3] is mat
273267
assert tuple(
274268
vect_node.inputs[1].eval({mat: np.zeros((2, 3), dtype=config.floatX)})
275269
) == (2, 3)
276270

277271
# Test parameter broadcasting
278-
node = normal(vec).owner
279-
new_inputs = node.inputs.copy()
280-
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
281-
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
282-
vect_node = vectorize_node(node, *new_inputs)
272+
mu = vec
273+
sigma = pt.as_tensor(np.array(1.0))
274+
out = normal(mu, sigma)
275+
new_mu = tensor("mu", shape=(10, 5))
276+
new_sigma = tensor("sigma", shape=(10,))
277+
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
283278
assert vect_node.op is normal
284279
assert vect_node.default_output().type.shape == (10, 5)
285280

286281
# Test parameter broadcasting with non-expanding size
287-
node = normal(vec, size=(5,)).owner
288-
new_inputs = node.inputs.copy()
289-
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
290-
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
291-
vect_node = vectorize_node(node, *new_inputs)
282+
mu = vec
283+
sigma = pt.as_tensor(np.array(1.0))
284+
out = normal(mu, sigma, size=(5,))
285+
new_mu = tensor("mu", shape=(10, 5))
286+
new_sigma = tensor("sigma", shape=(10,))
287+
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
292288
assert vect_node.op is normal
293289
assert vect_node.default_output().type.shape == (10, 5)
294290

295291
# Test parameter broadcasting with expanding size
296-
node = normal(vec, size=(2, 5)).owner
297-
new_inputs = node.inputs.copy()
298-
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
299-
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
300-
vect_node = vectorize_node(node, *new_inputs)
292+
mu = vec
293+
sigma = pt.as_tensor(np.array(1.0))
294+
out = normal(mu, sigma, size=(2, 5))
295+
new_mu = tensor("mu", shape=(1, 5))
296+
new_sigma = tensor("sigma", shape=(10,))
297+
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
301298
assert vect_node.op is normal
302299
assert vect_node.default_output().type.shape == (10, 2, 5)

0 commit comments

Comments
 (0)