|
20 | 20 | )
|
21 | 21 | from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
|
22 | 22 | from pytensor.tensor.random.utils import (
|
| 23 | + compute_batch_shape, |
23 | 24 | explicit_expand_dims,
|
24 | 25 | normalize_size_param,
|
25 | 26 | )
|
@@ -403,15 +404,14 @@ def vectorize_random_variable(
|
403 | 404 | original_expanded_dist_params, dict(zip(original_dist_params, dist_params))
|
404 | 405 | )
|
405 | 406 |
|
406 |
| - if len_old_size and equal_computations([old_size], [size]): |
| 407 | + new_ndim = dist_params[0].type.ndim - original_expanded_dist_params[0].type.ndim |
| 408 | + |
| 409 | + if new_ndim and len_old_size and equal_computations([old_size], [size]): |
407 | 410 | # If the original RV had a size variable and a new one has not been provided,
|
408 | 411 | # we need to define a new size as the concatenation of the original size dimensions
|
409 | 412 | # and the novel ones implied by new broadcasted batched parameters dimensions.
|
410 |
| - # We use the first broadcasted batch dimension for reference. |
411 |
| - bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0] |
412 |
| - new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size |
413 |
| - if new_param_ndim >= 0: |
414 |
| - new_size_dims = bcasted_param.shape[:new_param_ndim] |
415 |
| - size = concatenate([new_size_dims, size]) |
| 413 | + broadcasted_batch_shape = compute_batch_shape(dist_params, op.ndims_params) |
| 414 | + new_size_dims = broadcasted_batch_shape[:new_ndim] |
| 415 | + size = concatenate([new_size_dims, size]) |
416 | 416 |
|
417 | 417 | return op.make_node(rng, size, dtype, *dist_params)
|
0 commit comments