13
13
from pytensor .tensor .type import all_dtypes , iscalar , tensor
14
14
15
15
16
- @pytest .fixture (scope = "module " , autouse = True )
17
- def set_pytensor_flags ():
16
+ @pytest .fixture (scope = "function " , autouse = False )
17
+ def strict_test_value_flags ():
18
18
with config .change_flags (cxx = "" , compute_test_value = "raise" ):
19
19
yield
20
20
21
21
22
- def test_RandomVariable_basics ():
22
+ def test_RandomVariable_basics (strict_test_value_flags ):
23
23
str_res = str (
24
24
RandomVariable (
25
25
"normal" ,
@@ -95,7 +95,7 @@ def test_RandomVariable_basics():
95
95
grad (rv_out , [rv_node .inputs [0 ]])
96
96
97
97
98
- def test_RandomVariable_bcast ():
98
+ def test_RandomVariable_bcast (strict_test_value_flags ):
99
99
rv = RandomVariable ("normal" , 0 , [0 , 0 ], config .floatX , inplace = True )
100
100
101
101
mu = tensor (dtype = config .floatX , shape = (1 , None , None ))
@@ -125,7 +125,7 @@ def test_RandomVariable_bcast():
125
125
assert res .broadcastable == (True , False )
126
126
127
127
128
- def test_RandomVariable_bcast_specify_shape ():
128
+ def test_RandomVariable_bcast_specify_shape (strict_test_value_flags ):
129
129
rv = RandomVariable ("normal" , 0 , [0 , 0 ], config .floatX , inplace = True )
130
130
131
131
s1 = pt .as_tensor (1 , dtype = np .int64 )
@@ -146,7 +146,7 @@ def test_RandomVariable_bcast_specify_shape():
146
146
assert res .type .shape == (1 , None , None , None , 1 )
147
147
148
148
149
- def test_RandomVariable_floatX ():
149
+ def test_RandomVariable_floatX (strict_test_value_flags ):
150
150
test_rv_op = RandomVariable (
151
151
"normal" ,
152
152
0 ,
@@ -172,14 +172,14 @@ def test_RandomVariable_floatX():
172
172
(3 , default_rng , np .random .default_rng (3 )),
173
173
],
174
174
)
175
- def test_random_maker_op (seed , maker_op , numpy_res ):
175
+ def test_random_maker_op (strict_test_value_flags , seed , maker_op , numpy_res ):
176
176
seed = pt .as_tensor_variable (seed )
177
177
z = function (inputs = [], outputs = [maker_op (seed )])()
178
178
aes_res = z [0 ]
179
179
assert maker_op .random_type .values_eq (aes_res , numpy_res )
180
180
181
181
182
- def test_random_maker_ops_no_seed ():
182
+ def test_random_maker_ops_no_seed (strict_test_value_flags ):
183
183
# Testing the initialization when seed=None
184
184
# Since internal states randomly generated,
185
185
# we just check the output classes
@@ -192,7 +192,7 @@ def test_random_maker_ops_no_seed():
192
192
assert isinstance (aes_res , np .random .Generator )
193
193
194
194
195
- def test_RandomVariable_incompatible_size ():
195
+ def test_RandomVariable_incompatible_size (strict_test_value_flags ):
196
196
rv_op = RandomVariable ("normal" , 0 , [0 , 0 ], config .floatX , inplace = True )
197
197
with pytest .raises (
198
198
ValueError , match = "Size length is incompatible with batched dimensions"
@@ -216,7 +216,6 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
216
216
return [dist_params [0 ].shape [- 1 ]]
217
217
218
218
219
- @config .change_flags (compute_test_value = "off" )
220
219
def test_multivariate_rv_infer_static_shape ():
221
220
"""Test that infer shape for multivariate random variable works when a parameter must be broadcasted."""
222
221
mv_op = MultivariateRandomVariable ()
@@ -244,9 +243,7 @@ def test_multivariate_rv_infer_static_shape():
244
243
245
244
def test_vectorize_node ():
246
245
vec = tensor (shape = (None ,))
247
- vec .tag .test_value = [0 , 0 , 0 ]
248
246
mat = tensor (shape = (None , None ))
249
- mat .tag .test_value = [[0 , 0 , 0 ], [1 , 1 , 1 ]]
250
247
251
248
# Test without size
252
249
node = normal (vec ).owner
@@ -273,4 +270,6 @@ def test_vectorize_node():
273
270
vect_node = vectorize_node (node , * new_inputs )
274
271
assert vect_node .op is normal
275
272
assert vect_node .inputs [3 ] is mat
276
- assert tuple (vect_node .inputs [1 ].eval ({mat : mat .tag .test_value })) == (2 , 3 )
273
+ assert tuple (
274
+ vect_node .inputs [1 ].eval ({mat : np .zeros ((2 , 3 ), dtype = config .floatX )})
275
+ ) == (2 , 3 )
0 commit comments