Skip to content

Commit 197069d

Browse files
Implement numba overload for POTRF, LAPACK cholesky routine (#578)
* Implement numba overload for POTRF, LAPACK cholesky routine * Delete old numba_funcify_Cholesky * Refactor tests to include supported keywords and datatypes * Validate inputs and outputs of numba cholesky function * Raise on complex inputs * Change `cholesky` default for `check_finite` to `False` * Remove redundant dtype checks from numba linalg dispatchers * Add docstring to `numba_funcify_Cholesky` explaining why the overload is necessary.
1 parent f737996 commit 197069d

File tree

5 files changed

+200
-105
lines changed

5 files changed

+200
-105
lines changed

pytensor/link/numba/dispatch/basic.py

+1-36
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from pytensor.tensor.blas import BatchedDot
3838
from pytensor.tensor.math import Dot
3939
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
40-
from pytensor.tensor.slinalg import Cholesky, Solve
40+
from pytensor.tensor.slinalg import Solve
4141
from pytensor.tensor.subtensor import (
4242
AdvancedIncSubtensor,
4343
AdvancedIncSubtensor1,
@@ -809,41 +809,6 @@ def softplus(x):
809809
return softplus
810810

811811

812-
@numba_funcify.register(Cholesky)
813-
def numba_funcify_Cholesky(op, node, **kwargs):
814-
lower = op.lower
815-
816-
out_dtype = node.outputs[0].type.numpy_dtype
817-
818-
if lower:
819-
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
820-
821-
@numba_njit
822-
def cholesky(a):
823-
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
824-
825-
else:
826-
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.
827-
828-
warnings.warn(
829-
(
830-
"Numba will use object mode to allow the "
831-
"`lower` argument to `scipy.linalg.cholesky`."
832-
),
833-
UserWarning,
834-
)
835-
836-
ret_sig = get_numba_type(node.outputs[0].type)
837-
838-
@numba_njit
839-
def cholesky(a):
840-
with numba.objmode(ret=ret_sig):
841-
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
842-
return ret
843-
844-
return cholesky
845-
846-
847812
@numba_funcify.register(Solve)
848813
def numba_funcify_Solve(op, node, **kwargs):
849814
assume_a = op.assume_a

pytensor/link/numba/dispatch/slinalg.py

+135-13
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pytensor.link.numba.dispatch import basic as numba_basic
1111
from pytensor.link.numba.dispatch.basic import numba_funcify
12-
from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular
12+
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular
1313

1414

1515
_PTR = ctypes.POINTER
@@ -25,6 +25,15 @@
2525
_ptr_int = _PTR(_int)
2626

2727

28+
@numba.core.extending.register_jitable
29+
def _check_finite_matrix(a, func_name):
30+
for v in np.nditer(a):
31+
if not np.isfinite(v.item()):
32+
raise np.linalg.LinAlgError(
33+
"Non-numeric values (nan or inf) in input to " + func_name
34+
)
35+
36+
2837
@intrinsic
2938
def val_to_dptr(typingctx, data):
3039
def impl(context, builder, signature, args):
@@ -177,6 +186,22 @@ def numba_xtrtrs(cls, dtype):
177186

178187
return functype(lapack_ptr)
179188

189+
@classmethod
190+
def numba_xpotrf(cls, dtype):
191+
"""
192+
Called by scipy.linalg.cholesky
193+
"""
194+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
195+
functype = ctypes.CFUNCTYPE(
196+
None,
197+
_ptr_int, # UPLO,
198+
_ptr_int, # N
199+
float_pointer, # A
200+
_ptr_int, # LDA
201+
_ptr_int, # INFO
202+
)
203+
return functype(lapack_ptr)
204+
180205

181206
def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False):
182207
return linalg.solve_triangular(
@@ -190,13 +215,7 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
190215

191216
_check_scipy_linalg_matrix(A, "solve_triangular")
192217
_check_scipy_linalg_matrix(B, "solve_triangular")
193-
194218
dtype = A.dtype
195-
if str(dtype).startswith("complex"):
196-
raise ValueError(
197-
"Complex inputs not currently supported by solve_triangular in Numba mode"
198-
)
199-
200219
w_type = _get_underlying_float(dtype)
201220
numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
202221

@@ -249,8 +268,8 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False):
249268
)
250269

251270
if B_is_1d:
252-
return B_copy[..., 0]
253-
return B_copy
271+
return B_copy[..., 0], int_ptr_to_val(INFO)
272+
return B_copy, int_ptr_to_val(INFO)
254273

255274
return impl
256275

@@ -262,19 +281,122 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
262281
unit_diagonal = op.unit_diagonal
263282
check_finite = op.check_finite
264283

284+
dtype = node.inputs[0].dtype
285+
if str(dtype).startswith("complex"):
286+
raise NotImplementedError(
287+
"Complex inputs not currently supported by solve_triangular in Numba mode"
288+
)
289+
265290
@numba_basic.numba_njit(inline="always")
266291
def solve_triangular(a, b):
267-
res = _solve_triangular(a, b, trans, lower, unit_diagonal)
268292
if check_finite:
269-
if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))):
270-
raise ValueError(
271-
"Non-numeric values (nan or inf) returned by solve_triangular"
293+
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
294+
raise np.linalg.LinAlgError(
295+
"Non-numeric values (nan or inf) in input A to solve_triangular"
272296
)
297+
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
298+
raise np.linalg.LinAlgError(
299+
"Non-numeric values (nan or inf) in input b to solve_triangular"
300+
)
301+
302+
res, info = _solve_triangular(a, b, trans, lower, unit_diagonal)
303+
if info != 0:
304+
raise np.linalg.LinAlgError(
305+
"Singular matrix in input A to solve_triangular"
306+
)
273307
return res
274308

275309
return solve_triangular
276310

277311

312+
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
313+
return linalg.cholesky(
314+
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
315+
)
316+
317+
318+
@overload(_cholesky)
319+
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
320+
ensure_lapack()
321+
_check_scipy_linalg_matrix(A, "cholesky")
322+
dtype = A.dtype
323+
w_type = _get_underlying_float(dtype)
324+
numba_potrf = _LAPACK().numba_xpotrf(dtype)
325+
326+
def impl(A, lower=0, overwrite_a=False, check_finite=True):
327+
_N = np.int32(A.shape[-1])
328+
if A.shape[-2] != _N:
329+
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
330+
331+
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
332+
N = val_to_int_ptr(_N)
333+
LDA = val_to_int_ptr(_N)
334+
INFO = val_to_int_ptr(0)
335+
336+
if not overwrite_a:
337+
A_copy = _copy_to_fortran_order(A)
338+
else:
339+
A_copy = A
340+
341+
numba_potrf(
342+
UPLO,
343+
N,
344+
A_copy.view(w_type).ctypes,
345+
LDA,
346+
INFO,
347+
)
348+
349+
return A_copy, int_ptr_to_val(INFO)
350+
351+
return impl
352+
353+
354+
@numba_funcify.register(Cholesky)
355+
def numba_funcify_Cholesky(op, node, **kwargs):
356+
"""
357+
Overload scipy.linalg.cholesky with a numba function.
358+
359+
Note that np.linalg.cholesky is already implemented in numba, but it does not support additional keyword arguments.
360+
In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
361+
"""
362+
lower = op.lower
363+
overwrite_a = False
364+
check_finite = op.check_finite
365+
on_error = op.on_error
366+
367+
dtype = node.inputs[0].dtype
368+
if str(dtype).startswith("complex"):
369+
raise NotImplementedError(
370+
"Complex inputs not currently supported by cholesky in Numba mode"
371+
)
372+
373+
@numba_basic.numba_njit(inline="always")
374+
def nb_cholesky(a):
375+
if check_finite:
376+
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
377+
raise np.linalg.LinAlgError(
378+
"Non-numeric values (nan or inf) found in input to cholesky"
379+
)
380+
res, info = _cholesky(a, lower, overwrite_a, check_finite)
381+
382+
if on_error == "raise":
383+
if info > 0:
384+
raise np.linalg.LinAlgError(
385+
"Input to cholesky is not positive definite"
386+
)
387+
if info < 0:
388+
raise ValueError(
389+
'LAPACK reported an illegal value in input on entry to "POTRF."'
390+
)
391+
else:
392+
if info != 0:
393+
res = np.full_like(res, np.nan)
394+
395+
return res
396+
397+
return nb_cholesky
398+
399+
278400
@numba_funcify.register(BlockDiagonal)
279401
def numba_funcify_BlockDiagonal(op, node, **kwargs):
280402
dtype = node.outputs[0].dtype

pytensor/tensor/slinalg.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ class Cholesky(Op):
5151
__props__ = ("lower", "destructive", "on_error")
5252
gufunc_signature = "(m,m)->(m,m)"
5353

54-
def __init__(self, *, lower=True, on_error="raise"):
54+
def __init__(self, *, lower=True, check_finite=True, on_error="raise"):
5555
self.lower = lower
5656
self.destructive = False
57+
self.check_finite = check_finite
5758
if on_error not in ("raise", "nan"):
5859
raise ValueError('on_error must be one of "raise" or ""nan"')
5960
self.on_error = on_error
@@ -70,7 +71,9 @@ def perform(self, node, inputs, outputs):
7071
x = inputs[0]
7172
z = outputs[0]
7273
try:
73-
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
74+
z[0] = scipy.linalg.cholesky(
75+
x, lower=self.lower, check_finite=self.check_finite
76+
).astype(x.dtype)
7477
except scipy.linalg.LinAlgError:
7578
if self.on_error == "raise":
7679
raise
@@ -129,8 +132,10 @@ def conjugate_solve_triangular(outer, inner):
129132
return [grad]
130133

131134

132-
def cholesky(x, lower=True, on_error="raise"):
133-
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
135+
def cholesky(x, lower=True, on_error="raise", check_finite=False):
136+
return Blockwise(
137+
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
138+
)(x)
134139

135140

136141
class SolveBase(Op):

tests/link/numba/test_nlinalg.py

-51
Original file line numberDiff line numberDiff line change
@@ -14,57 +14,6 @@
1414
rng = np.random.default_rng(42849)
1515

1616

17-
@pytest.mark.parametrize(
18-
"x, lower, exc",
19-
[
20-
(
21-
set_test_value(
22-
pt.dmatrix(),
23-
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
24-
),
25-
True,
26-
None,
27-
),
28-
(
29-
set_test_value(
30-
pt.lmatrix(),
31-
(lambda x: x.T.dot(x))(
32-
rng.integers(1, 10, size=(3, 3)).astype("int64")
33-
),
34-
),
35-
True,
36-
None,
37-
),
38-
(
39-
set_test_value(
40-
pt.dmatrix(),
41-
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
42-
),
43-
False,
44-
UserWarning,
45-
),
46-
],
47-
)
48-
def test_Cholesky(x, lower, exc):
49-
g = slinalg.Cholesky(lower=lower)(x)
50-
51-
if isinstance(g, list):
52-
g_fg = FunctionGraph(outputs=g)
53-
else:
54-
g_fg = FunctionGraph(outputs=[g])
55-
56-
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
57-
with cm:
58-
compare_numba_and_py(
59-
g_fg,
60-
[
61-
i.tag.test_value
62-
for i in g_fg.inputs
63-
if not isinstance(i, (SharedVariable, Constant))
64-
],
65-
)
66-
67-
6817
@pytest.mark.parametrize(
6918
"A, x, lower, exc",
7019
[

0 commit comments

Comments
 (0)