Skip to content

Commit 69af4cc

Browse files
Validate inputs and outputs of numba cholesky function
1 parent 5023274 commit 69af4cc

File tree

3 files changed

+57
-12
lines changed

3 files changed

+57
-12
lines changed

pytensor/link/numba/dispatch/slinalg.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pytensor.link.numba.dispatch import basic as numba_basic
1111
from pytensor.link.numba.dispatch.basic import numba_funcify
12-
from pytensor.tensor.slinalg import Cholesky, BlockDiagonal, SolveTriangular
12+
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular
1313

1414

1515
_PTR = ctypes.POINTER
@@ -292,13 +292,14 @@ def solve_triangular(a, b):
292292
res = _solve_triangular(a, b, trans, lower, unit_diagonal)
293293
if check_finite:
294294
if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))):
295-
raise ValueError(
295+
raise np.linalg.LinAlgError(
296296
"Non-numeric values (nan or inf) returned by solve_triangular"
297297
)
298298
return res
299299

300300
return solve_triangular
301301

302+
302303
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
303304
return linalg.cholesky(
304305
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
@@ -339,7 +340,7 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
339340
INFO,
340341
)
341342

342-
return A_copy
343+
return A_copy, int_ptr_to_val(INFO)
343344

344345
return impl
345346

@@ -348,15 +349,36 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
348349
def numba_funcify_Cholesky(op, node, **kwargs):
349350
lower = op.lower
350351
overwrite_a = False
352+
check_finite = op.check_finite
351353
on_error = op.on_error
352354

353355
@numba_basic.numba_njit(inline="always")
354356
def nb_cholesky(a):
355-
res = _cholesky(a, lower, overwrite_a, on_error)
357+
if check_finite:
358+
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
359+
raise np.linalg.LinAlgError(
360+
"Non-numeric values (nan or inf) found in input to cholesky"
361+
)
362+
res, info = _cholesky(a, lower, overwrite_a, check_finite)
363+
364+
if on_error == "raise":
365+
if info > 0:
366+
raise np.linalg.LinAlgError(
367+
"Input to cholesky is not positive definite"
368+
)
369+
if info < 0:
370+
raise ValueError(
371+
'LAPACK reported an illegal value in input on entry to "POTRF."'
372+
)
373+
else:
374+
if info != 0:
375+
res = np.full_like(res, np.nan)
376+
356377
return res
357378

358379
return nb_cholesky
359380

381+
360382
@numba_funcify.register(BlockDiagonal)
361383
def numba_funcify_BlockDiagonal(op, node, **kwargs):
362384
dtype = node.outputs[0].dtype

pytensor/tensor/slinalg.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ class Cholesky(Op):
5151
__props__ = ("lower", "destructive", "on_error")
5252
gufunc_signature = "(m,m)->(m,m)"
5353

54-
def __init__(self, *, lower=True, on_error="raise"):
54+
def __init__(self, *, lower=True, check_finite=True, on_error="raise"):
5555
self.lower = lower
5656
self.destructive = False
57+
self.check_finite = check_finite
5758
if on_error not in ("raise", "nan"):
5859
raise ValueError('on_error must be one of "raise" or ""nan"')
5960
self.on_error = on_error
@@ -70,7 +71,9 @@ def perform(self, node, inputs, outputs):
7071
x = inputs[0]
7172
z = outputs[0]
7273
try:
73-
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
74+
z[0] = scipy.linalg.cholesky(
75+
x, lower=self.lower, check_finite=self.check_finite
76+
).astype(x.dtype)
7477
except scipy.linalg.LinAlgError:
7578
if self.on_error == "raise":
7679
raise
@@ -129,8 +132,10 @@ def conjugate_solve_triangular(outer, inner):
129132
return [grad]
130133

131134

132-
def cholesky(x, lower=True, on_error="raise"):
133-
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
135+
def cholesky(x, lower=True, on_error="raise", check_finite=True):
136+
return Blockwise(
137+
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
138+
)(x)
134139

135140

136141
class SolveBase(Op):

tests/link/numba/test_slinalg.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def test_solve_triangular_raises_on_nan_inf(value):
102102
b = np.full((5, 1), value)
103103

104104
with pytest.raises(
105-
ValueError, match=re.escape("Non-numeric values (nan or inf) returned ")
105+
np.linalg.LinAlgError,
106+
match=re.escape("Non-numeric values (nan or inf) returned "),
106107
):
107108
f(A_tri, b)
108109

@@ -127,19 +128,36 @@ def test_numba_Cholesky(lower):
127128
)
128129

129130

130-
def test_numba_Cholesky_raises_on_nan():
131+
def test_numba_Cholesky_raises_on_nan_input():
131132
test_value = rng.random(size=(3, 3)).astype(config.floatX)
132133
test_value[0, 0] = np.nan
133134

134135
x = pt.tensor(dtype=config.floatX, shape=(3, 3))
135136
x = x.T.dot(x)
136-
g = pt.linalg.cholesky(x, on_error="raise")
137+
g = pt.linalg.cholesky(x, check_finite=True)
137138
f = pytensor.function([x], g, mode="NUMBA")
138139

139-
with pytest.raises(ValueError, match=r"Non-numeric values"):
140+
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
140141
f(test_value)
141142

142143

144+
@pytest.mark.parametrize("on_error", ["nan", "raise"])
145+
def test_numba_Cholesky_raise_on(on_error):
146+
test_value = rng.random(size=(3, 3)).astype(config.floatX)
147+
148+
x = pt.tensor(dtype=config.floatX, shape=(3, 3))
149+
g = pt.linalg.cholesky(x, on_error=on_error)
150+
f = pytensor.function([x], g, mode="NUMBA")
151+
152+
if on_error == "raise":
153+
with pytest.raises(
154+
np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite"
155+
):
156+
f(test_value)
157+
else:
158+
assert np.all(np.isnan(f(test_value)))
159+
160+
143161
def test_block_diag():
144162
A = pt.matrix("A")
145163
B = pt.matrix("B")

0 commit comments

Comments
 (0)