Skip to content

Commit cc05486

Browse files
committed
Apply casting in as_tensor_variable in normalize_size_param
This allows PyTensor to infer more broadcastable patterns, by placing the casting inside the MakeVector Op
1 parent da5281b commit cc05486

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

pytensor/tensor/random/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def normalize_size_param(
134134
"Parameter size must be None, an integer, or a sequence with integers."
135135
)
136136
else:
137-
size = cast(as_tensor_variable(size, ndim=1), "int64")
137+
size = cast(as_tensor_variable(size, ndim=1, dtype="int64"), "int64")
138138

139139
if not isinstance(size, Constant):
140140
# This should help ensure that the length of non-constant `size`s

tests/tensor/random/test_op.py

+3
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ def test_RandomVariable_bcast():
148148
res = rv(0, 1, size=at.as_tensor(1, dtype=np.int64))
149149
assert res.broadcastable == (True,)
150150

151+
res = rv(0, 1, size=(at.as_tensor(1, dtype=np.int32), s3))
152+
assert res.broadcastable == (True, False)
153+
151154

152155
def test_RandomVariable_bcast_specify_shape():
153156
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)

0 commit comments

Comments
 (0)