9
9
from pytensor import function
10
10
from pytensor .configdefaults import config
11
11
from pytensor .tensor .basic import as_tensor_variable
12
- from pytensor .tensor .math import _allclose
12
+ from pytensor .tensor .math import _get_atol_rtol
13
13
from pytensor .tensor .nlinalg import (
14
14
SVD ,
15
15
Eig ,
@@ -60,7 +60,8 @@ def test_pseudoinverse_correctness():
60
60
assert ri .dtype == r .dtype
61
61
# Note that pseudoinverse can be quite imprecise so I prefer to compare
62
62
# the result with what np.linalg returns
63
- assert _allclose (ri , np .linalg .pinv (r ))
63
+ atol_ , rtol_ = _get_atol_rtol (ri , np .linalg .pinv (r ))
64
+ assert np .allclose (ri , np .linalg .pinv (r ), atol = atol_ , rtol = rtol_ )
64
65
65
66
66
67
def test_pseudoinverse_grad ():
@@ -92,8 +93,11 @@ def test_inverse_correctness(self):
92
93
rir = np .dot (ri , r )
93
94
rri = np .dot (r , ri )
94
95
95
- assert _allclose (np .identity (4 ), rir ), rir
96
- assert _allclose (np .identity (4 ), rri ), rri
96
+ atol_ , rtol_ = _get_atol_rtol (np .identity (4 ), rir )
97
+ assert np .allclose (np .identity (4 ), rir , atol = atol_ , rtol = rtol_ ), rir
98
+
99
+ atol_ , rtol_ = _get_atol_rtol (np .identity (4 ), rri )
100
+ assert np .allclose (np .identity (4 ), rri , atol = atol_ , rtol = rtol_ ), rri
97
101
98
102
def test_infer_shape (self ):
99
103
r = self .rng .standard_normal ((4 , 4 )).astype (config .floatX )
@@ -119,7 +123,8 @@ def test_matrix_dot():
119
123
for r in rs [1 :]:
120
124
numpy_sol = np .dot (numpy_sol , r )
121
125
122
- assert _allclose (numpy_sol , pytensor_sol )
126
+ atol_ , rtol_ = _get_atol_rtol (numpy_sol , pytensor_sol )
127
+ assert np .allclose (numpy_sol , pytensor_sol , atol = atol_ , rtol = rtol_ )
123
128
124
129
125
130
def test_qr_modes ():
@@ -131,23 +136,34 @@ def test_qr_modes():
131
136
f = function ([A ], qr (A ))
132
137
t_qr = f (a )
133
138
n_qr = np .linalg .qr (a )
134
- assert _allclose (n_qr , t_qr )
139
+ atol_ , rtol_ = _get_atol_rtol (np .asarray (n_qr ), np .asarray (t_qr ))
140
+ assert np .allclose (np .asarray (n_qr ), np .asarray (t_qr ), atol = atol_ , rtol = rtol_ )
135
141
136
142
for mode in ["reduced" , "r" , "raw" ]:
137
143
f = function ([A ], qr (A , mode ))
138
144
t_qr = f (a )
139
145
n_qr = np .linalg .qr (a , mode )
140
146
if isinstance (n_qr , list | tuple ):
141
- assert _allclose (n_qr [0 ], t_qr [0 ])
142
- assert _allclose (n_qr [1 ], t_qr [1 ])
147
+ atol_ , rtol_ = _get_atol_rtol (np .asarray (n_qr [0 ]), np .asarray (t_qr [0 ]))
148
+ assert np .allclose (
149
+ np .asarray (n_qr [0 ]), np .asarray (t_qr [0 ]), atol = atol_ , rtol = rtol_
150
+ )
151
+ atol_ , rtol_ = _get_atol_rtol (np .asarray (n_qr [1 ]), np .asarray (t_qr [1 ]))
152
+ assert np .allclose (
153
+ np .asarray (n_qr [1 ]), np .asarray (t_qr [1 ]), atol = atol_ , rtol = rtol_
154
+ )
143
155
else :
144
- assert _allclose (n_qr , t_qr )
156
+ atol_ , rtol_ = _get_atol_rtol (np .asarray (n_qr ), np .asarray (t_qr ))
157
+ assert np .allclose (
158
+ np .asarray (n_qr ), np .asarray (t_qr ), atol = atol_ , rtol = rtol_
159
+ )
145
160
146
161
try :
147
162
n_qr = np .linalg .qr (a , "complete" )
148
163
f = function ([A ], qr (A , "complete" ))
149
164
t_qr = f (a )
150
- assert _allclose (n_qr , t_qr )
165
+ atol_ , rtol_ = _get_atol_rtol (np .asarray (n_qr ), np .asarray (t_qr ))
166
+ assert np .allclose (np .asarray (n_qr ), np .asarray (t_qr ), atol = atol_ , rtol = rtol_ )
151
167
except TypeError as e :
152
168
assert "name 'complete' is not defined" in str (e )
153
169
@@ -199,7 +215,8 @@ def test_svd(self, core_shape, full_matrix, compute_uv, batched, test_imag):
199
215
np_outputs = np_outputs if isinstance (np_outputs , tuple ) else [np_outputs ]
200
216
201
217
for np_val , pt_val in zip (np_outputs , pt_outputs ):
202
- assert _allclose (np_val , pt_val )
218
+ atol_ , rtol_ = _get_atol_rtol (np_val , pt_val )
219
+ assert np .allclose (np_val , pt_val , atol = atol_ , rtol = rtol_ )
203
220
204
221
def test_svd_infer_shape (self ):
205
222
self .validate_shape ((4 , 4 ), full_matrices = True , compute_uv = True )
@@ -306,7 +323,8 @@ def test_tensorsolve():
306
323
307
324
n_x = np .linalg .tensorsolve (a , b )
308
325
t_x = fn (a , b )
309
- assert _allclose (n_x , t_x )
326
+ atol_ , rtol_ = _get_atol_rtol (n_x , np .asarray (t_x ))
327
+ assert np .allclose (n_x , t_x , atol = atol_ , rtol = rtol_ )
310
328
311
329
# check the type upcast now
312
330
C = tensor4 ("C" , dtype = "float32" )
@@ -319,7 +337,8 @@ def test_tensorsolve():
319
337
d = rng .random ((2 * 3 , 4 )).astype ("float64" )
320
338
n_y = np .linalg .tensorsolve (c , d )
321
339
t_y = fn (c , d )
322
- assert _allclose (n_y , t_y )
340
+ atol_ , rtol_ = _get_atol_rtol (n_y , np .asarray (t_y ))
341
+ assert np .allclose (n_y , t_y , atol = atol_ , rtol = rtol_ )
323
342
assert n_y .dtype == Y .dtype
324
343
325
344
# check the type upcast now
@@ -333,7 +352,8 @@ def test_tensorsolve():
333
352
f = rng .random ((2 * 3 , 4 )).astype ("float64" )
334
353
n_z = np .linalg .tensorsolve (e , f )
335
354
t_z = fn (e , f )
336
- assert _allclose (n_z , t_z )
355
+ atol_ , rtol_ = _get_atol_rtol (n_z , np .asarray (t_z ))
356
+ assert np .allclose (n_z , t_z , atol = atol_ , rtol = rtol_ )
337
357
assert n_z .dtype == Z .dtype
338
358
339
359
@@ -653,7 +673,8 @@ def test_eval(self):
653
673
n_ainv = np .linalg .tensorinv (self .a )
654
674
tf_a = function ([A ], [Ai ])
655
675
t_ainv = tf_a (self .a )
656
- assert _allclose (n_ainv , t_ainv )
676
+ atol_ , rtol_ = _get_atol_rtol (n_ainv , np .asarray (t_ainv ))
677
+ assert np .allclose (n_ainv , t_ainv , atol = atol_ , rtol = rtol_ )
657
678
658
679
B = self .B
659
680
Bi = tensorinv (B )
@@ -664,8 +685,10 @@ def test_eval(self):
664
685
tf_b1 = function ([B ], [Bi1 ])
665
686
t_binv = tf_b (self .b )
666
687
t_binv1 = tf_b1 (self .b1 )
667
- assert _allclose (t_binv , n_binv )
668
- assert _allclose (t_binv1 , n_binv1 )
688
+ atol_ , rtol_ = _get_atol_rtol (np .asarray (t_binv ), n_binv )
689
+ assert np .allclose (t_binv , n_binv , atol = atol_ , rtol = rtol_ )
690
+ atol_ , rtol_ = _get_atol_rtol (np .asarray (t_binv1 ), n_binv1 )
691
+ assert np .allclose (t_binv1 , n_binv1 , atol = atol_ , rtol = rtol_ )
669
692
670
693
671
694
class TestKron (utt .InferShapeTester ):
0 commit comments