Skip to content

Commit e827311

Browse files
committed
Do not autouse test_value flag fixture
1 parent 62cee00 commit e827311

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

tests/tensor/random/test_op.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
from pytensor.tensor.type import all_dtypes, iscalar, tensor
1414

1515

16-
@pytest.fixture(scope="module", autouse=True)
17-
def set_pytensor_flags():
16+
@pytest.fixture(scope="function", autouse=False)
17+
def strict_test_value_flags():
1818
with config.change_flags(cxx="", compute_test_value="raise"):
1919
yield
2020

2121

22-
def test_RandomVariable_basics():
22+
def test_RandomVariable_basics(strict_test_value_flags):
2323
str_res = str(
2424
RandomVariable(
2525
"normal",
@@ -95,7 +95,7 @@ def test_RandomVariable_basics():
9595
grad(rv_out, [rv_node.inputs[0]])
9696

9797

98-
def test_RandomVariable_bcast():
98+
def test_RandomVariable_bcast(strict_test_value_flags):
9999
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
100100

101101
mu = tensor(dtype=config.floatX, shape=(1, None, None))
@@ -125,7 +125,7 @@ def test_RandomVariable_bcast():
125125
assert res.broadcastable == (True, False)
126126

127127

128-
def test_RandomVariable_bcast_specify_shape():
128+
def test_RandomVariable_bcast_specify_shape(strict_test_value_flags):
129129
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
130130

131131
s1 = pt.as_tensor(1, dtype=np.int64)
@@ -146,7 +146,7 @@ def test_RandomVariable_bcast_specify_shape():
146146
assert res.type.shape == (1, None, None, None, 1)
147147

148148

149-
def test_RandomVariable_floatX():
149+
def test_RandomVariable_floatX(strict_test_value_flags):
150150
test_rv_op = RandomVariable(
151151
"normal",
152152
0,
@@ -172,14 +172,14 @@ def test_RandomVariable_floatX():
172172
(3, default_rng, np.random.default_rng(3)),
173173
],
174174
)
175-
def test_random_maker_op(seed, maker_op, numpy_res):
175+
def test_random_maker_op(strict_test_value_flags, seed, maker_op, numpy_res):
176176
seed = pt.as_tensor_variable(seed)
177177
z = function(inputs=[], outputs=[maker_op(seed)])()
178178
aes_res = z[0]
179179
assert maker_op.random_type.values_eq(aes_res, numpy_res)
180180

181181

182-
def test_random_maker_ops_no_seed():
182+
def test_random_maker_ops_no_seed(strict_test_value_flags):
183183
# Testing the initialization when seed=None
184184
# Since internal states randomly generated,
185185
# we just check the output classes
@@ -192,7 +192,7 @@ def test_random_maker_ops_no_seed():
192192
assert isinstance(aes_res, np.random.Generator)
193193

194194

195-
def test_RandomVariable_incompatible_size():
195+
def test_RandomVariable_incompatible_size(strict_test_value_flags):
196196
rv_op = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
197197
with pytest.raises(
198198
ValueError, match="Size length is incompatible with batched dimensions"
@@ -216,7 +216,6 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
216216
return [dist_params[0].shape[-1]]
217217

218218

219-
@config.change_flags(compute_test_value="off")
220219
def test_multivariate_rv_infer_static_shape():
221220
"""Test that infer shape for multivariate random variable works when a parameter must be broadcasted."""
222221
mv_op = MultivariateRandomVariable()
@@ -244,9 +243,7 @@ def test_multivariate_rv_infer_static_shape():
244243

245244
def test_vectorize_node():
246245
vec = tensor(shape=(None,))
247-
vec.tag.test_value = [0, 0, 0]
248246
mat = tensor(shape=(None, None))
249-
mat.tag.test_value = [[0, 0, 0], [1, 1, 1]]
250247

251248
# Test without size
252249
node = normal(vec).owner
@@ -273,4 +270,6 @@ def test_vectorize_node():
273270
vect_node = vectorize_node(node, *new_inputs)
274271
assert vect_node.op is normal
275272
assert vect_node.inputs[3] is mat
276-
assert tuple(vect_node.inputs[1].eval({mat: mat.tag.test_value})) == (2, 3)
273+
assert tuple(
274+
vect_node.inputs[1].eval({mat: np.zeros((2, 3), dtype=config.floatX)})
275+
) == (2, 3)

0 commit comments

Comments
 (0)