Skip to content

Commit 87c6ee6

Browse files
Deprecated redefinition of np.allclose in _allclose
1 parent 3cc2393 commit 87c6ee6

File tree

9 files changed

+80
-51
lines changed

9 files changed

+80
-51
lines changed

pytensor/tensor/math.py

-12
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,6 @@ def _get_atol_rtol(a, b):
130130
return atol, rtol
131131

132132

133-
def _allclose(a, b, rtol=None, atol=None):
134-
a = np.asarray(a)
135-
b = np.asarray(b)
136-
atol_, rtol_ = _get_atol_rtol(a, b)
137-
if rtol is not None:
138-
rtol_ = rtol
139-
if atol is not None:
140-
atol_ = atol
141-
142-
return np.allclose(a, b, atol=atol_, rtol=rtol_)
143-
144-
145133
class Argmax(COp):
146134
"""
147135
Calculate the argmax over a given axis or over all axes.

pytensor/tensor/type.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,12 @@ def values_eq_approx(
662662
if str(a.dtype) not in continuous_dtypes:
663663
return np.all(a == b)
664664
else:
665-
cmp = pytensor.tensor.math._allclose(a, b, rtol=rtol, atol=atol)
665+
atol_, rtol_ = pytensor.tensor.math._get_atol_rtol(a, b)
666+
if rtol is not None:
667+
rtol_ = rtol
668+
if atol is not None:
669+
atol_ = atol
670+
cmp = np.allclose(np.asarray(a), np.asarray(b), rtol=rtol_, atol=atol_)
666671
if cmp:
667672
# Numpy claims they are close, this is good enough for us.
668673
return True

tests/graph/test_compute_test_value.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pytensor.graph.op import Op
1111
from pytensor.graph.type import Type
1212
from pytensor.link.c.op import COp
13-
from pytensor.tensor.math import _allclose, dot
13+
from pytensor.tensor.math import _get_atol_rtol, dot
1414
from pytensor.tensor.type import fmatrix, iscalar, matrix, vector
1515

1616

@@ -85,7 +85,15 @@ def test_variable_only(self):
8585
z = dot(x, y)
8686
assert hasattr(z.tag, "test_value")
8787
f = pytensor.function([x, y], z)
88-
assert _allclose(f(x.tag.test_value, y.tag.test_value), z.tag.test_value)
88+
atol_, rtol_ = _get_atol_rtol(
89+
f(x.tag.test_value, y.tag.test_value), z.tag.test_value
90+
)
91+
assert np.allclose(
92+
f(x.tag.test_value, y.tag.test_value),
93+
z.tag.test_value,
94+
atol=atol_,
95+
rtol=rtol_,
96+
)
8997

9098
# this test should fail
9199
y.tag.test_value = np.random.random((6, 5)).astype(config.floatX)
@@ -122,7 +130,16 @@ def test_string_var(self):
122130
out = dot(dot(x, y), z)
123131
assert hasattr(out.tag, "test_value")
124132
tf = pytensor.function([x, y], out)
125-
assert _allclose(tf(x.tag.test_value, y.tag.test_value), out.tag.test_value)
133+
134+
atol_, rtol_ = _get_atol_rtol(
135+
tf(x.tag.test_value, y.tag.test_value), out.tag.test_value
136+
)
137+
assert np.allclose(
138+
tf(x.tag.test_value, y.tag.test_value),
139+
out.tag.test_value,
140+
atol=atol_,
141+
rtol=rtol_,
142+
)
126143

127144
def f(x, y, z):
128145
return dot(dot(x, y), z)
@@ -141,7 +158,10 @@ def test_shared(self):
141158
z = dot(x, y)
142159
assert hasattr(z.tag, "test_value")
143160
f = pytensor.function([x], z)
144-
assert _allclose(f(x.tag.test_value), z.tag.test_value)
161+
atol_, rtol_ = _get_atol_rtol(f(x.tag.test_value), z.tag.test_value)
162+
assert np.allclose(
163+
f(x.tag.test_value), z.tag.test_value, atol=atol_, rtol=rtol_
164+
)
145165

146166
# this test should fail
147167
y.set_value(np.random.random((5, 6)).astype(config.floatX))
@@ -156,7 +176,8 @@ def test_ndarray(self):
156176
z = dot(x, y)
157177
assert hasattr(z.tag, "test_value")
158178
f = pytensor.function([], z)
159-
assert _allclose(f(), z.tag.test_value)
179+
atol_, rtol_ = _get_atol_rtol(f(), z.tag.test_value)
180+
assert np.allclose(f(), z.tag.test_value, atol=atol_, rtol=rtol_)
160181

161182
# this test should fail
162183
x = np.random.random((2, 4)).astype(config.floatX)
@@ -170,7 +191,8 @@ def test_empty_elemwise(self):
170191
z = (x + 2) * 3
171192
assert hasattr(z.tag, "test_value")
172193
f = pytensor.function([], z)
173-
assert _allclose(f(), z.tag.test_value)
194+
atol_, rtol_ = _get_atol_rtol(f(), z.tag.test_value)
195+
assert np.allclose(f(), z.tag.test_value, atol=atol_, rtol=rtol_)
174196

175197
def test_constant(self):
176198
x = pt.constant(np.random.random((2, 3)), dtype=config.floatX)
@@ -180,7 +202,8 @@ def test_constant(self):
180202
z = dot(x, y)
181203
assert hasattr(z.tag, "test_value")
182204
f = pytensor.function([], z)
183-
assert _allclose(f(), z.tag.test_value)
205+
atol_, rtol_ = _get_atol_rtol(f(), z.tag.test_value)
206+
assert np.allclose(f(), z.tag.test_value, atol=atol_, rtol=rtol_)
184207

185208
# this test should fail
186209
x = pt.constant(np.random.random((2, 4)), dtype=config.floatX)

tests/scan/test_rewriting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,7 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1):
12701270

12711271
(pytensor_dump, pytensor_x, pytensor_y) = f4(v_u1, v_u2, v_x0, v_y0, vW_in1)
12721272

1273-
np.testing.assert_allclose(pytensor_x, v_x[-1:][0])
1273+
np.testing.assert_allclose(pytensor_x, v_x[-1:].squeeze(0))
12741274
np.testing.assert_allclose(pytensor_y, v_y[-1:])
12751275

12761276
def test_save_mem_reduced_number_of_steps(self):

tests/sparse/test_basic.py

-1
Original file line numberDiff line numberDiff line change
@@ -2091,7 +2091,6 @@ def test_op(self, op_type):
20912091
f = pytensor.function(variable, self.op(variable[0], axis=axis))
20922092
tested = f(*data)
20932093
expected = data[0].todense().sum(axis).ravel()
2094-
20952094
np.testing.assert_allclose(expected, [tested], atol=1e-08, rtol=1e-05)
20962095

20972096
def test_infer_shape(self):

tests/tensor/conv/test_abstract_conv.py

-4
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@
3838
from tests.tensor.conv import c_conv3d_corr3d_ref, c_conv_corr_ref
3939

4040

41-
pytensor.config.mode = "FAST_COMPILE"
42-
43-
4441
def conv2d_corr(
4542
inputs,
4643
filters,
@@ -2414,7 +2411,6 @@ def test_fwd(self):
24142411
for j in range(0, kshp[2]):
24152412
single_kern = kern[:, i, j, ...].reshape(single_kshp)
24162413
ref_val = ref_func(img, single_kern)
2417-
24182414
np.testing.assert_allclose(
24192415
ref_val[:, :, i, j],
24202416
unshared_output[:, :, i, j],

tests/tensor/test_basic.py

-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pytensor
99
import pytensor.scalar as ps
1010
import pytensor.tensor.basic as ptb
11-
import pytensor.tensor.math as ptm
1211
from pytensor import compile, config, function, shared
1312
from pytensor.compile import SharedVariable
1413
from pytensor.compile.io import In, Out
@@ -1258,11 +1257,6 @@ def test_cast_from_complex_to_real_raises_error(self, real_dtype, complex_dtype)
12581257
# gradient numerically
12591258

12601259

1261-
def test_basic_allclose():
1262-
# This was raised by a user in https://github.com/Theano/Theano/issues/2975
1263-
assert ptm._allclose(-0.311023883434, -0.311022856884)
1264-
1265-
12661260
def test_get_vector_length():
12671261
# Test `Constant`s
12681262
empty_tuple = as_tensor_variable(())

tests/tensor/test_math.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@
4545
Prod,
4646
ProdWithoutZeros,
4747
Sum,
48-
_allclose,
4948
_dot,
49+
_get_atol_rtol,
5050
abs,
5151
add,
5252
allclose,
@@ -3608,7 +3608,8 @@ def setup_method(self):
36083608
def _validate_output(self, a, b):
36093609
pytensor_sol = self.op(a, b).eval()
36103610
numpy_sol = np.matmul(a, b)
3611-
assert _allclose(numpy_sol, pytensor_sol)
3611+
atol_, rtol_ = _get_atol_rtol(numpy_sol, pytensor_sol)
3612+
assert np.allclose(numpy_sol, pytensor_sol, atol=atol_, rtol=rtol_)
36123613

36133614
@pytest.mark.parametrize(
36143615
"x1, x2",

tests/tensor/test_nlinalg.py

+40-17
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor import function
1010
from pytensor.configdefaults import config
1111
from pytensor.tensor.basic import as_tensor_variable
12-
from pytensor.tensor.math import _allclose
12+
from pytensor.tensor.math import _get_atol_rtol
1313
from pytensor.tensor.nlinalg import (
1414
SVD,
1515
Eig,
@@ -60,7 +60,8 @@ def test_pseudoinverse_correctness():
6060
assert ri.dtype == r.dtype
6161
# Note that pseudoinverse can be quite imprecise so I prefer to compare
6262
# the result with what np.linalg returns
63-
assert _allclose(ri, np.linalg.pinv(r))
63+
atol_, rtol_ = _get_atol_rtol(ri, np.linalg.pinv(r))
64+
assert np.allclose(ri, np.linalg.pinv(r), atol=atol_, rtol=rtol_)
6465

6566

6667
def test_pseudoinverse_grad():
@@ -92,8 +93,11 @@ def test_inverse_correctness(self):
9293
rir = np.dot(ri, r)
9394
rri = np.dot(r, ri)
9495

95-
assert _allclose(np.identity(4), rir), rir
96-
assert _allclose(np.identity(4), rri), rri
96+
atol_, rtol_ = _get_atol_rtol(np.identity(4), rir)
97+
assert np.allclose(np.identity(4), rir, atol=atol_, rtol=rtol_), rir
98+
99+
atol_, rtol_ = _get_atol_rtol(np.identity(4), rri)
100+
assert np.allclose(np.identity(4), rri, atol=atol_, rtol=rtol_), rri
97101

98102
def test_infer_shape(self):
99103
r = self.rng.standard_normal((4, 4)).astype(config.floatX)
@@ -119,7 +123,8 @@ def test_matrix_dot():
119123
for r in rs[1:]:
120124
numpy_sol = np.dot(numpy_sol, r)
121125

122-
assert _allclose(numpy_sol, pytensor_sol)
126+
atol_, rtol_ = _get_atol_rtol(numpy_sol, pytensor_sol)
127+
assert np.allclose(numpy_sol, pytensor_sol, atol=atol_, rtol=rtol_)
123128

124129

125130
def test_qr_modes():
@@ -131,23 +136,34 @@ def test_qr_modes():
131136
f = function([A], qr(A))
132137
t_qr = f(a)
133138
n_qr = np.linalg.qr(a)
134-
assert _allclose(n_qr, t_qr)
139+
atol_, rtol_ = _get_atol_rtol(np.asarray(n_qr), np.asarray(t_qr))
140+
assert np.allclose(np.asarray(n_qr), np.asarray(t_qr), atol=atol_, rtol=rtol_)
135141

136142
for mode in ["reduced", "r", "raw"]:
137143
f = function([A], qr(A, mode))
138144
t_qr = f(a)
139145
n_qr = np.linalg.qr(a, mode)
140146
if isinstance(n_qr, list | tuple):
141-
assert _allclose(n_qr[0], t_qr[0])
142-
assert _allclose(n_qr[1], t_qr[1])
147+
atol_, rtol_ = _get_atol_rtol(np.asarray(n_qr[0]), np.asarray(t_qr[0]))
148+
assert np.allclose(
149+
np.asarray(n_qr[0]), np.asarray(t_qr[0]), atol=atol_, rtol=rtol_
150+
)
151+
atol_, rtol_ = _get_atol_rtol(np.asarray(n_qr[1]), np.asarray(t_qr[1]))
152+
assert np.allclose(
153+
np.asarray(n_qr[1]), np.asarray(t_qr[1]), atol=atol_, rtol=rtol_
154+
)
143155
else:
144-
assert _allclose(n_qr, t_qr)
156+
atol_, rtol_ = _get_atol_rtol(np.asarray(n_qr), np.asarray(t_qr))
157+
assert np.allclose(
158+
np.asarray(n_qr), np.asarray(t_qr), atol=atol_, rtol=rtol_
159+
)
145160

146161
try:
147162
n_qr = np.linalg.qr(a, "complete")
148163
f = function([A], qr(A, "complete"))
149164
t_qr = f(a)
150-
assert _allclose(n_qr, t_qr)
165+
atol_, rtol_ = _get_atol_rtol(np.asarray(n_qr), np.asarray(t_qr))
166+
assert np.allclose(np.asarray(n_qr), np.asarray(t_qr), atol=atol_, rtol=rtol_)
151167
except TypeError as e:
152168
assert "name 'complete' is not defined" in str(e)
153169

@@ -199,7 +215,8 @@ def test_svd(self, core_shape, full_matrix, compute_uv, batched, test_imag):
199215
np_outputs = np_outputs if isinstance(np_outputs, tuple) else [np_outputs]
200216

201217
for np_val, pt_val in zip(np_outputs, pt_outputs):
202-
assert _allclose(np_val, pt_val)
218+
atol_, rtol_ = _get_atol_rtol(np_val, pt_val)
219+
assert np.allclose(np_val, pt_val, atol=atol_, rtol=rtol_)
203220

204221
def test_svd_infer_shape(self):
205222
self.validate_shape((4, 4), full_matrices=True, compute_uv=True)
@@ -306,7 +323,8 @@ def test_tensorsolve():
306323

307324
n_x = np.linalg.tensorsolve(a, b)
308325
t_x = fn(a, b)
309-
assert _allclose(n_x, t_x)
326+
atol_, rtol_ = _get_atol_rtol(n_x, np.asarray(t_x))
327+
assert np.allclose(n_x, t_x, atol=atol_, rtol=rtol_)
310328

311329
# check the type upcast now
312330
C = tensor4("C", dtype="float32")
@@ -319,7 +337,8 @@ def test_tensorsolve():
319337
d = rng.random((2 * 3, 4)).astype("float64")
320338
n_y = np.linalg.tensorsolve(c, d)
321339
t_y = fn(c, d)
322-
assert _allclose(n_y, t_y)
340+
atol_, rtol_ = _get_atol_rtol(n_y, np.asarray(t_y))
341+
assert np.allclose(n_y, t_y, atol=atol_, rtol=rtol_)
323342
assert n_y.dtype == Y.dtype
324343

325344
# check the type upcast now
@@ -333,7 +352,8 @@ def test_tensorsolve():
333352
f = rng.random((2 * 3, 4)).astype("float64")
334353
n_z = np.linalg.tensorsolve(e, f)
335354
t_z = fn(e, f)
336-
assert _allclose(n_z, t_z)
355+
atol_, rtol_ = _get_atol_rtol(n_z, np.asarray(t_z))
356+
assert np.allclose(n_z, t_z, atol=atol_, rtol=rtol_)
337357
assert n_z.dtype == Z.dtype
338358

339359

@@ -653,7 +673,8 @@ def test_eval(self):
653673
n_ainv = np.linalg.tensorinv(self.a)
654674
tf_a = function([A], [Ai])
655675
t_ainv = tf_a(self.a)
656-
assert _allclose(n_ainv, t_ainv)
676+
atol_, rtol_ = _get_atol_rtol(n_ainv, np.asarray(t_ainv))
677+
assert np.allclose(n_ainv, t_ainv, atol=atol_, rtol=rtol_)
657678

658679
B = self.B
659680
Bi = tensorinv(B)
@@ -664,8 +685,10 @@ def test_eval(self):
664685
tf_b1 = function([B], [Bi1])
665686
t_binv = tf_b(self.b)
666687
t_binv1 = tf_b1(self.b1)
667-
assert _allclose(t_binv, n_binv)
668-
assert _allclose(t_binv1, n_binv1)
688+
atol_, rtol_ = _get_atol_rtol(np.asarray(t_binv), n_binv)
689+
assert np.allclose(t_binv, n_binv, atol=atol_, rtol=rtol_)
690+
atol_, rtol_ = _get_atol_rtol(np.asarray(t_binv1), n_binv1)
691+
assert np.allclose(t_binv1, n_binv1, atol=atol_, rtol=rtol_)
669692

670693

671694
class TestKron(utt.InferShapeTester):

0 commit comments

Comments
 (0)