Skip to content

Commit 237f54f

Browse files
committed
Fix shape_inference of ChoiceRV when param_shapes are provided
1 parent cc05486 commit 237f54f

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

pytensor/tensor/random/basic.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1990,11 +1990,12 @@ def _supp_shape_from_params(self, *args, **kwargs):
19901990
raise NotImplementedError()
19911991

19921992
def _infer_shape(self, size, dist_params, param_shapes=None):
1993-
(a, p, _) = dist_params
1994-
1993+
a, p, _ = dist_params
19951994
if isinstance(p.type, pytensor.tensor.type_other.NoneTypeT):
1995+
param_shapes = param_shapes[:1] if param_shapes is not None else None
19961996
shape = super()._infer_shape(size, (a,), param_shapes)
19971997
else:
1998+
param_shapes = param_shapes[:2] if param_shapes is not None else None
19981999
shape = super()._infer_shape(size, (a, p), param_shapes)
19992000

20002001
return shape

tests/tensor/random/test_basic.py

+13
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,19 @@ def test_choice_samples():
13901390
compare_sample_values(choice, at.as_tensor_variable([1, 2, 3]), 2, replace=True)
13911391

13921392

1393+
def test_choice_infer_shape():
1394+
node = choice([0, 1]).owner
1395+
res = node.op._infer_shape((), node.inputs[3:], None)
1396+
assert tuple(res.eval()) == ()
1397+
1398+
node = choice([0, 1]).owner
1399+
# The param_shape of a NoneConst is None, during shape_inference
1400+
res = node.op._infer_shape(
1401+
(), node.inputs[3:], (node.inputs[3].shape, None, node.inputs[5].shape)
1402+
)
1403+
assert tuple(res.eval()) == ()
1404+
1405+
13931406
def test_permutation_samples():
13941407
compare_sample_values(
13951408
permutation,

0 commit comments

Comments
 (0)