Skip to content

Commit a0b3c0a

Browse files
Validate inputs and outputs of numba cholesky function
1 parent d596ff3 commit a0b3c0a

File tree

3 files changed

+47
-11
lines changed

3 files changed

+47
-11
lines changed

pytensor/link/numba/dispatch/slinalg.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
340340
INFO,
341341
)
342342

343-
return A_copy
343+
return A_copy, int_ptr_to_val(INFO)
344344

345345
return impl
346346

@@ -349,16 +349,30 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
349349
def numba_funcify_Cholesky(op, node, **kwargs):
350350
lower = op.lower
351351
overwrite_a = False
352-
check_finite = op.on_error == "raise"
352+
check_finite = op.check_finite
353+
on_error = op.on_error
353354

354355
@numba_basic.numba_njit(inline="always")
355356
def nb_cholesky(a):
356-
res = _cholesky(a, lower, overwrite_a, check_finite)
357357
if check_finite:
358-
if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))):
358+
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
359359
raise np.linalg.LinAlgError(
360-
"Non-numeric values (nan or inf) returned by cholesky"
360+
"Non-numeric values (nan or inf) found in input to cholesky"
361361
)
362+
res, info = _cholesky(a, lower, overwrite_a, check_finite)
363+
364+
if on_error == "raise":
365+
if info > 0:
366+
raise np.linalg.LinAlgError(
367+
"Input to cholesky is not positive definite"
368+
)
369+
if info < 0:
370+
raise ValueError(
371+
'LAPACK reported an illegal value in input on entry to "POTRF."'
372+
)
373+
else:
374+
if info != 0:
375+
res = np.full_like(res, np.nan)
362376

363377
return res
364378

pytensor/tensor/slinalg.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ class Cholesky(Op):
5151
__props__ = ("lower", "destructive", "on_error")
5252
gufunc_signature = "(m,m)->(m,m)"
5353

54-
def __init__(self, *, lower=True, on_error="raise"):
54+
def __init__(self, *, lower=True, check_finite=True, on_error="raise"):
5555
self.lower = lower
5656
self.destructive = False
57+
self.check_finite = check_finite
5758
if on_error not in ("raise", "nan"):
5859
raise ValueError('on_error must be one of "raise" or ""nan"')
5960
self.on_error = on_error
@@ -70,7 +71,9 @@ def perform(self, node, inputs, outputs):
7071
x = inputs[0]
7172
z = outputs[0]
7273
try:
73-
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
74+
z[0] = scipy.linalg.cholesky(
75+
x, lower=self.lower, check_finite=self.check_finite
76+
).astype(x.dtype)
7477
except scipy.linalg.LinAlgError:
7578
if self.on_error == "raise":
7679
raise
@@ -129,8 +132,10 @@ def conjugate_solve_triangular(outer, inner):
129132
return [grad]
130133

131134

132-
def cholesky(x, lower=True, on_error="raise"):
133-
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
135+
def cholesky(x, lower=True, on_error="raise", check_finite=True):
136+
return Blockwise(
137+
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
138+
)(x)
134139

135140

136141
class SolveBase(Op):

tests/link/numba/test_slinalg.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,36 @@ def test_numba_Cholesky(lower):
128128
)
129129

130130

131-
def test_numba_Cholesky_raises_on_nan():
131+
def test_numba_Cholesky_raises_on_nan_input():
132132
test_value = rng.random(size=(3, 3)).astype(config.floatX)
133133
test_value[0, 0] = np.nan
134134

135135
x = pt.tensor(dtype=config.floatX, shape=(3, 3))
136136
x = x.T.dot(x)
137-
g = pt.linalg.cholesky(x, on_error="raise")
137+
g = pt.linalg.cholesky(x, check_finite=True)
138138
f = pytensor.function([x], g, mode="NUMBA")
139139

140140
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
141141
f(test_value)
142142

143143

144+
@pytest.mark.parametrize("on_error", ["nan", "raise"])
145+
def test_numba_Cholesky_raise_on(on_error):
146+
test_value = rng.random(size=(3, 3)).astype(config.floatX)
147+
148+
x = pt.tensor(dtype=config.floatX, shape=(3, 3))
149+
g = pt.linalg.cholesky(x, on_error=on_error)
150+
f = pytensor.function([x], g, mode="NUMBA")
151+
152+
if on_error == "raise":
153+
with pytest.raises(
154+
np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite"
155+
):
156+
f(test_value)
157+
else:
158+
assert np.all(np.isnan(f(test_value)))
159+
160+
144161
def test_block_diag():
145162
A = pt.matrix("A")
146163
B = pt.matrix("B")

0 commit comments

Comments
 (0)