Skip to content

Commit d70fe9b

Browse files
committed
Make zip strict
1 parent a14cb2b commit d70fe9b

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

pytensor/tensor/random/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ def explicit_expand_dims(
129129
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
130130

131131
batch_dims = [
132-
param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params)
132+
param.type.ndim - ndim_param
133+
for param, ndim_param in zip(params, ndim_params, strict=True)
133134
]
134135

135136
if size_length is not None:

tests/tensor/random/test_op.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def test_RandomVariable_basics(strict_test_value_flags):
7474
# `dtype` is respected
7575
rv = RandomVariable("normal", signature="(),()->()", dtype="int32")
7676
with config.change_flags(compute_test_value="off"):
77-
rv_out = rv()
77+
rv_out = rv(0, 0)
7878
assert rv_out.dtype == "int32"
79-
rv_out = rv(dtype="int64")
79+
rv_out = rv(0, 0, dtype="int64")
8080
assert rv_out.dtype == "int64"
8181

8282
with pytest.raises(
@@ -85,6 +85,10 @@ def test_RandomVariable_basics(strict_test_value_flags):
8585
):
8686
assert rv(dtype="float32").dtype == "float32"
8787

88+
# If we pass fewer arguments (and there are no defaults), an error is raised
89+
with pytest.raises(ValueError):
90+
rv(0)
91+
8892

8993
def test_RandomVariable_bcast(strict_test_value_flags):
9094
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)

0 commit comments

Comments
 (0)